1
2
3
4
5
6
7 package wycheproof
8
9 import (
10 "bytes"
11 "crypto/ecdh"
12 "fmt"
13 "testing"
14 )
15
16 func TestECDHStdLib(t *testing.T) {
17 type ECDHTestVector struct {
18
19 Comment string `json:"comment,omitempty"`
20
21 Flags []string `json:"flags,omitempty"`
22
23 Private string `json:"private,omitempty"`
24
25 Public string `json:"public,omitempty"`
26
27 Result string `json:"result,omitempty"`
28
29 Shared string `json:"shared,omitempty"`
30
31 TcID int `json:"tcId,omitempty"`
32 }
33
34 type ECDHTestGroup struct {
35 Curve string `json:"curve,omitempty"`
36 Tests []*ECDHTestVector `json:"tests,omitempty"`
37 }
38
39 type Root struct {
40 TestGroups []*ECDHTestGroup `json:"testGroups,omitempty"`
41 }
42
43 flagsShouldPass := map[string]bool{
44
45 "CompressedPoint": false,
46
47 "UnnamedCurve": false,
48
49 "WrongOrder": false,
50 "UnusedParam": false,
51
52
53 "Twist": true,
54 "SmallPublicKey": false,
55 "LowOrderPublic": false,
56 "ZeroSharedSecret": false,
57 "NonCanonicalPublic": true,
58 }
59
60
61
62 curveToCurve := map[string]ecdh.Curve{
63 "secp256r1": ecdh.P256(),
64 "secp384r1": ecdh.P384(),
65 "secp521r1": ecdh.P521(),
66 "curve25519": ecdh.X25519(),
67 }
68
69 curveToKeySize := map[string]int{
70 "secp256r1": 32,
71 "secp384r1": 48,
72 "secp521r1": 66,
73 "curve25519": 32,
74 }
75
76 for _, f := range []string{
77 "ecdh_secp256r1_ecpoint_test.json",
78 "ecdh_secp384r1_ecpoint_test.json",
79 "ecdh_secp521r1_ecpoint_test.json",
80 "x25519_test.json",
81 } {
82 var root Root
83 readTestVector(t, f, &root)
84 for _, tg := range root.TestGroups {
85 if _, ok := curveToCurve[tg.Curve]; !ok {
86 continue
87 }
88 for _, tt := range tg.Tests {
89 tg, tt := tg, tt
90 t.Run(fmt.Sprintf("%s/%d", tg.Curve, tt.TcID), func(t *testing.T) {
91 t.Logf("Type: %v", tt.Result)
92 t.Logf("Flags: %q", tt.Flags)
93 t.Log(tt.Comment)
94
95 shouldPass := shouldPass(tt.Result, tt.Flags, flagsShouldPass)
96
97 curve := curveToCurve[tg.Curve]
98 p := decodeHex(tt.Public)
99 pub, err := curve.NewPublicKey(p)
100 if err != nil {
101 if shouldPass {
102 t.Errorf("NewPublicKey: %v", err)
103 }
104 return
105 }
106
107 privBytes := decodeHex(tt.Private)
108 if len(privBytes) != curveToKeySize[tg.Curve] {
109 t.Skipf("non-standard key size %d", len(privBytes))
110 }
111
112 priv, err := curve.NewPrivateKey(privBytes)
113 if err != nil {
114 if shouldPass {
115 t.Errorf("NewPrivateKey: %v", err)
116 }
117 return
118 }
119
120 shared := decodeHex(tt.Shared)
121 x, err := priv.ECDH(pub)
122 if err != nil {
123 if tg.Curve == "curve25519" && !shouldPass {
124
125
126 return
127 }
128 t.Fatalf("ECDH: %v", err)
129 }
130
131 if bytes.Equal(shared, x) != shouldPass {
132 if shouldPass {
133 t.Errorf("ECDH = %x, want %x", shared, x)
134 } else {
135 t.Errorf("ECDH = %x, want anything else", shared)
136 }
137 }
138 })
139 }
140 }
141 }
142 }
143
View as plain text