1
2
3
4
5 package wycheproof
6
7 import (
8 "bytes"
9 "crypto/ecdsa"
10 "crypto/elliptic"
11 "crypto/x509"
12 "encoding/asn1"
13 "errors"
14 "fmt"
15 "testing"
16
17 "golang.org/x/crypto/cryptobyte"
18 casn1 "golang.org/x/crypto/cryptobyte/asn1"
19 )
20
21 func TestECDH(t *testing.T) {
22 type ECDHTestVector struct {
23
24 Comment string `json:"comment,omitempty"`
25
26 Flags []string `json:"flags,omitempty"`
27
28 Private string `json:"private,omitempty"`
29
30 Public string `json:"public,omitempty"`
31
32 Result string `json:"result,omitempty"`
33
34 Shared string `json:"shared,omitempty"`
35
36 TcID int `json:"tcId,omitempty"`
37 }
38
39 type ECDHTestGroup struct {
40 Curve string `json:"curve,omitempty"`
41 Tests []*ECDHTestVector `json:"tests,omitempty"`
42 }
43
44 type Root struct {
45 TestGroups []*ECDHTestGroup `json:"testGroups,omitempty"`
46 }
47
48 flagsShouldPass := map[string]bool{
49
50
51 "CompressedPoint": true,
52
53 "UnnamedCurve": false,
54
55 "WrongOrder": false,
56 "UnusedParam": false,
57 }
58
59
60
61 supportedCurves := map[string]bool{
62 "secp224r1": true,
63 "secp256r1": true,
64 "secp384r1": true,
65 "secp521r1": true,
66 }
67
68 var root Root
69 readTestVector(t, "ecdh_test.json", &root)
70 for _, tg := range root.TestGroups {
71 if !supportedCurves[tg.Curve] {
72 continue
73 }
74 for _, tt := range tg.Tests {
75 tg, tt := tg, tt
76 t.Run(fmt.Sprintf("%s/%d", tg.Curve, tt.TcID), func(t *testing.T) {
77 t.Logf("Type: %v", tt.Result)
78 t.Logf("Flags: %q", tt.Flags)
79 t.Log(tt.Comment)
80
81 shouldPass := shouldPass(tt.Result, tt.Flags, flagsShouldPass)
82
83 p := decodeHex(tt.Public)
84 pp, err := x509.ParsePKIXPublicKey(p)
85 if err != nil {
86 pp, err = decodeCompressedPKIX(p)
87 }
88 if err != nil {
89 if shouldPass {
90 t.Errorf("unexpected parsing error: %s", err)
91 }
92 return
93 }
94 pub := pp.(*ecdsa.PublicKey)
95
96 priv := decodeHex(tt.Private)
97 shared := decodeHex(tt.Shared)
98
99 x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, priv)
100 xBytes := make([]byte, (pub.Curve.Params().BitSize+7)/8)
101 got := bytes.Equal(shared, x.FillBytes(xBytes))
102
103 if want := shouldPass; got != want {
104 t.Errorf("wanted success %v, got %v", want, got)
105 }
106 })
107 }
108 }
109 }
110
111 func decodeCompressedPKIX(der []byte) (interface{}, error) {
112 s := cryptobyte.String(der)
113 var s1, s2 cryptobyte.String
114 var algoOID, namedCurveOID asn1.ObjectIdentifier
115 var pointDER []byte
116 if !s.ReadASN1(&s1, casn1.SEQUENCE) || !s.Empty() ||
117 !s1.ReadASN1(&s2, casn1.SEQUENCE) ||
118 !s2.ReadASN1ObjectIdentifier(&algoOID) ||
119 !s2.ReadASN1ObjectIdentifier(&namedCurveOID) || !s2.Empty() ||
120 !s1.ReadASN1BitStringAsBytes(&pointDER) || !s1.Empty() {
121 return nil, errors.New("failed to parse PKIX structure")
122 }
123
124 if !algoOID.Equal(oidPublicKeyECDSA) {
125 return nil, errors.New("wrong algorithm OID")
126 }
127 namedCurve := namedCurveFromOID(namedCurveOID)
128 if namedCurve == nil {
129 return nil, errors.New("unsupported elliptic curve")
130 }
131 x, y := elliptic.UnmarshalCompressed(namedCurve, pointDER)
132 if x == nil {
133 return nil, errors.New("failed to unmarshal elliptic curve point")
134 }
135 pub := &ecdsa.PublicKey{
136 Curve: namedCurve,
137 X: x,
138 Y: y,
139 }
140 return pub, nil
141 }
142
143 var (
144 oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
145 oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33}
146 oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7}
147 oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34}
148 oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35}
149 )
150
151 func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve {
152 switch {
153 case oid.Equal(oidNamedCurveP224):
154 return elliptic.P224()
155 case oid.Equal(oidNamedCurveP256):
156 return elliptic.P256()
157 case oid.Equal(oidNamedCurveP384):
158 return elliptic.P384()
159 case oid.Equal(oidNamedCurveP521):
160 return elliptic.P521()
161 }
162 return nil
163 }
164
View as plain text