1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "internal/godebug"
19 "io"
20 "net"
21 "sync"
22 "sync/atomic"
23 "time"
24 )
25
26
27
28 type Conn struct {
29
30 conn net.Conn
31 isClient bool
32 handshakeFn func(context.Context) error
33 quic *quicState
34
35
36
37
38 isHandshakeComplete atomic.Bool
39
40 handshakeMutex sync.Mutex
41 handshakeErr error
42 vers uint16
43 haveVers bool
44 config *Config
45
46
47
48 handshakes int
49 extMasterSecret bool
50 didResume bool
51 cipherSuite uint16
52 ocspResponse []byte
53 scts [][]byte
54 peerCertificates []*x509.Certificate
55
56
57 activeCertHandles []*activeCert
58
59
60 verifiedChains [][]*x509.Certificate
61
62 serverName string
63
64
65
66 secureRenegotiation bool
67
68 ekm func(label string, context []byte, length int) ([]byte, error)
69
70
71 resumptionSecret []byte
72
73
74
75
76 ticketKeys []ticketKey
77
78
79
80
81
82 clientFinishedIsFirst bool
83
84
85 closeNotifyErr error
86
87
88 closeNotifySent bool
89
90
91
92
93
94 clientFinished [12]byte
95 serverFinished [12]byte
96
97
98 clientProtocol string
99
100
101 in, out halfConn
102 rawInput bytes.Buffer
103 input bytes.Reader
104 hand bytes.Buffer
105 buffering bool
106 sendBuf []byte
107
108
109
110 bytesSent int64
111 packetsSent int64
112
113
114
115
116 retryCount int
117
118
119
120 activeCall atomic.Int32
121
122 tmp [16]byte
123 }
124
125
126
127
128
129
130 func (c *Conn) LocalAddr() net.Addr {
131 return c.conn.LocalAddr()
132 }
133
134
135 func (c *Conn) RemoteAddr() net.Addr {
136 return c.conn.RemoteAddr()
137 }
138
139
140
141
142 func (c *Conn) SetDeadline(t time.Time) error {
143 return c.conn.SetDeadline(t)
144 }
145
146
147
148 func (c *Conn) SetReadDeadline(t time.Time) error {
149 return c.conn.SetReadDeadline(t)
150 }
151
152
153
154
155 func (c *Conn) SetWriteDeadline(t time.Time) error {
156 return c.conn.SetWriteDeadline(t)
157 }
158
159
160
161
162 func (c *Conn) NetConn() net.Conn {
163 return c.conn
164 }
165
166
167
168 type halfConn struct {
169 sync.Mutex
170
171 err error
172 version uint16
173 cipher any
174 mac hash.Hash
175 seq [8]byte
176
177 scratchBuf [13]byte
178
179 nextCipher any
180 nextMac hash.Hash
181
182 level QUICEncryptionLevel
183 trafficSecret []byte
184 }
185
186 type permanentError struct {
187 err net.Error
188 }
189
190 func (e *permanentError) Error() string { return e.err.Error() }
191 func (e *permanentError) Unwrap() error { return e.err }
192 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
193 func (e *permanentError) Temporary() bool { return false }
194
195 func (hc *halfConn) setErrorLocked(err error) error {
196 if e, ok := err.(net.Error); ok {
197 hc.err = &permanentError{err: e}
198 } else {
199 hc.err = err
200 }
201 return hc.err
202 }
203
204
205
206 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
207 hc.version = version
208 hc.nextCipher = cipher
209 hc.nextMac = mac
210 }
211
212
213
214 func (hc *halfConn) changeCipherSpec() error {
215 if hc.nextCipher == nil || hc.version == VersionTLS13 {
216 return alertInternalError
217 }
218 hc.cipher = hc.nextCipher
219 hc.mac = hc.nextMac
220 hc.nextCipher = nil
221 hc.nextMac = nil
222 for i := range hc.seq {
223 hc.seq[i] = 0
224 }
225 return nil
226 }
227
228 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
229 hc.trafficSecret = secret
230 hc.level = level
231 key, iv := suite.trafficKey(secret)
232 hc.cipher = suite.aead(key, iv)
233 for i := range hc.seq {
234 hc.seq[i] = 0
235 }
236 }
237
238
239 func (hc *halfConn) incSeq() {
240 for i := 7; i >= 0; i-- {
241 hc.seq[i]++
242 if hc.seq[i] != 0 {
243 return
244 }
245 }
246
247
248
249
250 panic("TLS: sequence number wraparound")
251 }
252
253
254
255
256 func (hc *halfConn) explicitNonceLen() int {
257 if hc.cipher == nil {
258 return 0
259 }
260
261 switch c := hc.cipher.(type) {
262 case cipher.Stream:
263 return 0
264 case aead:
265 return c.explicitNonceLen()
266 case cbcMode:
267
268 if hc.version >= VersionTLS11 {
269 return c.BlockSize()
270 }
271 return 0
272 default:
273 panic("unknown cipher type")
274 }
275 }
276
277
278
279
280 func extractPadding(payload []byte) (toRemove int, good byte) {
281 if len(payload) < 1 {
282 return 0, 0
283 }
284
285 paddingLen := payload[len(payload)-1]
286 t := uint(len(payload)-1) - uint(paddingLen)
287
288 good = byte(int32(^t) >> 31)
289
290
291 toCheck := 256
292
293 if toCheck > len(payload) {
294 toCheck = len(payload)
295 }
296
297 for i := 0; i < toCheck; i++ {
298 t := uint(paddingLen) - uint(i)
299
300 mask := byte(int32(^t) >> 31)
301 b := payload[len(payload)-1-i]
302 good &^= mask&paddingLen ^ mask&b
303 }
304
305
306
307 good &= good << 4
308 good &= good << 2
309 good &= good << 1
310 good = uint8(int8(good) >> 7)
311
312
313
314
315
316
317
318
319
320
321 paddingLen &= good
322
323 toRemove = int(paddingLen) + 1
324 return
325 }
326
327 func roundUp(a, b int) int {
328 return a + (b-a%b)%b
329 }
330
331
332 type cbcMode interface {
333 cipher.BlockMode
334 SetIV([]byte)
335 }
336
337
338
339 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
340 var plaintext []byte
341 typ := recordType(record[0])
342 payload := record[recordHeaderLen:]
343
344
345
346 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
347 return payload, typ, nil
348 }
349
350 paddingGood := byte(255)
351 paddingLen := 0
352
353 explicitNonceLen := hc.explicitNonceLen()
354
355 if hc.cipher != nil {
356 switch c := hc.cipher.(type) {
357 case cipher.Stream:
358 c.XORKeyStream(payload, payload)
359 case aead:
360 if len(payload) < explicitNonceLen {
361 return nil, 0, alertBadRecordMAC
362 }
363 nonce := payload[:explicitNonceLen]
364 if len(nonce) == 0 {
365 nonce = hc.seq[:]
366 }
367 payload = payload[explicitNonceLen:]
368
369 var additionalData []byte
370 if hc.version == VersionTLS13 {
371 additionalData = record[:recordHeaderLen]
372 } else {
373 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
374 additionalData = append(additionalData, record[:3]...)
375 n := len(payload) - c.Overhead()
376 additionalData = append(additionalData, byte(n>>8), byte(n))
377 }
378
379 var err error
380 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
381 if err != nil {
382 return nil, 0, alertBadRecordMAC
383 }
384 case cbcMode:
385 blockSize := c.BlockSize()
386 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
387 if len(payload)%blockSize != 0 || len(payload) < minPayload {
388 return nil, 0, alertBadRecordMAC
389 }
390
391 if explicitNonceLen > 0 {
392 c.SetIV(payload[:explicitNonceLen])
393 payload = payload[explicitNonceLen:]
394 }
395 c.CryptBlocks(payload, payload)
396
397
398
399
400
401
402
403 paddingLen, paddingGood = extractPadding(payload)
404 default:
405 panic("unknown cipher type")
406 }
407
408 if hc.version == VersionTLS13 {
409 if typ != recordTypeApplicationData {
410 return nil, 0, alertUnexpectedMessage
411 }
412 if len(plaintext) > maxPlaintext+1 {
413 return nil, 0, alertRecordOverflow
414 }
415
416 for i := len(plaintext) - 1; i >= 0; i-- {
417 if plaintext[i] != 0 {
418 typ = recordType(plaintext[i])
419 plaintext = plaintext[:i]
420 break
421 }
422 if i == 0 {
423 return nil, 0, alertUnexpectedMessage
424 }
425 }
426 }
427 } else {
428 plaintext = payload
429 }
430
431 if hc.mac != nil {
432 macSize := hc.mac.Size()
433 if len(payload) < macSize {
434 return nil, 0, alertBadRecordMAC
435 }
436
437 n := len(payload) - macSize - paddingLen
438 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
439 record[3] = byte(n >> 8)
440 record[4] = byte(n)
441 remoteMAC := payload[n : n+macSize]
442 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
443
444
445
446
447
448
449
450
451 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
452 if macAndPaddingGood != 1 {
453 return nil, 0, alertBadRecordMAC
454 }
455
456 plaintext = payload[:n]
457 }
458
459 hc.incSeq()
460 return plaintext, typ, nil
461 }
462
463
464
465
466 func sliceForAppend(in []byte, n int) (head, tail []byte) {
467 if total := len(in) + n; cap(in) >= total {
468 head = in[:total]
469 } else {
470 head = make([]byte, total)
471 copy(head, in)
472 }
473 tail = head[len(in):]
474 return
475 }
476
477
478
479 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
480 if hc.cipher == nil {
481 return append(record, payload...), nil
482 }
483
484 var explicitNonce []byte
485 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
486 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
487 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
488
489
490
491
492
493
494
495
496
497 copy(explicitNonce, hc.seq[:])
498 } else {
499 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
500 return nil, err
501 }
502 }
503 }
504
505 var dst []byte
506 switch c := hc.cipher.(type) {
507 case cipher.Stream:
508 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
509 record, dst = sliceForAppend(record, len(payload)+len(mac))
510 c.XORKeyStream(dst[:len(payload)], payload)
511 c.XORKeyStream(dst[len(payload):], mac)
512 case aead:
513 nonce := explicitNonce
514 if len(nonce) == 0 {
515 nonce = hc.seq[:]
516 }
517
518 if hc.version == VersionTLS13 {
519 record = append(record, payload...)
520
521
522 record = append(record, record[0])
523 record[0] = byte(recordTypeApplicationData)
524
525 n := len(payload) + 1 + c.Overhead()
526 record[3] = byte(n >> 8)
527 record[4] = byte(n)
528
529 record = c.Seal(record[:recordHeaderLen],
530 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
531 } else {
532 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
533 additionalData = append(additionalData, record[:recordHeaderLen]...)
534 record = c.Seal(record, nonce, payload, additionalData)
535 }
536 case cbcMode:
537 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
538 blockSize := c.BlockSize()
539 plaintextLen := len(payload) + len(mac)
540 paddingLen := blockSize - plaintextLen%blockSize
541 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
542 copy(dst, payload)
543 copy(dst[len(payload):], mac)
544 for i := plaintextLen; i < len(dst); i++ {
545 dst[i] = byte(paddingLen - 1)
546 }
547 if len(explicitNonce) > 0 {
548 c.SetIV(explicitNonce)
549 }
550 c.CryptBlocks(dst, dst)
551 default:
552 panic("unknown cipher type")
553 }
554
555
556 n := len(record) - recordHeaderLen
557 record[3] = byte(n >> 8)
558 record[4] = byte(n)
559 hc.incSeq()
560
561 return record, nil
562 }
563
564
565 type RecordHeaderError struct {
566
567 Msg string
568
569
570 RecordHeader [5]byte
571
572
573
574
575 Conn net.Conn
576 }
577
578 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
579
580 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
581 err.Msg = msg
582 err.Conn = conn
583 copy(err.RecordHeader[:], c.rawInput.Bytes())
584 return err
585 }
586
587 func (c *Conn) readRecord() error {
588 return c.readRecordOrCCS(false)
589 }
590
591 func (c *Conn) readChangeCipherSpec() error {
592 return c.readRecordOrCCS(true)
593 }
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
610 if c.in.err != nil {
611 return c.in.err
612 }
613 handshakeComplete := c.isHandshakeComplete.Load()
614
615
616 if c.input.Len() != 0 {
617 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
618 }
619 c.input.Reset(nil)
620
621 if c.quic != nil {
622 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
623 }
624
625
626 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
627
628
629
630 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
631 err = io.EOF
632 }
633 if e, ok := err.(net.Error); !ok || !e.Temporary() {
634 c.in.setErrorLocked(err)
635 }
636 return err
637 }
638 hdr := c.rawInput.Bytes()[:recordHeaderLen]
639 typ := recordType(hdr[0])
640
641
642
643
644
645 if !handshakeComplete && typ == 0x80 {
646 c.sendAlert(alertProtocolVersion)
647 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
648 }
649
650 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
651 expectedVers := c.vers
652 if expectedVers == VersionTLS13 {
653
654
655 expectedVers = VersionTLS12
656 }
657 n := int(hdr[3])<<8 | int(hdr[4])
658 if c.haveVers && vers != expectedVers {
659 c.sendAlert(alertProtocolVersion)
660 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
661 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
662 }
663 if !c.haveVers {
664
665
666
667
668 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
669 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
670 }
671 }
672 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
673 c.sendAlert(alertRecordOverflow)
674 msg := fmt.Sprintf("oversized record received with length %d", n)
675 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
676 }
677 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
678 if e, ok := err.(net.Error); !ok || !e.Temporary() {
679 c.in.setErrorLocked(err)
680 }
681 return err
682 }
683
684
685 record := c.rawInput.Next(recordHeaderLen + n)
686 data, typ, err := c.in.decrypt(record)
687 if err != nil {
688 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
689 }
690 if len(data) > maxPlaintext {
691 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
692 }
693
694
695 if c.in.cipher == nil && typ == recordTypeApplicationData {
696 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
697 }
698
699 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
700
701 c.retryCount = 0
702 }
703
704
705 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
706 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
707 }
708
709 switch typ {
710 default:
711 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
712
713 case recordTypeAlert:
714 if c.quic != nil {
715 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
716 }
717 if len(data) != 2 {
718 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
719 }
720 if alert(data[1]) == alertCloseNotify {
721 return c.in.setErrorLocked(io.EOF)
722 }
723 if c.vers == VersionTLS13 {
724 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
725 }
726 switch data[0] {
727 case alertLevelWarning:
728
729 return c.retryReadRecord(expectChangeCipherSpec)
730 case alertLevelError:
731 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
732 default:
733 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
734 }
735
736 case recordTypeChangeCipherSpec:
737 if len(data) != 1 || data[0] != 1 {
738 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
739 }
740
741 if c.hand.Len() > 0 {
742 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
743 }
744
745
746
747
748
749 if c.vers == VersionTLS13 {
750 return c.retryReadRecord(expectChangeCipherSpec)
751 }
752 if !expectChangeCipherSpec {
753 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
754 }
755 if err := c.in.changeCipherSpec(); err != nil {
756 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
757 }
758
759 case recordTypeApplicationData:
760 if !handshakeComplete || expectChangeCipherSpec {
761 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
762 }
763
764
765 if len(data) == 0 {
766 return c.retryReadRecord(expectChangeCipherSpec)
767 }
768
769
770
771 c.input.Reset(data)
772
773 case recordTypeHandshake:
774 if len(data) == 0 || expectChangeCipherSpec {
775 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
776 }
777 c.hand.Write(data)
778 }
779
780 return nil
781 }
782
783
784
785 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
786 c.retryCount++
787 if c.retryCount > maxUselessRecords {
788 c.sendAlert(alertUnexpectedMessage)
789 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
790 }
791 return c.readRecordOrCCS(expectChangeCipherSpec)
792 }
793
794
795
796
797 type atLeastReader struct {
798 R io.Reader
799 N int64
800 }
801
802 func (r *atLeastReader) Read(p []byte) (int, error) {
803 if r.N <= 0 {
804 return 0, io.EOF
805 }
806 n, err := r.R.Read(p)
807 r.N -= int64(n)
808 if r.N > 0 && err == io.EOF {
809 return n, io.ErrUnexpectedEOF
810 }
811 if r.N <= 0 && err == nil {
812 return n, io.EOF
813 }
814 return n, err
815 }
816
817
818
819 func (c *Conn) readFromUntil(r io.Reader, n int) error {
820 if c.rawInput.Len() >= n {
821 return nil
822 }
823 needs := n - c.rawInput.Len()
824
825
826
827 c.rawInput.Grow(needs + bytes.MinRead)
828 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
829 return err
830 }
831
832
833 func (c *Conn) sendAlertLocked(err alert) error {
834 if c.quic != nil {
835 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
836 }
837
838 switch err {
839 case alertNoRenegotiation, alertCloseNotify:
840 c.tmp[0] = alertLevelWarning
841 default:
842 c.tmp[0] = alertLevelError
843 }
844 c.tmp[1] = byte(err)
845
846 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
847 if err == alertCloseNotify {
848
849 return writeErr
850 }
851
852 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
853 }
854
855
856 func (c *Conn) sendAlert(err alert) error {
857 c.out.Lock()
858 defer c.out.Unlock()
859 return c.sendAlertLocked(err)
860 }
861
862 const (
863
864
865
866
867
868 tcpMSSEstimate = 1208
869
870
871
872
873 recordSizeBoostThreshold = 128 * 1024
874 )
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
893 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
894 return maxPlaintext
895 }
896
897 if c.bytesSent >= recordSizeBoostThreshold {
898 return maxPlaintext
899 }
900
901
902 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
903 if c.out.cipher != nil {
904 switch ciph := c.out.cipher.(type) {
905 case cipher.Stream:
906 payloadBytes -= c.out.mac.Size()
907 case cipher.AEAD:
908 payloadBytes -= ciph.Overhead()
909 case cbcMode:
910 blockSize := ciph.BlockSize()
911
912
913 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
914
915
916 payloadBytes -= c.out.mac.Size()
917 default:
918 panic("unknown cipher type")
919 }
920 }
921 if c.vers == VersionTLS13 {
922 payloadBytes--
923 }
924
925
926 pkt := c.packetsSent
927 c.packetsSent++
928 if pkt > 1000 {
929 return maxPlaintext
930 }
931
932 n := payloadBytes * int(pkt+1)
933 if n > maxPlaintext {
934 n = maxPlaintext
935 }
936 return n
937 }
938
939 func (c *Conn) write(data []byte) (int, error) {
940 if c.buffering {
941 c.sendBuf = append(c.sendBuf, data...)
942 return len(data), nil
943 }
944
945 n, err := c.conn.Write(data)
946 c.bytesSent += int64(n)
947 return n, err
948 }
949
950 func (c *Conn) flush() (int, error) {
951 if len(c.sendBuf) == 0 {
952 return 0, nil
953 }
954
955 n, err := c.conn.Write(c.sendBuf)
956 c.bytesSent += int64(n)
957 c.sendBuf = nil
958 c.buffering = false
959 return n, err
960 }
961
962
963 var outBufPool = sync.Pool{
964 New: func() any {
965 return new([]byte)
966 },
967 }
968
969
970
971 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
972 if c.quic != nil {
973 if typ != recordTypeHandshake {
974 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
975 }
976 c.quicWriteCryptoData(c.out.level, data)
977 if !c.buffering {
978 if _, err := c.flush(); err != nil {
979 return 0, err
980 }
981 }
982 return len(data), nil
983 }
984
985 outBufPtr := outBufPool.Get().(*[]byte)
986 outBuf := *outBufPtr
987 defer func() {
988
989
990
991
992
993 *outBufPtr = outBuf
994 outBufPool.Put(outBufPtr)
995 }()
996
997 var n int
998 for len(data) > 0 {
999 m := len(data)
1000 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1001 m = maxPayload
1002 }
1003
1004 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1005 outBuf[0] = byte(typ)
1006 vers := c.vers
1007 if vers == 0 {
1008
1009
1010 vers = VersionTLS10
1011 } else if vers == VersionTLS13 {
1012
1013
1014 vers = VersionTLS12
1015 }
1016 outBuf[1] = byte(vers >> 8)
1017 outBuf[2] = byte(vers)
1018 outBuf[3] = byte(m >> 8)
1019 outBuf[4] = byte(m)
1020
1021 var err error
1022 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1023 if err != nil {
1024 return n, err
1025 }
1026 if _, err := c.write(outBuf); err != nil {
1027 return n, err
1028 }
1029 n += m
1030 data = data[m:]
1031 }
1032
1033 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1034 if err := c.out.changeCipherSpec(); err != nil {
1035 return n, c.sendAlertLocked(err.(alert))
1036 }
1037 }
1038
1039 return n, nil
1040 }
1041
1042
1043
1044
1045 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1046 c.out.Lock()
1047 defer c.out.Unlock()
1048
1049 data, err := msg.marshal()
1050 if err != nil {
1051 return 0, err
1052 }
1053 if transcript != nil {
1054 transcript.Write(data)
1055 }
1056
1057 return c.writeRecordLocked(recordTypeHandshake, data)
1058 }
1059
1060
1061
1062 func (c *Conn) writeChangeCipherRecord() error {
1063 c.out.Lock()
1064 defer c.out.Unlock()
1065 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1066 return err
1067 }
1068
1069
1070 func (c *Conn) readHandshakeBytes(n int) error {
1071 if c.quic != nil {
1072 return c.quicReadHandshakeBytes(n)
1073 }
1074 for c.hand.Len() < n {
1075 if err := c.readRecord(); err != nil {
1076 return err
1077 }
1078 }
1079 return nil
1080 }
1081
1082
1083
1084
1085 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1086 if err := c.readHandshakeBytes(4); err != nil {
1087 return nil, err
1088 }
1089 data := c.hand.Bytes()
1090 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1091 if n > maxHandshake {
1092 c.sendAlertLocked(alertInternalError)
1093 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1094 }
1095 if err := c.readHandshakeBytes(4 + n); err != nil {
1096 return nil, err
1097 }
1098 data = c.hand.Next(4 + n)
1099 return c.unmarshalHandshakeMessage(data, transcript)
1100 }
1101
1102 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1103 var m handshakeMessage
1104 switch data[0] {
1105 case typeHelloRequest:
1106 m = new(helloRequestMsg)
1107 case typeClientHello:
1108 m = new(clientHelloMsg)
1109 case typeServerHello:
1110 m = new(serverHelloMsg)
1111 case typeNewSessionTicket:
1112 if c.vers == VersionTLS13 {
1113 m = new(newSessionTicketMsgTLS13)
1114 } else {
1115 m = new(newSessionTicketMsg)
1116 }
1117 case typeCertificate:
1118 if c.vers == VersionTLS13 {
1119 m = new(certificateMsgTLS13)
1120 } else {
1121 m = new(certificateMsg)
1122 }
1123 case typeCertificateRequest:
1124 if c.vers == VersionTLS13 {
1125 m = new(certificateRequestMsgTLS13)
1126 } else {
1127 m = &certificateRequestMsg{
1128 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1129 }
1130 }
1131 case typeCertificateStatus:
1132 m = new(certificateStatusMsg)
1133 case typeServerKeyExchange:
1134 m = new(serverKeyExchangeMsg)
1135 case typeServerHelloDone:
1136 m = new(serverHelloDoneMsg)
1137 case typeClientKeyExchange:
1138 m = new(clientKeyExchangeMsg)
1139 case typeCertificateVerify:
1140 m = &certificateVerifyMsg{
1141 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1142 }
1143 case typeFinished:
1144 m = new(finishedMsg)
1145 case typeEncryptedExtensions:
1146 m = new(encryptedExtensionsMsg)
1147 case typeEndOfEarlyData:
1148 m = new(endOfEarlyDataMsg)
1149 case typeKeyUpdate:
1150 m = new(keyUpdateMsg)
1151 default:
1152 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1153 }
1154
1155
1156
1157
1158 data = append([]byte(nil), data...)
1159
1160 if !m.unmarshal(data) {
1161 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1162 }
1163
1164 if transcript != nil {
1165 transcript.Write(data)
1166 }
1167
1168 return m, nil
1169 }
1170
1171 var (
1172 errShutdown = errors.New("tls: protocol is shutdown")
1173 )
1174
1175
1176
1177
1178
1179
1180
1181 func (c *Conn) Write(b []byte) (int, error) {
1182
1183 for {
1184 x := c.activeCall.Load()
1185 if x&1 != 0 {
1186 return 0, net.ErrClosed
1187 }
1188 if c.activeCall.CompareAndSwap(x, x+2) {
1189 break
1190 }
1191 }
1192 defer c.activeCall.Add(-2)
1193
1194 if err := c.Handshake(); err != nil {
1195 return 0, err
1196 }
1197
1198 c.out.Lock()
1199 defer c.out.Unlock()
1200
1201 if err := c.out.err; err != nil {
1202 return 0, err
1203 }
1204
1205 if !c.isHandshakeComplete.Load() {
1206 return 0, alertInternalError
1207 }
1208
1209 if c.closeNotifySent {
1210 return 0, errShutdown
1211 }
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222 var m int
1223 if len(b) > 1 && c.vers == VersionTLS10 {
1224 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1225 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1226 if err != nil {
1227 return n, c.out.setErrorLocked(err)
1228 }
1229 m, b = 1, b[1:]
1230 }
1231 }
1232
1233 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1234 return n + m, c.out.setErrorLocked(err)
1235 }
1236
1237
1238 func (c *Conn) handleRenegotiation() error {
1239 if c.vers == VersionTLS13 {
1240 return errors.New("tls: internal error: unexpected renegotiation")
1241 }
1242
1243 msg, err := c.readHandshake(nil)
1244 if err != nil {
1245 return err
1246 }
1247
1248 helloReq, ok := msg.(*helloRequestMsg)
1249 if !ok {
1250 c.sendAlert(alertUnexpectedMessage)
1251 return unexpectedMessageError(helloReq, msg)
1252 }
1253
1254 if !c.isClient {
1255 return c.sendAlert(alertNoRenegotiation)
1256 }
1257
1258 switch c.config.Renegotiation {
1259 case RenegotiateNever:
1260 return c.sendAlert(alertNoRenegotiation)
1261 case RenegotiateOnceAsClient:
1262 if c.handshakes > 1 {
1263 return c.sendAlert(alertNoRenegotiation)
1264 }
1265 case RenegotiateFreelyAsClient:
1266
1267 default:
1268 c.sendAlert(alertInternalError)
1269 return errors.New("tls: unknown Renegotiation value")
1270 }
1271
1272 c.handshakeMutex.Lock()
1273 defer c.handshakeMutex.Unlock()
1274
1275 c.isHandshakeComplete.Store(false)
1276 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1277 c.handshakes++
1278 }
1279 return c.handshakeErr
1280 }
1281
1282
1283
1284 func (c *Conn) handlePostHandshakeMessage() error {
1285 if c.vers != VersionTLS13 {
1286 return c.handleRenegotiation()
1287 }
1288
1289 msg, err := c.readHandshake(nil)
1290 if err != nil {
1291 return err
1292 }
1293 c.retryCount++
1294 if c.retryCount > maxUselessRecords {
1295 c.sendAlert(alertUnexpectedMessage)
1296 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1297 }
1298
1299 switch msg := msg.(type) {
1300 case *newSessionTicketMsgTLS13:
1301 return c.handleNewSessionTicket(msg)
1302 case *keyUpdateMsg:
1303 return c.handleKeyUpdate(msg)
1304 }
1305
1306
1307
1308
1309 c.sendAlert(alertUnexpectedMessage)
1310 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1311 }
1312
1313 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1314 if c.quic != nil {
1315 c.sendAlert(alertUnexpectedMessage)
1316 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1317 }
1318
1319 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1320 if cipherSuite == nil {
1321 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1322 }
1323
1324 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1325 c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1326
1327 if keyUpdate.updateRequested {
1328 c.out.Lock()
1329 defer c.out.Unlock()
1330
1331 msg := &keyUpdateMsg{}
1332 msgBytes, err := msg.marshal()
1333 if err != nil {
1334 return err
1335 }
1336 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1337 if err != nil {
1338
1339 c.out.setErrorLocked(err)
1340 return nil
1341 }
1342
1343 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1344 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1345 }
1346
1347 return nil
1348 }
1349
1350
1351
1352
1353
1354
1355
1356 func (c *Conn) Read(b []byte) (int, error) {
1357 if err := c.Handshake(); err != nil {
1358 return 0, err
1359 }
1360 if len(b) == 0 {
1361
1362
1363 return 0, nil
1364 }
1365
1366 c.in.Lock()
1367 defer c.in.Unlock()
1368
1369 for c.input.Len() == 0 {
1370 if err := c.readRecord(); err != nil {
1371 return 0, err
1372 }
1373 for c.hand.Len() > 0 {
1374 if err := c.handlePostHandshakeMessage(); err != nil {
1375 return 0, err
1376 }
1377 }
1378 }
1379
1380 n, _ := c.input.Read(b)
1381
1382
1383
1384
1385
1386
1387
1388
1389 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1390 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1391 if err := c.readRecord(); err != nil {
1392 return n, err
1393 }
1394 }
1395
1396 return n, nil
1397 }
1398
1399
1400 func (c *Conn) Close() error {
1401
1402 var x int32
1403 for {
1404 x = c.activeCall.Load()
1405 if x&1 != 0 {
1406 return net.ErrClosed
1407 }
1408 if c.activeCall.CompareAndSwap(x, x|1) {
1409 break
1410 }
1411 }
1412 if x != 0 {
1413
1414
1415
1416
1417
1418
1419 return c.conn.Close()
1420 }
1421
1422 var alertErr error
1423 if c.isHandshakeComplete.Load() {
1424 if err := c.closeNotify(); err != nil {
1425 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1426 }
1427 }
1428
1429 if err := c.conn.Close(); err != nil {
1430 return err
1431 }
1432 return alertErr
1433 }
1434
1435 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1436
1437
1438
1439
1440 func (c *Conn) CloseWrite() error {
1441 if !c.isHandshakeComplete.Load() {
1442 return errEarlyCloseWrite
1443 }
1444
1445 return c.closeNotify()
1446 }
1447
1448 func (c *Conn) closeNotify() error {
1449 c.out.Lock()
1450 defer c.out.Unlock()
1451
1452 if !c.closeNotifySent {
1453
1454 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1455 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1456 c.closeNotifySent = true
1457
1458 c.SetWriteDeadline(time.Now())
1459 }
1460 return c.closeNotifyErr
1461 }
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476 func (c *Conn) Handshake() error {
1477 return c.HandshakeContext(context.Background())
1478 }
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490 func (c *Conn) HandshakeContext(ctx context.Context) error {
1491
1492
1493 return c.handshakeContext(ctx)
1494 }
1495
1496 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1497
1498
1499
1500 if c.isHandshakeComplete.Load() {
1501 return nil
1502 }
1503
1504 handshakeCtx, cancel := context.WithCancel(ctx)
1505
1506
1507
1508 defer cancel()
1509
1510 if c.quic != nil {
1511 c.quic.cancelc = handshakeCtx.Done()
1512 c.quic.cancel = cancel
1513 } else if ctx.Done() != nil {
1514
1515
1516
1517
1518
1519 done := make(chan struct{})
1520 interruptRes := make(chan error, 1)
1521 defer func() {
1522 close(done)
1523 if ctxErr := <-interruptRes; ctxErr != nil {
1524
1525 ret = ctxErr
1526 }
1527 }()
1528 go func() {
1529 select {
1530 case <-handshakeCtx.Done():
1531
1532 _ = c.conn.Close()
1533 interruptRes <- handshakeCtx.Err()
1534 case <-done:
1535 interruptRes <- nil
1536 }
1537 }()
1538 }
1539
1540 c.handshakeMutex.Lock()
1541 defer c.handshakeMutex.Unlock()
1542
1543 if err := c.handshakeErr; err != nil {
1544 return err
1545 }
1546 if c.isHandshakeComplete.Load() {
1547 return nil
1548 }
1549
1550 c.in.Lock()
1551 defer c.in.Unlock()
1552
1553 c.handshakeErr = c.handshakeFn(handshakeCtx)
1554 if c.handshakeErr == nil {
1555 c.handshakes++
1556 } else {
1557
1558
1559 c.flush()
1560 }
1561
1562 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1563 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1564 }
1565 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1566 panic("tls: internal error: handshake returned an error but is marked successful")
1567 }
1568
1569 if c.quic != nil {
1570 if c.handshakeErr == nil {
1571 c.quicHandshakeComplete()
1572
1573
1574
1575 c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1576 } else {
1577 var a alert
1578 c.out.Lock()
1579 if !errors.As(c.out.err, &a) {
1580 a = alertInternalError
1581 }
1582 c.out.Unlock()
1583
1584
1585
1586
1587 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1588 }
1589 close(c.quic.blockedc)
1590 close(c.quic.signalc)
1591 }
1592
1593 return c.handshakeErr
1594 }
1595
1596
1597 func (c *Conn) ConnectionState() ConnectionState {
1598 c.handshakeMutex.Lock()
1599 defer c.handshakeMutex.Unlock()
1600 return c.connectionStateLocked()
1601 }
1602
1603 var tlsunsafeekm = godebug.New("tlsunsafeekm")
1604
1605 func (c *Conn) connectionStateLocked() ConnectionState {
1606 var state ConnectionState
1607 state.HandshakeComplete = c.isHandshakeComplete.Load()
1608 state.Version = c.vers
1609 state.NegotiatedProtocol = c.clientProtocol
1610 state.DidResume = c.didResume
1611 state.NegotiatedProtocolIsMutual = true
1612 state.ServerName = c.serverName
1613 state.CipherSuite = c.cipherSuite
1614 state.PeerCertificates = c.peerCertificates
1615 state.VerifiedChains = c.verifiedChains
1616 state.SignedCertificateTimestamps = c.scts
1617 state.OCSPResponse = c.ocspResponse
1618 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1619 if c.clientFinishedIsFirst {
1620 state.TLSUnique = c.clientFinished[:]
1621 } else {
1622 state.TLSUnique = c.serverFinished[:]
1623 }
1624 }
1625 if c.config.Renegotiation != RenegotiateNever {
1626 state.ekm = noEKMBecauseRenegotiation
1627 } else if c.vers != VersionTLS13 && !c.extMasterSecret {
1628 state.ekm = func(label string, context []byte, length int) ([]byte, error) {
1629 if tlsunsafeekm.Value() == "1" {
1630 tlsunsafeekm.IncNonDefault()
1631 return c.ekm(label, context, length)
1632 }
1633 return noEKMBecauseNoEMS(label, context, length)
1634 }
1635 } else {
1636 state.ekm = c.ekm
1637 }
1638 return state
1639 }
1640
1641
1642
1643 func (c *Conn) OCSPResponse() []byte {
1644 c.handshakeMutex.Lock()
1645 defer c.handshakeMutex.Unlock()
1646
1647 return c.ocspResponse
1648 }
1649
1650
1651
1652
1653 func (c *Conn) VerifyHostname(host string) error {
1654 c.handshakeMutex.Lock()
1655 defer c.handshakeMutex.Unlock()
1656 if !c.isClient {
1657 return errors.New("tls: VerifyHostname called on TLS server connection")
1658 }
1659 if !c.isHandshakeComplete.Load() {
1660 return errors.New("tls: handshake has not yet been performed")
1661 }
1662 if len(c.verifiedChains) == 0 {
1663 return errors.New("tls: handshake did not verify certificate chain")
1664 }
1665 return c.peerCertificates[0].VerifyHostname(host)
1666 }
1667
View as plain text