1
2
3
4
5 package curve25519_test
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "encoding/hex"
11 "testing"
12
13 "golang.org/x/crypto/curve25519"
14 )
15
16 const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
17
18 func TestX25519Basepoint(t *testing.T) {
19 x := make([]byte, 32)
20 x[0] = 1
21
22 for i := 0; i < 200; i++ {
23 var err error
24 x, err = curve25519.X25519(x, curve25519.Basepoint)
25 if err != nil {
26 t.Fatal(err)
27 }
28 }
29
30 result := hex.EncodeToString(x)
31 if result != expectedHex {
32 t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
33 }
34 }
35
36 func TestLowOrderPoints(t *testing.T) {
37 scalar := make([]byte, curve25519.ScalarSize)
38 if _, err := rand.Read(scalar); err != nil {
39 t.Fatal(err)
40 }
41 for i, p := range lowOrderPoints {
42 out, err := curve25519.X25519(scalar, p)
43 if err == nil {
44 t.Errorf("%d: expected error, got nil", i)
45 }
46 if out != nil {
47 t.Errorf("%d: expected nil output, got %x", i, out)
48 }
49 }
50 }
51
52 func TestTestVectors(t *testing.T) {
53 t.Run("Legacy", func(t *testing.T) { testTestVectors(t, curve25519.ScalarMult) })
54 t.Run("X25519", func(t *testing.T) {
55 testTestVectors(t, func(dst, scalar, point *[32]byte) {
56 out, err := curve25519.X25519(scalar[:], point[:])
57 if err != nil {
58 t.Fatal(err)
59 }
60 copy(dst[:], out)
61 })
62 })
63 }
64
65 func testTestVectors(t *testing.T, scalarMult func(dst, scalar, point *[32]byte)) {
66 for _, tv := range testVectors {
67 var got [32]byte
68 scalarMult(&got, &tv.In, &tv.Base)
69 if !bytes.Equal(got[:], tv.Expect[:]) {
70 t.Logf(" in = %x", tv.In)
71 t.Logf(" base = %x", tv.Base)
72 t.Logf(" got = %x", got)
73 t.Logf("expect = %x", tv.Expect)
74 t.Fail()
75 }
76 }
77 }
78
79
80
81
82
83
84
85 func TestHighBitIgnored(t *testing.T) {
86 var s, u [32]byte
87 rand.Read(s[:])
88 rand.Read(u[:])
89
90 var hi0, hi1 [32]byte
91
92 u[31] &= 0x7f
93 curve25519.ScalarMult(&hi0, &s, &u)
94
95 u[31] |= 0x80
96 curve25519.ScalarMult(&hi1, &s, &u)
97
98 if !bytes.Equal(hi0[:], hi1[:]) {
99 t.Errorf("high bit of group point should not affect result")
100 }
101 }
102
103 var benchmarkSink byte
104
105 func BenchmarkX25519Basepoint(b *testing.B) {
106 scalar := make([]byte, curve25519.ScalarSize)
107 if _, err := rand.Read(scalar); err != nil {
108 b.Fatal(err)
109 }
110
111 b.ResetTimer()
112 for i := 0; i < b.N; i++ {
113 out, err := curve25519.X25519(scalar, curve25519.Basepoint)
114 if err != nil {
115 b.Fatal(err)
116 }
117 benchmarkSink ^= out[0]
118 }
119 }
120
121 func BenchmarkX25519(b *testing.B) {
122 scalar := make([]byte, curve25519.ScalarSize)
123 if _, err := rand.Read(scalar); err != nil {
124 b.Fatal(err)
125 }
126 point, err := curve25519.X25519(scalar, curve25519.Basepoint)
127 if err != nil {
128 b.Fatal(err)
129 }
130 if _, err := rand.Read(scalar); err != nil {
131 b.Fatal(err)
132 }
133
134 b.ResetTimer()
135 for i := 0; i < b.N; i++ {
136 out, err := curve25519.X25519(scalar, point)
137 if err != nil {
138 b.Fatal(err)
139 }
140 benchmarkSink ^= out[0]
141 }
142 }
143
View as plain text