Source file
src/crypto/tls/handshake_messages_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "crypto/x509"
10 "encoding/hex"
11 "math"
12 "math/rand"
13 "reflect"
14 "strings"
15 "testing"
16 "testing/quick"
17 "time"
18 )
19
20 var tests = []handshakeMessage{
21 &clientHelloMsg{},
22 &serverHelloMsg{},
23 &finishedMsg{},
24
25 &certificateMsg{},
26 &certificateRequestMsg{},
27 &certificateVerifyMsg{
28 hasSignatureAlgorithm: true,
29 },
30 &certificateStatusMsg{},
31 &clientKeyExchangeMsg{},
32 &newSessionTicketMsg{},
33 &encryptedExtensionsMsg{},
34 &endOfEarlyDataMsg{},
35 &keyUpdateMsg{},
36 &newSessionTicketMsgTLS13{},
37 &certificateRequestMsgTLS13{},
38 &certificateMsgTLS13{},
39 &SessionState{},
40 }
41
42 func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
43 t.Helper()
44 b, err := msg.marshal()
45 if err != nil {
46 t.Fatal(err)
47 }
48 return b
49 }
50
51 func TestMarshalUnmarshal(t *testing.T) {
52 rand := rand.New(rand.NewSource(time.Now().UnixNano()))
53
54 for i, m := range tests {
55 ty := reflect.ValueOf(m).Type()
56
57 n := 100
58 if testing.Short() {
59 n = 5
60 }
61 for j := 0; j < n; j++ {
62 v, ok := quick.Value(ty, rand)
63 if !ok {
64 t.Errorf("#%d: failed to create value", i)
65 break
66 }
67
68 m1 := v.Interface().(handshakeMessage)
69 marshaled := mustMarshal(t, m1)
70 if !m.unmarshal(marshaled) {
71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
72 break
73 }
74 m.marshal()
75
76 if m, ok := m.(*SessionState); ok {
77 m.activeCertHandles = nil
78 }
79
80 if !reflect.DeepEqual(m1, m) {
81 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
82 break
83 }
84
85 if i >= 3 {
86
87
88
89
90
91 for j := 0; j < len(marshaled); j++ {
92 if m.unmarshal(marshaled[0:j]) {
93 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
94 break
95 }
96 }
97 }
98 }
99 }
100 }
101
102 func TestFuzz(t *testing.T) {
103 rand := rand.New(rand.NewSource(0))
104 for _, m := range tests {
105 for j := 0; j < 1000; j++ {
106 len := rand.Intn(1000)
107 bytes := randomBytes(len, rand)
108
109 m.unmarshal(bytes)
110 }
111 }
112 }
113
114 func randomBytes(n int, rand *rand.Rand) []byte {
115 r := make([]byte, n)
116 if _, err := rand.Read(r); err != nil {
117 panic("rand.Read failed: " + err.Error())
118 }
119 return r
120 }
121
122 func randomString(n int, rand *rand.Rand) string {
123 b := randomBytes(n, rand)
124 return string(b)
125 }
126
127 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
128 m := &clientHelloMsg{}
129 m.vers = uint16(rand.Intn(65536))
130 m.random = randomBytes(32, rand)
131 m.sessionId = randomBytes(rand.Intn(32), rand)
132 m.cipherSuites = make([]uint16, rand.Intn(63)+1)
133 for i := 0; i < len(m.cipherSuites); i++ {
134 cs := uint16(rand.Int31())
135 if cs == scsvRenegotiation {
136 cs += 1
137 }
138 m.cipherSuites[i] = cs
139 }
140 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
141 if rand.Intn(10) > 5 {
142 m.serverName = randomString(rand.Intn(255), rand)
143 for strings.HasSuffix(m.serverName, ".") {
144 m.serverName = m.serverName[:len(m.serverName)-1]
145 }
146 }
147 m.ocspStapling = rand.Intn(10) > 5
148 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
149 m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
150 for i := range m.supportedCurves {
151 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
152 }
153 if rand.Intn(10) > 5 {
154 m.ticketSupported = true
155 if rand.Intn(10) > 5 {
156 m.sessionTicket = randomBytes(rand.Intn(300), rand)
157 } else {
158 m.sessionTicket = make([]byte, 0)
159 }
160 }
161 if rand.Intn(10) > 5 {
162 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
163 }
164 if rand.Intn(10) > 5 {
165 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
166 }
167 for i := 0; i < rand.Intn(5); i++ {
168 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
169 }
170 if rand.Intn(10) > 5 {
171 m.scts = true
172 }
173 if rand.Intn(10) > 5 {
174 m.secureRenegotiationSupported = true
175 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
176 }
177 if rand.Intn(10) > 5 {
178 m.extendedMasterSecret = true
179 }
180 for i := 0; i < rand.Intn(5); i++ {
181 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
182 }
183 if rand.Intn(10) > 5 {
184 m.cookie = randomBytes(rand.Intn(500)+1, rand)
185 }
186 for i := 0; i < rand.Intn(5); i++ {
187 var ks keyShare
188 ks.group = CurveID(rand.Intn(30000) + 1)
189 ks.data = randomBytes(rand.Intn(200)+1, rand)
190 m.keyShares = append(m.keyShares, ks)
191 }
192 switch rand.Intn(3) {
193 case 1:
194 m.pskModes = []uint8{pskModeDHE}
195 case 2:
196 m.pskModes = []uint8{pskModeDHE, pskModePlain}
197 }
198 for i := 0; i < rand.Intn(5); i++ {
199 var psk pskIdentity
200 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
201 psk.label = randomBytes(rand.Intn(500)+1, rand)
202 m.pskIdentities = append(m.pskIdentities, psk)
203 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
204 }
205 if rand.Intn(10) > 5 {
206 m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
207 }
208 if rand.Intn(10) > 5 {
209 m.earlyData = true
210 }
211
212 return reflect.ValueOf(m)
213 }
214
215 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
216 m := &serverHelloMsg{}
217 m.vers = uint16(rand.Intn(65536))
218 m.random = randomBytes(32, rand)
219 m.sessionId = randomBytes(rand.Intn(32), rand)
220 m.cipherSuite = uint16(rand.Int31())
221 m.compressionMethod = uint8(rand.Intn(256))
222 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
223
224 if rand.Intn(10) > 5 {
225 m.ocspStapling = true
226 }
227 if rand.Intn(10) > 5 {
228 m.ticketSupported = true
229 }
230 if rand.Intn(10) > 5 {
231 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
232 }
233
234 for i := 0; i < rand.Intn(4); i++ {
235 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
236 }
237
238 if rand.Intn(10) > 5 {
239 m.secureRenegotiationSupported = true
240 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
241 }
242 if rand.Intn(10) > 5 {
243 m.extendedMasterSecret = true
244 }
245 if rand.Intn(10) > 5 {
246 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
247 }
248 if rand.Intn(10) > 5 {
249 m.cookie = randomBytes(rand.Intn(500)+1, rand)
250 }
251 if rand.Intn(10) > 5 {
252 for i := 0; i < rand.Intn(5); i++ {
253 m.serverShare.group = CurveID(rand.Intn(30000) + 1)
254 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
255 }
256 } else if rand.Intn(10) > 5 {
257 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
258 }
259 if rand.Intn(10) > 5 {
260 m.selectedIdentityPresent = true
261 m.selectedIdentity = uint16(rand.Intn(0xffff))
262 }
263
264 return reflect.ValueOf(m)
265 }
266
267 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
268 m := &encryptedExtensionsMsg{}
269
270 if rand.Intn(10) > 5 {
271 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
272 }
273 if rand.Intn(10) > 5 {
274 m.earlyData = true
275 }
276
277 return reflect.ValueOf(m)
278 }
279
280 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
281 m := &certificateMsg{}
282 numCerts := rand.Intn(20)
283 m.certificates = make([][]byte, numCerts)
284 for i := 0; i < numCerts; i++ {
285 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
286 }
287 return reflect.ValueOf(m)
288 }
289
290 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
291 m := &certificateRequestMsg{}
292 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
293 for i := 0; i < rand.Intn(100); i++ {
294 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
295 }
296 return reflect.ValueOf(m)
297 }
298
299 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
300 m := &certificateVerifyMsg{}
301 m.hasSignatureAlgorithm = true
302 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
303 m.signature = randomBytes(rand.Intn(15)+1, rand)
304 return reflect.ValueOf(m)
305 }
306
307 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
308 m := &certificateStatusMsg{}
309 m.response = randomBytes(rand.Intn(10)+1, rand)
310 return reflect.ValueOf(m)
311 }
312
313 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
314 m := &clientKeyExchangeMsg{}
315 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
316 return reflect.ValueOf(m)
317 }
318
319 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
320 m := &finishedMsg{}
321 m.verifyData = randomBytes(12, rand)
322 return reflect.ValueOf(m)
323 }
324
325 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
326 m := &newSessionTicketMsg{}
327 m.ticket = randomBytes(rand.Intn(4), rand)
328 return reflect.ValueOf(m)
329 }
330
331 var sessionTestCerts []*x509.Certificate
332
333 func init() {
334 cert, err := x509.ParseCertificate(testRSACertificate)
335 if err != nil {
336 panic(err)
337 }
338 sessionTestCerts = append(sessionTestCerts, cert)
339 cert, err = x509.ParseCertificate(testRSACertificateIssuer)
340 if err != nil {
341 panic(err)
342 }
343 sessionTestCerts = append(sessionTestCerts, cert)
344 }
345
346 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
347 s := &SessionState{}
348 isTLS13 := rand.Intn(10) > 5
349 if isTLS13 {
350 s.version = VersionTLS13
351 } else {
352 s.version = uint16(rand.Intn(VersionTLS13))
353 }
354 s.isClient = rand.Intn(10) > 5
355 s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
356 s.createdAt = uint64(rand.Int63())
357 s.secret = randomBytes(rand.Intn(100)+1, rand)
358 for n, i := rand.Intn(3), 0; i < n; i++ {
359 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
360 }
361 if rand.Intn(10) > 5 {
362 s.EarlyData = true
363 }
364 if rand.Intn(10) > 5 {
365 s.extMasterSecret = true
366 }
367 if s.isClient || rand.Intn(10) > 5 {
368 if rand.Intn(10) > 5 {
369 s.peerCertificates = sessionTestCerts
370 } else {
371 s.peerCertificates = sessionTestCerts[:1]
372 }
373 }
374 if rand.Intn(10) > 5 && s.peerCertificates != nil {
375 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
376 }
377 if rand.Intn(10) > 5 && s.peerCertificates != nil {
378 for i := 0; i < rand.Intn(2)+1; i++ {
379 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
380 }
381 }
382 if len(s.peerCertificates) > 0 {
383 for i := 0; i < rand.Intn(3); i++ {
384 if rand.Intn(10) > 5 {
385 s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
386 } else {
387 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
388 }
389 }
390 }
391 if rand.Intn(10) > 5 && s.EarlyData {
392 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
393 }
394 if s.isClient {
395 if isTLS13 {
396 s.useBy = uint64(rand.Int63())
397 s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
398 }
399 }
400 return reflect.ValueOf(s)
401 }
402
403 func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
404 func (s *SessionState) unmarshal(b []byte) bool {
405 ss, err := ParseSessionState(b)
406 if err != nil {
407 return false
408 }
409 *s = *ss
410 return true
411 }
412
413 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
414 m := &endOfEarlyDataMsg{}
415 return reflect.ValueOf(m)
416 }
417
418 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
419 m := &keyUpdateMsg{}
420 m.updateRequested = rand.Intn(10) > 5
421 return reflect.ValueOf(m)
422 }
423
424 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
425 m := &newSessionTicketMsgTLS13{}
426 m.lifetime = uint32(rand.Intn(500000))
427 m.ageAdd = uint32(rand.Intn(500000))
428 m.nonce = randomBytes(rand.Intn(100), rand)
429 m.label = randomBytes(rand.Intn(1000), rand)
430 if rand.Intn(10) > 5 {
431 m.maxEarlyData = uint32(rand.Intn(500000))
432 }
433 return reflect.ValueOf(m)
434 }
435
436 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
437 m := &certificateRequestMsgTLS13{}
438 if rand.Intn(10) > 5 {
439 m.ocspStapling = true
440 }
441 if rand.Intn(10) > 5 {
442 m.scts = true
443 }
444 if rand.Intn(10) > 5 {
445 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
446 }
447 if rand.Intn(10) > 5 {
448 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
449 }
450 if rand.Intn(10) > 5 {
451 m.certificateAuthorities = make([][]byte, 3)
452 for i := 0; i < 3; i++ {
453 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
454 }
455 }
456 return reflect.ValueOf(m)
457 }
458
459 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
460 m := &certificateMsgTLS13{}
461 for i := 0; i < rand.Intn(2)+1; i++ {
462 m.certificate.Certificate = append(
463 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
464 }
465 if rand.Intn(10) > 5 {
466 m.ocspStapling = true
467 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
468 }
469 if rand.Intn(10) > 5 {
470 m.scts = true
471 for i := 0; i < rand.Intn(2)+1; i++ {
472 m.certificate.SignedCertificateTimestamps = append(
473 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
474 }
475 }
476 return reflect.ValueOf(m)
477 }
478
479 func TestRejectEmptySCTList(t *testing.T) {
480
481
482 var random [32]byte
483 sct := []byte{0x42, 0x42, 0x42, 0x42}
484 serverHello := &serverHelloMsg{
485 vers: VersionTLS12,
486 random: random[:],
487 scts: [][]byte{sct},
488 }
489 serverHelloBytes := mustMarshal(t, serverHello)
490
491 var serverHelloCopy serverHelloMsg
492 if !serverHelloCopy.unmarshal(serverHelloBytes) {
493 t.Fatal("Failed to unmarshal initial message")
494 }
495
496
497 i := bytes.Index(serverHelloBytes, sct)
498 if i < 0 {
499 t.Fatal("Cannot find SCT in ServerHello")
500 }
501
502 var serverHelloEmptySCT []byte
503 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
504
505 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
506 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
507
508
509 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
510 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
511 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
512
513
514 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
515 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
516
517 if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
518 t.Fatal("Unmarshaled ServerHello with empty SCT list")
519 }
520 }
521
522 func TestRejectEmptySCT(t *testing.T) {
523
524
525
526 var random [32]byte
527 serverHello := &serverHelloMsg{
528 vers: VersionTLS12,
529 random: random[:],
530 scts: [][]byte{nil},
531 }
532 serverHelloBytes := mustMarshal(t, serverHello)
533
534 var serverHelloCopy serverHelloMsg
535 if serverHelloCopy.unmarshal(serverHelloBytes) {
536 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
537 }
538 }
539
540 func TestRejectDuplicateExtensions(t *testing.T) {
541 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
542 if err != nil {
543 t.Fatalf("failed to decode test ClientHello: %s", err)
544 }
545 var clientHelloCopy clientHelloMsg
546 if clientHelloCopy.unmarshal(clientHelloBytes) {
547 t.Error("Unmarshaled ClientHello with duplicate extensions")
548 }
549
550 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
551 if err != nil {
552 t.Fatalf("failed to decode test ServerHello: %s", err)
553 }
554 var serverHelloCopy serverHelloMsg
555 if serverHelloCopy.unmarshal(serverHelloBytes) {
556 t.Fatal("Unmarshaled ServerHello with duplicate extensions")
557 }
558 }
559
View as plain text