1
2
3
4
5
6
7 package quic
8
9 import (
10 "bytes"
11 "context"
12 "crypto/tls"
13 "errors"
14 "flag"
15 "fmt"
16 "log/slog"
17 "math"
18 "net/netip"
19 "reflect"
20 "strings"
21 "testing"
22 "time"
23
24 "golang.org/x/net/internal/quic/qlog"
25 )
26
27 var (
28 testVV = flag.Bool("vv", false, "even more verbose test output")
29 qlogdir = flag.String("qlog", "", "write qlog logs to directory")
30 )
31
32 func TestConnTestConn(t *testing.T) {
33 tc := newTestConn(t, serverSide)
34 tc.handshake()
35 if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
36 t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
37 }
38
39 ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
40 tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
41 when = now
42 })
43 return
44 }).result()
45 if !ranAt.Equal(tc.endpoint.now) {
46 t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
47 }
48 tc.wait()
49
50 nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
51 tc.advanceTo(nextTime)
52 ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
53 tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
54 when = now
55 })
56 return
57 }).result()
58 if !ranAt.Equal(nextTime) {
59 t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
60 }
61 tc.wait()
62
63 tc.advanceToTimer()
64 if got := tc.conn.lifetime.state; got != connStateDone {
65 t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
66 }
67 }
68
69 type testDatagram struct {
70 packets []*testPacket
71 paddedSize int
72 addr netip.AddrPort
73 }
74
75 func (d testDatagram) String() string {
76 var b strings.Builder
77 fmt.Fprintf(&b, "datagram with %v packets", len(d.packets))
78 if d.paddedSize > 0 {
79 fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize)
80 }
81 b.WriteString(":")
82 for _, p := range d.packets {
83 b.WriteString("\n")
84 b.WriteString(p.String())
85 }
86 return b.String()
87 }
88
89 type testPacket struct {
90 ptype packetType
91 header byte
92 version uint32
93 num packetNumber
94 keyPhaseBit bool
95 keyNumber int
96 dstConnID []byte
97 srcConnID []byte
98 token []byte
99 originalDstConnID []byte
100 frames []debugFrame
101 }
102
103 func (p testPacket) String() string {
104 var b strings.Builder
105 fmt.Fprintf(&b, " %v %v", p.ptype, p.num)
106 if p.version != 0 {
107 fmt.Fprintf(&b, " version=%v", p.version)
108 }
109 if p.srcConnID != nil {
110 fmt.Fprintf(&b, " src={%x}", p.srcConnID)
111 }
112 if p.dstConnID != nil {
113 fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
114 }
115 if p.token != nil {
116 fmt.Fprintf(&b, " token={%x}", p.token)
117 }
118 for _, f := range p.frames {
119 fmt.Fprintf(&b, "\n %v", f)
120 }
121 return b.String()
122 }
123
124
125 const maxTestKeyPhases = 3
126
127
128
129 type testConn struct {
130 t *testing.T
131 conn *Conn
132 endpoint *testEndpoint
133 timer time.Time
134 timerLastFired time.Time
135 idlec chan struct{}
136
137
138
139
140
141
142
143
144 keysInitial fixedKeyPair
145 keysHandshake fixedKeyPair
146 rkeyAppData test1RTTKeys
147 wkeyAppData test1RTTKeys
148 rsecrets [numberSpaceCount]keySecret
149 wsecrets [numberSpaceCount]keySecret
150
151
152
153
154
155
156
157
158 cryptoDataOut map[tls.QUICEncryptionLevel][]byte
159 cryptoDataIn map[tls.QUICEncryptionLevel][]byte
160 peerTLSConn *tls.QUICConn
161
162
163 peerConnID []byte
164 peerNextPacketNum [numberSpaceCount]packetNumber
165
166
167
168 sentDatagrams [][]byte
169 sentPackets []*testPacket
170 sentFrames []debugFrame
171 lastPacket *testPacket
172
173 recvDatagram chan *datagram
174
175
176 sentTransportParameters *transportParameters
177
178
179 ignoreFrames map[byte]bool
180
181
182 sendKeyNumber int
183 sendKeyPhaseBit bool
184
185 asyncTestState
186 }
187
188 type test1RTTKeys struct {
189 hdr headerKey
190 pkt [maxTestKeyPhases]packetKey
191 }
192
193 type keySecret struct {
194 suite uint16
195 secret []byte
196 }
197
198
199
200
201
202
203 func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
204 t.Helper()
205 config := &Config{
206 TLSConfig: newTestTLSConfig(side),
207 StatelessResetKey: testStatelessResetKey,
208 QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
209 Level: QLogLevelFrame,
210 Dir: *qlogdir,
211 })),
212 }
213 var cids newServerConnIDs
214 if side == serverSide {
215
216 cids.srcConnID = testPeerConnID(0)
217 cids.dstConnID = testPeerConnID(-1)
218 cids.originalDstConnID = cids.dstConnID
219 }
220 var configTransportParams []func(*transportParameters)
221 var configTestConn []func(*testConn)
222 for _, o := range opts {
223 switch o := o.(type) {
224 case func(*Config):
225 o(config)
226 case func(*tls.Config):
227 o(config.TLSConfig)
228 case func(cids *newServerConnIDs):
229 o(&cids)
230 case func(p *transportParameters):
231 configTransportParams = append(configTransportParams, o)
232 case func(p *testConn):
233 configTestConn = append(configTestConn, o)
234 default:
235 t.Fatalf("unknown newTestConn option %T", o)
236 }
237 }
238
239 endpoint := newTestEndpoint(t, config)
240 endpoint.configTransportParams = configTransportParams
241 endpoint.configTestConn = configTestConn
242 conn, err := endpoint.e.newConn(
243 endpoint.now,
244 side,
245 cids,
246 netip.MustParseAddrPort("127.0.0.1:443"))
247 if err != nil {
248 t.Fatal(err)
249 }
250 tc := endpoint.conns[conn]
251 tc.wait()
252 return tc
253 }
254
255 func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
256 t.Helper()
257 tc := &testConn{
258 t: t,
259 endpoint: endpoint,
260 conn: conn,
261 peerConnID: testPeerConnID(0),
262 ignoreFrames: map[byte]bool{
263 frameTypePadding: true,
264 },
265 cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
266 cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte),
267 recvDatagram: make(chan *datagram),
268 }
269 t.Cleanup(tc.cleanup)
270 for _, f := range endpoint.configTestConn {
271 f(tc)
272 }
273 conn.testHooks = (*testConnHooks)(tc)
274
275 if endpoint.peerTLSConn != nil {
276 tc.peerTLSConn = endpoint.peerTLSConn
277 endpoint.peerTLSConn = nil
278 return tc
279 }
280
281 peerProvidedParams := defaultTransportParameters()
282 peerProvidedParams.initialSrcConnID = testPeerConnID(0)
283 if conn.side == clientSide {
284 peerProvidedParams.originalDstConnID = testLocalConnID(-1)
285 }
286 for _, f := range endpoint.configTransportParams {
287 f(&peerProvidedParams)
288 }
289
290 peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
291 if conn.side == clientSide {
292 tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
293 } else {
294 tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
295 }
296 tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
297 tc.peerTLSConn.Start(context.Background())
298 t.Cleanup(func() {
299 tc.peerTLSConn.Close()
300 })
301
302 return tc
303 }
304
305
306 func (tc *testConn) advance(d time.Duration) {
307 tc.t.Helper()
308 tc.endpoint.advance(d)
309 }
310
311
312 func (tc *testConn) advanceTo(now time.Time) {
313 tc.t.Helper()
314 tc.endpoint.advanceTo(now)
315 }
316
317
318 func (tc *testConn) advanceToTimer() {
319 if tc.timer.IsZero() {
320 tc.t.Fatalf("advancing to timer, but timer is not set")
321 }
322 tc.advanceTo(tc.timer)
323 }
324
325 func (tc *testConn) timerDelay() time.Duration {
326 if tc.timer.IsZero() {
327 return math.MaxInt64
328 }
329 if tc.timer.Before(tc.endpoint.now) {
330 return 0
331 }
332 return tc.timer.Sub(tc.endpoint.now)
333 }
334
335 const infiniteDuration = time.Duration(math.MaxInt64)
336
337
338 func (tc *testConn) timeUntilEvent() time.Duration {
339 if tc.timer.IsZero() {
340 return infiniteDuration
341 }
342 if tc.timer.Before(tc.endpoint.now) {
343 return 0
344 }
345 return tc.timer.Sub(tc.endpoint.now)
346 }
347
348
349
350
351
352 func (tc *testConn) wait() {
353 tc.t.Helper()
354 idlec := make(chan struct{})
355 fail := false
356 tc.conn.sendMsg(func(now time.Time, c *Conn) {
357 if tc.idlec != nil {
358 tc.t.Errorf("testConn.wait called concurrently")
359 fail = true
360 close(idlec)
361 } else {
362
363 tc.idlec = idlec
364 }
365 })
366 select {
367 case <-idlec:
368 case <-tc.conn.donec:
369
370 tc.wakeAsync()
371 }
372 if fail {
373 panic(fail)
374 }
375 }
376
377 func (tc *testConn) cleanup() {
378 if tc.conn == nil {
379 return
380 }
381 tc.conn.exit()
382 <-tc.conn.donec
383 }
384
385 func logDatagram(t *testing.T, text string, d *testDatagram) {
386 t.Helper()
387 if !*testVV {
388 return
389 }
390 pad := ""
391 if d.paddedSize > 0 {
392 pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
393 }
394 t.Logf("%v datagram%v", text, pad)
395 for _, p := range d.packets {
396 var s string
397 switch p.ptype {
398 case packetType1RTT:
399 s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num)
400 default:
401 s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
402 }
403 if p.token != nil {
404 s += fmt.Sprintf(" token={%x}", p.token)
405 }
406 if p.keyPhaseBit {
407 s += fmt.Sprintf(" KeyPhase")
408 }
409 if p.keyNumber != 0 {
410 s += fmt.Sprintf(" keynum=%v", p.keyNumber)
411 }
412 t.Log(s)
413 for _, f := range p.frames {
414 t.Logf(" %v", f)
415 }
416 }
417 }
418
419
420 func (tc *testConn) write(d *testDatagram) {
421 tc.t.Helper()
422 tc.endpoint.writeDatagram(d)
423 }
424
425
426 func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
427 tc.t.Helper()
428 space := spaceForPacketType(ptype)
429 dstConnID := tc.conn.connIDState.local[0].cid
430 if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial {
431
432 dstConnID = tc.conn.connIDState.local[1].cid
433 }
434 d := &testDatagram{
435 packets: []*testPacket{{
436 ptype: ptype,
437 num: tc.peerNextPacketNum[space],
438 keyNumber: tc.sendKeyNumber,
439 keyPhaseBit: tc.sendKeyPhaseBit,
440 frames: frames,
441 version: quicVersion1,
442 dstConnID: dstConnID,
443 srcConnID: tc.peerConnID,
444 }},
445 }
446 if ptype == packetTypeInitial && tc.conn.side == serverSide {
447 d.paddedSize = 1200
448 }
449 tc.write(d)
450 }
451
452
453
454 func (tc *testConn) writeAckForAll() {
455 tc.t.Helper()
456 if tc.lastPacket == nil {
457 return
458 }
459 tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
460 ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
461 })
462 }
463
464
465
466 func (tc *testConn) writeAckForLatest() {
467 tc.t.Helper()
468 if tc.lastPacket == nil {
469 return
470 }
471 tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
472 ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}},
473 })
474 }
475
476
477 func (tc *testConn) ignoreFrame(frameType byte) {
478 tc.ignoreFrames[frameType] = true
479 }
480
481
482
483 func (tc *testConn) readDatagram() *testDatagram {
484 tc.t.Helper()
485 tc.wait()
486 tc.sentPackets = nil
487 tc.sentFrames = nil
488 buf := tc.endpoint.read()
489 if buf == nil {
490 return nil
491 }
492 d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
493
494
495 logDatagram(tc.t, "-> conn under test sends", d)
496 typeForFrame := func(f debugFrame) byte {
497
498
499
500
501
502
503
504 switch f := f.(type) {
505 case debugFramePadding:
506 return frameTypePadding
507 case debugFramePing:
508 return frameTypePing
509 case debugFrameAck:
510 return frameTypeAck
511 case debugFrameResetStream:
512 return frameTypeResetStream
513 case debugFrameStopSending:
514 return frameTypeStopSending
515 case debugFrameCrypto:
516 return frameTypeCrypto
517 case debugFrameNewToken:
518 return frameTypeNewToken
519 case debugFrameStream:
520 return frameTypeStreamBase
521 case debugFrameMaxData:
522 return frameTypeMaxData
523 case debugFrameMaxStreamData:
524 return frameTypeMaxStreamData
525 case debugFrameMaxStreams:
526 if f.streamType == bidiStream {
527 return frameTypeMaxStreamsBidi
528 } else {
529 return frameTypeMaxStreamsUni
530 }
531 case debugFrameDataBlocked:
532 return frameTypeDataBlocked
533 case debugFrameStreamDataBlocked:
534 return frameTypeStreamDataBlocked
535 case debugFrameStreamsBlocked:
536 if f.streamType == bidiStream {
537 return frameTypeStreamsBlockedBidi
538 } else {
539 return frameTypeStreamsBlockedUni
540 }
541 case debugFrameNewConnectionID:
542 return frameTypeNewConnectionID
543 case debugFrameRetireConnectionID:
544 return frameTypeRetireConnectionID
545 case debugFramePathChallenge:
546 return frameTypePathChallenge
547 case debugFramePathResponse:
548 return frameTypePathResponse
549 case debugFrameConnectionCloseTransport:
550 return frameTypeConnectionCloseTransport
551 case debugFrameConnectionCloseApplication:
552 return frameTypeConnectionCloseApplication
553 case debugFrameHandshakeDone:
554 return frameTypeHandshakeDone
555 }
556 panic(fmt.Errorf("unhandled frame type %T", f))
557 }
558 for _, p := range d.packets {
559 var frames []debugFrame
560 for _, f := range p.frames {
561 if !tc.ignoreFrames[typeForFrame(f)] {
562 frames = append(frames, f)
563 }
564 }
565 p.frames = frames
566 }
567 return d
568 }
569
570
571
572 func (tc *testConn) readPacket() *testPacket {
573 tc.t.Helper()
574 for len(tc.sentPackets) == 0 {
575 d := tc.readDatagram()
576 if d == nil {
577 return nil
578 }
579 for _, p := range d.packets {
580 if len(p.frames) == 0 {
581 tc.lastPacket = p
582 continue
583 }
584 tc.sentPackets = append(tc.sentPackets, p)
585 }
586 }
587 p := tc.sentPackets[0]
588 tc.sentPackets = tc.sentPackets[1:]
589 tc.lastPacket = p
590 return p
591 }
592
593
594
595 func (tc *testConn) readFrame() (debugFrame, packetType) {
596 tc.t.Helper()
597 for len(tc.sentFrames) == 0 {
598 p := tc.readPacket()
599 if p == nil {
600 return nil, packetTypeInvalid
601 }
602 tc.sentFrames = p.frames
603 }
604 f := tc.sentFrames[0]
605 tc.sentFrames = tc.sentFrames[1:]
606 return f, tc.lastPacket.ptype
607 }
608
609
610 func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
611 tc.t.Helper()
612 got := tc.readDatagram()
613 if !datagramEqual(got, want) {
614 tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
615 }
616 }
617
618 func datagramEqual(a, b *testDatagram) bool {
619 if a == nil && b == nil {
620 return true
621 }
622 if a == nil || b == nil {
623 return false
624 }
625 if a.paddedSize != b.paddedSize ||
626 a.addr != b.addr ||
627 len(a.packets) != len(b.packets) {
628 return false
629 }
630 for i := range a.packets {
631 if !packetEqual(a.packets[i], b.packets[i]) {
632 return false
633 }
634 }
635 return true
636 }
637
638
639 func (tc *testConn) wantPacket(expectation string, want *testPacket) {
640 tc.t.Helper()
641 got := tc.readPacket()
642 if !packetEqual(got, want) {
643 tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
644 }
645 }
646
647 func packetEqual(a, b *testPacket) bool {
648 ac := *a
649 ac.frames = nil
650 ac.header = 0
651 bc := *b
652 bc.frames = nil
653 bc.header = 0
654 if !reflect.DeepEqual(ac, bc) {
655 return false
656 }
657 if len(a.frames) != len(b.frames) {
658 return false
659 }
660 for i := range a.frames {
661 if !frameEqual(a.frames[i], b.frames[i]) {
662 return false
663 }
664 }
665 return true
666 }
667
668
669 func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
670 tc.t.Helper()
671 got, gotType := tc.readFrame()
672 if got == nil {
673 tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
674 }
675 if gotType != wantType {
676 tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
677 }
678 if !frameEqual(got, want) {
679 tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
680 }
681 }
682
683 func frameEqual(a, b debugFrame) bool {
684 switch af := a.(type) {
685 case debugFrameConnectionCloseTransport:
686 bf, ok := b.(debugFrameConnectionCloseTransport)
687 return ok && af.code == bf.code
688 }
689 return reflect.DeepEqual(a, b)
690 }
691
692
693
694 func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
695 tc.t.Helper()
696 got, gotType := tc.readFrame()
697 if got == nil {
698 tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
699 }
700 if gotType != wantType {
701 tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
702 }
703 if reflect.TypeOf(got) != reflect.TypeOf(want) {
704 tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want)
705 }
706 }
707
708
709 func (tc *testConn) wantIdle(expectation string) {
710 tc.t.Helper()
711 switch {
712 case len(tc.sentFrames) > 0:
713 tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0])
714 case len(tc.sentPackets) > 0:
715 tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0])
716 }
717 if f, _ := tc.readFrame(); f != nil {
718 tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f)
719 }
720 }
721
722 func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
723 t.Helper()
724 var w packetWriter
725 w.reset(1200)
726 var pnumMaxAcked packetNumber
727 switch p.ptype {
728 case packetTypeRetry:
729 return encodeRetryPacket(p.originalDstConnID, retryPacket{
730 srcConnID: p.srcConnID,
731 dstConnID: p.dstConnID,
732 token: p.token,
733 })
734 case packetType1RTT:
735 w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
736 default:
737 w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
738 ptype: p.ptype,
739 version: p.version,
740 num: p.num,
741 dstConnID: p.dstConnID,
742 srcConnID: p.srcConnID,
743 extra: p.token,
744 })
745 }
746 for _, f := range p.frames {
747 f.write(&w)
748 }
749 w.appendPaddingTo(pad)
750 if p.ptype != packetType1RTT {
751 var k fixedKeys
752 if tc == nil {
753 if p.ptype == packetTypeInitial {
754 k = initialKeys(p.dstConnID, serverSide).r
755 } else {
756 t.Fatalf("sending %v packet with no conn", p.ptype)
757 }
758 } else {
759 switch p.ptype {
760 case packetTypeInitial:
761 k = tc.keysInitial.w
762 case packetTypeHandshake:
763 k = tc.keysHandshake.w
764 }
765 }
766 if !k.isSet() {
767 t.Fatalf("sending %v packet with no write key", p.ptype)
768 }
769 w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
770 ptype: p.ptype,
771 version: p.version,
772 num: p.num,
773 dstConnID: p.dstConnID,
774 srcConnID: p.srcConnID,
775 extra: p.token,
776 })
777 } else {
778 if tc == nil || !tc.wkeyAppData.hdr.isSet() {
779 t.Fatalf("sending 1-RTT packet with no write key")
780 }
781
782
783 k := &updatingKeyPair{
784 w: updatingKeys{
785 hdr: tc.wkeyAppData.hdr,
786 pkt: [2]packetKey{
787 tc.wkeyAppData.pkt[p.keyNumber],
788 tc.wkeyAppData.pkt[p.keyNumber],
789 },
790 },
791 updateAfter: maxPacketNumber,
792 }
793 if p.keyPhaseBit {
794 k.phase |= keyPhaseBit
795 }
796 w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
797 }
798 return w.datagram()
799 }
800
801 func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
802 t.Helper()
803 bufSize := len(buf)
804 d := &testDatagram{}
805 size := len(buf)
806 for len(buf) > 0 {
807 if buf[0] == 0 {
808 d.paddedSize = bufSize
809 break
810 }
811 ptype := getPacketType(buf)
812 switch ptype {
813 case packetTypeRetry:
814 retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
815 if !ok {
816 t.Fatalf("could not parse %v packet", ptype)
817 }
818 return &testDatagram{
819 packets: []*testPacket{{
820 ptype: packetTypeRetry,
821 dstConnID: retry.dstConnID,
822 srcConnID: retry.srcConnID,
823 token: retry.token,
824 }},
825 }
826 case packetTypeInitial, packetTypeHandshake:
827 var k fixedKeys
828 if tc == nil {
829 if ptype == packetTypeInitial {
830 p, _ := parseGenericLongHeaderPacket(buf)
831 k = initialKeys(p.srcConnID, serverSide).w
832 } else {
833 t.Fatalf("reading %v packet with no conn", ptype)
834 }
835 } else {
836 switch ptype {
837 case packetTypeInitial:
838 k = tc.keysInitial.r
839 case packetTypeHandshake:
840 k = tc.keysHandshake.r
841 }
842 }
843 if !k.isSet() {
844 t.Fatalf("reading %v packet with no read key", ptype)
845 }
846 var pnumMax packetNumber
847 p, n := parseLongHeaderPacket(buf, k, pnumMax)
848 if n < 0 {
849 t.Fatalf("packet parse error")
850 }
851 frames, err := parseTestFrames(t, p.payload)
852 if err != nil {
853 t.Fatal(err)
854 }
855 var token []byte
856 if ptype == packetTypeInitial && len(p.extra) > 0 {
857 token = p.extra
858 }
859 d.packets = append(d.packets, &testPacket{
860 ptype: p.ptype,
861 header: buf[0],
862 version: p.version,
863 num: p.num,
864 dstConnID: p.dstConnID,
865 srcConnID: p.srcConnID,
866 token: token,
867 frames: frames,
868 })
869 buf = buf[n:]
870 case packetType1RTT:
871 if tc == nil || !tc.rkeyAppData.hdr.isSet() {
872 t.Fatalf("reading 1-RTT packet with no read key")
873 }
874 var pnumMax packetNumber
875 pnumOff := 1 + len(tc.peerConnID)
876
877 var phase int
878 var pnum packetNumber
879 var hdr []byte
880 var pay []byte
881 var err error
882 for phase = 0; phase < maxTestKeyPhases; phase++ {
883 b := append([]byte{}, buf...)
884 hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
885 if err != nil {
886 t.Fatalf("1-RTT packet header parse error")
887 }
888 k := tc.rkeyAppData.pkt[phase]
889 pay, err = k.unprotect(hdr, pay, pnum)
890 if err == nil {
891 break
892 }
893 }
894 if err != nil {
895 t.Fatalf("1-RTT packet payload parse error")
896 }
897 frames, err := parseTestFrames(t, pay)
898 if err != nil {
899 t.Fatal(err)
900 }
901 d.packets = append(d.packets, &testPacket{
902 ptype: packetType1RTT,
903 header: hdr[0],
904 num: pnum,
905 dstConnID: hdr[1:][:len(tc.peerConnID)],
906 keyPhaseBit: hdr[0]&keyPhaseBit != 0,
907 keyNumber: phase,
908 frames: frames,
909 })
910 buf = buf[len(buf):]
911 default:
912 t.Fatalf("unhandled packet type %v", ptype)
913 }
914 }
915
916
917
918
919
920
921 if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 {
922 p := d.packets[len(d.packets)-1]
923 f := p.frames[len(p.frames)-1]
924 if _, ok := f.(debugFramePadding); ok {
925 p.frames = p.frames[:len(p.frames)-1]
926 d.paddedSize = size
927 }
928 }
929 return d
930 }
931
932 func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
933 t.Helper()
934 var frames []debugFrame
935 for len(payload) > 0 {
936 f, n := parseDebugFrame(payload)
937 if n < 0 {
938 return nil, errors.New("error parsing frames")
939 }
940 frames = append(frames, f)
941 payload = payload[n:]
942 }
943 return frames, nil
944 }
945
946 func spaceForPacketType(ptype packetType) numberSpace {
947 switch ptype {
948 case packetTypeInitial:
949 return initialSpace
950 case packetType0RTT:
951 panic("TODO: packetType0RTT")
952 case packetTypeHandshake:
953 return handshakeSpace
954 case packetTypeRetry:
955 panic("retry packets have no number space")
956 case packetType1RTT:
957 return appDataSpace
958 }
959 panic("unknown packet type")
960 }
961
962
963 type testConnHooks testConn
964
965 func (tc *testConnHooks) init() {
966 tc.conn.keysAppData.updateAfter = maxPacketNumber
967 tc.keysInitial.r = tc.conn.keysInitial.w
968 tc.keysInitial.w = tc.conn.keysInitial.r
969 if tc.conn.side == serverSide {
970 tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
971 }
972 }
973
974
975
976
977
978
979
980
981
982
983
984
985
986 func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
987 checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
988 var space numberSpace
989 switch {
990 case e.Level == tls.QUICEncryptionLevelHandshake:
991 space = handshakeSpace
992 case e.Level == tls.QUICEncryptionLevelApplication:
993 space = appDataSpace
994 default:
995 tc.t.Errorf("unexpected encryption level %v", e.Level)
996 return
997 }
998 if secrets[space].secret == nil {
999 secrets[space].suite = e.Suite
1000 secrets[space].secret = append([]byte{}, e.Data...)
1001 } else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
1002 tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
1003 }
1004 }
1005 setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
1006 k.hdr.init(suite, secret)
1007 for i := 0; i < len(k.pkt); i++ {
1008 k.pkt[i].init(suite, secret)
1009 secret = updateSecret(suite, secret)
1010 }
1011 }
1012 switch e.Kind {
1013 case tls.QUICSetReadSecret:
1014 checkKey("write", &tc.wsecrets, e)
1015 switch e.Level {
1016 case tls.QUICEncryptionLevelHandshake:
1017 tc.keysHandshake.w.init(e.Suite, e.Data)
1018 case tls.QUICEncryptionLevelApplication:
1019 setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
1020 }
1021 case tls.QUICSetWriteSecret:
1022 checkKey("read", &tc.rsecrets, e)
1023 switch e.Level {
1024 case tls.QUICEncryptionLevelHandshake:
1025 tc.keysHandshake.r.init(e.Suite, e.Data)
1026 case tls.QUICEncryptionLevelApplication:
1027 setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
1028 }
1029 case tls.QUICWriteData:
1030 tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
1031 tc.peerTLSConn.HandleData(e.Level, e.Data)
1032 }
1033 for {
1034 e := tc.peerTLSConn.NextEvent()
1035 switch e.Kind {
1036 case tls.QUICNoEvent:
1037 return
1038 case tls.QUICSetReadSecret:
1039 checkKey("write", &tc.rsecrets, e)
1040 switch e.Level {
1041 case tls.QUICEncryptionLevelHandshake:
1042 tc.keysHandshake.r.init(e.Suite, e.Data)
1043 case tls.QUICEncryptionLevelApplication:
1044 setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
1045 }
1046 case tls.QUICSetWriteSecret:
1047 checkKey("read", &tc.wsecrets, e)
1048 switch e.Level {
1049 case tls.QUICEncryptionLevelHandshake:
1050 tc.keysHandshake.w.init(e.Suite, e.Data)
1051 case tls.QUICEncryptionLevelApplication:
1052 setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
1053 }
1054 case tls.QUICWriteData:
1055 tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
1056 case tls.QUICTransportParameters:
1057 p, err := unmarshalTransportParams(e.Data)
1058 if err != nil {
1059 tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err)
1060 } else {
1061 tc.sentTransportParameters = &p
1062 }
1063 }
1064 }
1065 }
1066
1067
1068 func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
1069 tc.timer = timer
1070 for {
1071 if !timer.IsZero() && !timer.After(tc.endpoint.now) {
1072 if timer.Equal(tc.timerLastFired) {
1073
1074
1075
1076 tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
1077 } else {
1078 tc.timerLastFired = timer
1079 return tc.endpoint.now, timerEvent{}
1080 }
1081 }
1082 select {
1083 case m := <-msgc:
1084 return tc.endpoint.now, m
1085 default:
1086 }
1087 if !tc.wakeAsync() {
1088 break
1089 }
1090 }
1091
1092 if tc.idlec != nil {
1093 idlec := tc.idlec
1094 tc.idlec = nil
1095 close(idlec)
1096 }
1097 m = <-msgc
1098 return tc.endpoint.now, m
1099 }
1100
1101 func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
1102 return testLocalConnID(seq), nil
1103 }
1104
1105 func (tc *testConnHooks) timeNow() time.Time {
1106 return tc.endpoint.now
1107 }
1108
1109
1110
1111 func testLocalConnID(seq int64) []byte {
1112 cid := make([]byte, connIDLen)
1113 copy(cid, []byte{0xc0, 0xff, 0xee})
1114 cid[len(cid)-1] = byte(seq)
1115 return cid
1116 }
1117
1118
1119
1120 func testPeerConnID(seq int64) []byte {
1121
1122
1123 return []byte{0xbe, 0xee, 0xff, byte(seq)}
1124 }
1125
1126 func testPeerStatelessResetToken(seq int64) statelessResetToken {
1127 return statelessResetToken{
1128 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
1129 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
1130 }
1131 }
1132
1133
1134
1135
1136
1137
1138 func canceledContext() context.Context {
1139 ctx, cancel := context.WithCancel(context.Background())
1140 cancel()
1141 return ctx
1142 }
1143
View as plain text