// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build go1.21 package quic import ( "bytes" "context" "crypto/tls" "errors" "flag" "fmt" "log/slog" "math" "net/netip" "reflect" "strings" "testing" "time" "golang.org/x/net/internal/quic/qlog" ) var ( testVV = flag.Bool("vv", false, "even more verbose test output") qlogdir = flag.String("qlog", "", "write qlog logs to directory") ) func TestConnTestConn(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want { t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want) } ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { when = now }) return }).result() if !ranAt.Equal(tc.endpoint.now) { t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) } tc.wait() nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) tc.advanceTo(nextTime) ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { when = now }) return }).result() if !ranAt.Equal(nextTime) { t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime) } tc.wait() tc.advanceToTimer() if got := tc.conn.lifetime.state; got != connStateDone { t.Errorf("after advancing to idle timeout, conn state = %v, want done", got) } } type testDatagram struct { packets []*testPacket paddedSize int addr netip.AddrPort } func (d testDatagram) String() string { var b strings.Builder fmt.Fprintf(&b, "datagram with %v packets", len(d.packets)) if d.paddedSize > 0 { fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize) } b.WriteString(":") for _, p := range d.packets { b.WriteString("\n") b.WriteString(p.String()) } return b.String() } type testPacket struct { ptype packetType header byte version uint32 num packetNumber keyPhaseBit bool keyNumber int dstConnID []byte srcConnID []byte token []byte originalDstConnID []byte // used for encoding Retry packets frames []debugFrame } func (p testPacket) String() string { var b strings.Builder fmt.Fprintf(&b, " %v %v", p.ptype, p.num) if p.version != 0 { fmt.Fprintf(&b, " version=%v", p.version) } if p.srcConnID != nil { fmt.Fprintf(&b, " src={%x}", p.srcConnID) } if p.dstConnID != nil { fmt.Fprintf(&b, " dst={%x}", p.dstConnID) } if p.token != nil { fmt.Fprintf(&b, " token={%x}", p.token) } for _, f := range p.frames { fmt.Fprintf(&b, "\n %v", f) } return b.String() } // maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test. const maxTestKeyPhases = 3 // A testConn is a Conn whose external interactions (sending and receiving packets, // setting timers) can be manipulated in tests. type testConn struct { t *testing.T conn *Conn endpoint *testEndpoint timer time.Time timerLastFired time.Time idlec chan struct{} // only accessed on the conn's loop // Keys are distinct from the conn's keys, // because the test may know about keys before the conn does. // For example, when sending a datagram with coalesced // Initial and Handshake packets to a client conn, // we use Handshake keys to encrypt the packet. // The client only acquires those keys when it processes // the Initial packet. keysInitial fixedKeyPair keysHandshake fixedKeyPair rkeyAppData test1RTTKeys wkeyAppData test1RTTKeys rsecrets [numberSpaceCount]keySecret wsecrets [numberSpaceCount]keySecret // testConn uses a test hook to snoop on the conn's TLS events. // CRYPTO data produced by the conn's QUICConn is placed in // cryptoDataOut. // // The peerTLSConn is is a QUICConn representing the peer. // CRYPTO data produced by the conn is written to peerTLSConn, // and data produced by peerTLSConn is placed in cryptoDataIn. cryptoDataOut map[tls.QUICEncryptionLevel][]byte cryptoDataIn map[tls.QUICEncryptionLevel][]byte peerTLSConn *tls.QUICConn // Information about the conn's (fake) peer. peerConnID []byte // source conn id of peer's packets peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use // Datagrams, packets, and frames sent by the conn, // but not yet processed by the test. sentDatagrams [][]byte sentPackets []*testPacket sentFrames []debugFrame lastPacket *testPacket recvDatagram chan *datagram // Transport parameters sent by the conn. sentTransportParameters *transportParameters // Frame types to ignore in tests. ignoreFrames map[byte]bool // Values to set in packets sent to the conn. sendKeyNumber int sendKeyPhaseBit bool asyncTestState } type test1RTTKeys struct { hdr headerKey pkt [maxTestKeyPhases]packetKey } type keySecret struct { suite uint16 secret []byte } // newTestConn creates a Conn for testing. // // The Conn's event loop is controlled by the test, // allowing test code to access Conn state directly // by first ensuring the loop goroutine is idle. func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { t.Helper() config := &Config{ TLSConfig: newTestTLSConfig(side), StatelessResetKey: testStatelessResetKey, QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ Level: QLogLevelFrame, Dir: *qlogdir, })), } var cids newServerConnIDs if side == serverSide { // The initial connection ID for the server is chosen by the client. cids.srcConnID = testPeerConnID(0) cids.dstConnID = testPeerConnID(-1) cids.originalDstConnID = cids.dstConnID } var configTransportParams []func(*transportParameters) var configTestConn []func(*testConn) for _, o := range opts { switch o := o.(type) { case func(*Config): o(config) case func(*tls.Config): o(config.TLSConfig) case func(cids *newServerConnIDs): o(&cids) case func(p *transportParameters): configTransportParams = append(configTransportParams, o) case func(p *testConn): configTestConn = append(configTestConn, o) default: t.Fatalf("unknown newTestConn option %T", o) } } endpoint := newTestEndpoint(t, config) endpoint.configTransportParams = configTransportParams endpoint.configTestConn = configTestConn conn, err := endpoint.e.newConn( endpoint.now, side, cids, netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) } tc := endpoint.conns[conn] tc.wait() return tc } func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn { t.Helper() tc := &testConn{ t: t, endpoint: endpoint, conn: conn, peerConnID: testPeerConnID(0), ignoreFrames: map[byte]bool{ frameTypePadding: true, // ignore PADDING by default }, cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte), cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte), recvDatagram: make(chan *datagram), } t.Cleanup(tc.cleanup) for _, f := range endpoint.configTestConn { f(tc) } conn.testHooks = (*testConnHooks)(tc) if endpoint.peerTLSConn != nil { tc.peerTLSConn = endpoint.peerTLSConn endpoint.peerTLSConn = nil return tc } peerProvidedParams := defaultTransportParameters() peerProvidedParams.initialSrcConnID = testPeerConnID(0) if conn.side == clientSide { peerProvidedParams.originalDstConnID = testLocalConnID(-1) } for _, f := range endpoint.configTransportParams { f(&peerProvidedParams) } peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())} if conn.side == clientSide { tc.peerTLSConn = tls.QUICServer(peerQUICConfig) } else { tc.peerTLSConn = tls.QUICClient(peerQUICConfig) } tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) tc.peerTLSConn.Start(context.Background()) t.Cleanup(func() { tc.peerTLSConn.Close() }) return tc } // advance causes time to pass. func (tc *testConn) advance(d time.Duration) { tc.t.Helper() tc.endpoint.advance(d) } // advanceTo sets the current time. func (tc *testConn) advanceTo(now time.Time) { tc.t.Helper() tc.endpoint.advanceTo(now) } // advanceToTimer sets the current time to the time of the Conn's next timer event. func (tc *testConn) advanceToTimer() { if tc.timer.IsZero() { tc.t.Fatalf("advancing to timer, but timer is not set") } tc.advanceTo(tc.timer) } func (tc *testConn) timerDelay() time.Duration { if tc.timer.IsZero() { return math.MaxInt64 // infinite } if tc.timer.Before(tc.endpoint.now) { return 0 } return tc.timer.Sub(tc.endpoint.now) } const infiniteDuration = time.Duration(math.MaxInt64) // timeUntilEvent returns the amount of time until the next connection event. func (tc *testConn) timeUntilEvent() time.Duration { if tc.timer.IsZero() { return infiniteDuration } if tc.timer.Before(tc.endpoint.now) { return 0 } return tc.timer.Sub(tc.endpoint.now) } // wait blocks until the conn becomes idle. // The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire. // Tests shouldn't need to call wait directly. // testConn methods that wake the Conn event loop will call wait for them. func (tc *testConn) wait() { tc.t.Helper() idlec := make(chan struct{}) fail := false tc.conn.sendMsg(func(now time.Time, c *Conn) { if tc.idlec != nil { tc.t.Errorf("testConn.wait called concurrently") fail = true close(idlec) } else { // nextMessage will close idlec. tc.idlec = idlec } }) select { case <-idlec: case <-tc.conn.donec: // We may have async ops that can proceed now that the conn is done. tc.wakeAsync() } if fail { panic(fail) } } func (tc *testConn) cleanup() { if tc.conn == nil { return } tc.conn.exit() <-tc.conn.donec } func logDatagram(t *testing.T, text string, d *testDatagram) { t.Helper() if !*testVV { return } pad := "" if d.paddedSize > 0 { pad = fmt.Sprintf(" (padded to %v)", d.paddedSize) } t.Logf("%v datagram%v", text, pad) for _, p := range d.packets { var s string switch p.ptype { case packetType1RTT: s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num) default: s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID) } if p.token != nil { s += fmt.Sprintf(" token={%x}", p.token) } if p.keyPhaseBit { s += fmt.Sprintf(" KeyPhase") } if p.keyNumber != 0 { s += fmt.Sprintf(" keynum=%v", p.keyNumber) } t.Log(s) for _, f := range p.frames { t.Logf(" %v", f) } } } // write sends the Conn a datagram. func (tc *testConn) write(d *testDatagram) { tc.t.Helper() tc.endpoint.writeDatagram(d) } // writeFrame sends the Conn a datagram containing the given frames. func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { tc.t.Helper() space := spaceForPacketType(ptype) dstConnID := tc.conn.connIDState.local[0].cid if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial { // Only use the transient connection ID in Initial packets. dstConnID = tc.conn.connIDState.local[1].cid } d := &testDatagram{ packets: []*testPacket{{ ptype: ptype, num: tc.peerNextPacketNum[space], keyNumber: tc.sendKeyNumber, keyPhaseBit: tc.sendKeyPhaseBit, frames: frames, version: quicVersion1, dstConnID: dstConnID, srcConnID: tc.peerConnID, }}, } if ptype == packetTypeInitial && tc.conn.side == serverSide { d.paddedSize = 1200 } tc.write(d) } // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the // last one received. func (tc *testConn) writeAckForAll() { tc.t.Helper() if tc.lastPacket == nil { return } tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{ ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}}, }) } // writeAckForLatest sends the Conn a datagram containing an ack for the // most recent packet received. func (tc *testConn) writeAckForLatest() { tc.t.Helper() if tc.lastPacket == nil { return } tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{ ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}}, }) } // ignoreFrame hides frames of the given type sent by the Conn. func (tc *testConn) ignoreFrame(frameType byte) { tc.ignoreFrames[frameType] = true } // readDatagram reads the next datagram sent by the Conn. // It returns nil if the Conn has no more datagrams to send at this time. func (tc *testConn) readDatagram() *testDatagram { tc.t.Helper() tc.wait() tc.sentPackets = nil tc.sentFrames = nil buf := tc.endpoint.read() if buf == nil { return nil } d := parseTestDatagram(tc.t, tc.endpoint, tc, buf) // Log the datagram before removing ignored frames. // When things go wrong, it's useful to see all the frames. logDatagram(tc.t, "-> conn under test sends", d) typeForFrame := func(f debugFrame) byte { // This is very clunky, and points at a problem // in how we specify what frames to ignore in tests. // // We mark frames to ignore using the frame type, // but we've got a debugFrame data structure here. // Perhaps we should be ignoring frames by debugFrame // type instead: tc.ignoreFrame[debugFrameAck](). switch f := f.(type) { case debugFramePadding: return frameTypePadding case debugFramePing: return frameTypePing case debugFrameAck: return frameTypeAck case debugFrameResetStream: return frameTypeResetStream case debugFrameStopSending: return frameTypeStopSending case debugFrameCrypto: return frameTypeCrypto case debugFrameNewToken: return frameTypeNewToken case debugFrameStream: return frameTypeStreamBase case debugFrameMaxData: return frameTypeMaxData case debugFrameMaxStreamData: return frameTypeMaxStreamData case debugFrameMaxStreams: if f.streamType == bidiStream { return frameTypeMaxStreamsBidi } else { return frameTypeMaxStreamsUni } case debugFrameDataBlocked: return frameTypeDataBlocked case debugFrameStreamDataBlocked: return frameTypeStreamDataBlocked case debugFrameStreamsBlocked: if f.streamType == bidiStream { return frameTypeStreamsBlockedBidi } else { return frameTypeStreamsBlockedUni } case debugFrameNewConnectionID: return frameTypeNewConnectionID case debugFrameRetireConnectionID: return frameTypeRetireConnectionID case debugFramePathChallenge: return frameTypePathChallenge case debugFramePathResponse: return frameTypePathResponse case debugFrameConnectionCloseTransport: return frameTypeConnectionCloseTransport case debugFrameConnectionCloseApplication: return frameTypeConnectionCloseApplication case debugFrameHandshakeDone: return frameTypeHandshakeDone } panic(fmt.Errorf("unhandled frame type %T", f)) } for _, p := range d.packets { var frames []debugFrame for _, f := range p.frames { if !tc.ignoreFrames[typeForFrame(f)] { frames = append(frames, f) } } p.frames = frames } return d } // readPacket reads the next packet sent by the Conn. // It returns nil if the Conn has no more packets to send at this time. func (tc *testConn) readPacket() *testPacket { tc.t.Helper() for len(tc.sentPackets) == 0 { d := tc.readDatagram() if d == nil { return nil } for _, p := range d.packets { if len(p.frames) == 0 { tc.lastPacket = p continue } tc.sentPackets = append(tc.sentPackets, p) } } p := tc.sentPackets[0] tc.sentPackets = tc.sentPackets[1:] tc.lastPacket = p return p } // readFrame reads the next frame sent by the Conn. // It returns nil if the Conn has no more frames to send at this time. func (tc *testConn) readFrame() (debugFrame, packetType) { tc.t.Helper() for len(tc.sentFrames) == 0 { p := tc.readPacket() if p == nil { return nil, packetTypeInvalid } tc.sentFrames = p.frames } f := tc.sentFrames[0] tc.sentFrames = tc.sentFrames[1:] return f, tc.lastPacket.ptype } // wantDatagram indicates that we expect the Conn to send a datagram. func (tc *testConn) wantDatagram(expectation string, want *testDatagram) { tc.t.Helper() got := tc.readDatagram() if !datagramEqual(got, want) { tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } func datagramEqual(a, b *testDatagram) bool { if a == nil && b == nil { return true } if a == nil || b == nil { return false } if a.paddedSize != b.paddedSize || a.addr != b.addr || len(a.packets) != len(b.packets) { return false } for i := range a.packets { if !packetEqual(a.packets[i], b.packets[i]) { return false } } return true } // wantPacket indicates that we expect the Conn to send a packet. func (tc *testConn) wantPacket(expectation string, want *testPacket) { tc.t.Helper() got := tc.readPacket() if !packetEqual(got, want) { tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want) } } func packetEqual(a, b *testPacket) bool { ac := *a ac.frames = nil ac.header = 0 bc := *b bc.frames = nil bc.header = 0 if !reflect.DeepEqual(ac, bc) { return false } if len(a.frames) != len(b.frames) { return false } for i := range a.frames { if !frameEqual(a.frames[i], b.frames[i]) { return false } } return true } // wantFrame indicates that we expect the Conn to send a frame. func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) { tc.t.Helper() got, gotType := tc.readFrame() if got == nil { tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) } if gotType != wantType { tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) } if !frameEqual(got, want) { tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want) } } func frameEqual(a, b debugFrame) bool { switch af := a.(type) { case debugFrameConnectionCloseTransport: bf, ok := b.(debugFrameConnectionCloseTransport) return ok && af.code == bf.code } return reflect.DeepEqual(a, b) } // wantFrameType indicates that we expect the Conn to send a frame, // although we don't care about the contents. func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) { tc.t.Helper() got, gotType := tc.readFrame() if got == nil { tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) } if gotType != wantType { tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) } if reflect.TypeOf(got) != reflect.TypeOf(want) { tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want) } } // wantIdle indicates that we expect the Conn to not send any more frames. func (tc *testConn) wantIdle(expectation string) { tc.t.Helper() switch { case len(tc.sentFrames) > 0: tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0]) case len(tc.sentPackets) > 0: tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0]) } if f, _ := tc.readFrame(); f != nil { tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f) } } func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte { t.Helper() var w packetWriter w.reset(1200) var pnumMaxAcked packetNumber switch p.ptype { case packetTypeRetry: return encodeRetryPacket(p.originalDstConnID, retryPacket{ srcConnID: p.srcConnID, dstConnID: p.dstConnID, token: p.token, }) case packetType1RTT: w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID) default: w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{ ptype: p.ptype, version: p.version, num: p.num, dstConnID: p.dstConnID, srcConnID: p.srcConnID, extra: p.token, }) } for _, f := range p.frames { f.write(&w) } w.appendPaddingTo(pad) if p.ptype != packetType1RTT { var k fixedKeys if tc == nil { if p.ptype == packetTypeInitial { k = initialKeys(p.dstConnID, serverSide).r } else { t.Fatalf("sending %v packet with no conn", p.ptype) } } else { switch p.ptype { case packetTypeInitial: k = tc.keysInitial.w case packetTypeHandshake: k = tc.keysHandshake.w } } if !k.isSet() { t.Fatalf("sending %v packet with no write key", p.ptype) } w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{ ptype: p.ptype, version: p.version, num: p.num, dstConnID: p.dstConnID, srcConnID: p.srcConnID, extra: p.token, }) } else { if tc == nil || !tc.wkeyAppData.hdr.isSet() { t.Fatalf("sending 1-RTT packet with no write key") } // Somewhat hackish: Generate a temporary updatingKeyPair that will // always use our desired key phase. k := &updatingKeyPair{ w: updatingKeys{ hdr: tc.wkeyAppData.hdr, pkt: [2]packetKey{ tc.wkeyAppData.pkt[p.keyNumber], tc.wkeyAppData.pkt[p.keyNumber], }, }, updateAfter: maxPacketNumber, } if p.keyPhaseBit { k.phase |= keyPhaseBit } w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k) } return w.datagram() } func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram { t.Helper() bufSize := len(buf) d := &testDatagram{} size := len(buf) for len(buf) > 0 { if buf[0] == 0 { d.paddedSize = bufSize break } ptype := getPacketType(buf) switch ptype { case packetTypeRetry: retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID) if !ok { t.Fatalf("could not parse %v packet", ptype) } return &testDatagram{ packets: []*testPacket{{ ptype: packetTypeRetry, dstConnID: retry.dstConnID, srcConnID: retry.srcConnID, token: retry.token, }}, } case packetTypeInitial, packetTypeHandshake: var k fixedKeys if tc == nil { if ptype == packetTypeInitial { p, _ := parseGenericLongHeaderPacket(buf) k = initialKeys(p.srcConnID, serverSide).w } else { t.Fatalf("reading %v packet with no conn", ptype) } } else { switch ptype { case packetTypeInitial: k = tc.keysInitial.r case packetTypeHandshake: k = tc.keysHandshake.r } } if !k.isSet() { t.Fatalf("reading %v packet with no read key", ptype) } var pnumMax packetNumber // TODO: Track packet numbers. p, n := parseLongHeaderPacket(buf, k, pnumMax) if n < 0 { t.Fatalf("packet parse error") } frames, err := parseTestFrames(t, p.payload) if err != nil { t.Fatal(err) } var token []byte if ptype == packetTypeInitial && len(p.extra) > 0 { token = p.extra } d.packets = append(d.packets, &testPacket{ ptype: p.ptype, header: buf[0], version: p.version, num: p.num, dstConnID: p.dstConnID, srcConnID: p.srcConnID, token: token, frames: frames, }) buf = buf[n:] case packetType1RTT: if tc == nil || !tc.rkeyAppData.hdr.isSet() { t.Fatalf("reading 1-RTT packet with no read key") } var pnumMax packetNumber // TODO: Track packet numbers. pnumOff := 1 + len(tc.peerConnID) // Try unprotecting the packet with the first maxTestKeyPhases keys. var phase int var pnum packetNumber var hdr []byte var pay []byte var err error for phase = 0; phase < maxTestKeyPhases; phase++ { b := append([]byte{}, buf...) hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax) if err != nil { t.Fatalf("1-RTT packet header parse error") } k := tc.rkeyAppData.pkt[phase] pay, err = k.unprotect(hdr, pay, pnum) if err == nil { break } } if err != nil { t.Fatalf("1-RTT packet payload parse error") } frames, err := parseTestFrames(t, pay) if err != nil { t.Fatal(err) } d.packets = append(d.packets, &testPacket{ ptype: packetType1RTT, header: hdr[0], num: pnum, dstConnID: hdr[1:][:len(tc.peerConnID)], keyPhaseBit: hdr[0]&keyPhaseBit != 0, keyNumber: phase, frames: frames, }) buf = buf[len(buf):] default: t.Fatalf("unhandled packet type %v", ptype) } } // This is rather hackish: If the last frame in the last packet // in the datagram is PADDING, then remove it and record // the padded size in the testDatagram.paddedSize. // // This makes it easier to write a test that expects a datagram // padded to 1200 bytes. if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 { p := d.packets[len(d.packets)-1] f := p.frames[len(p.frames)-1] if _, ok := f.(debugFramePadding); ok { p.frames = p.frames[:len(p.frames)-1] d.paddedSize = size } } return d } func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) { t.Helper() var frames []debugFrame for len(payload) > 0 { f, n := parseDebugFrame(payload) if n < 0 { return nil, errors.New("error parsing frames") } frames = append(frames, f) payload = payload[n:] } return frames, nil } func spaceForPacketType(ptype packetType) numberSpace { switch ptype { case packetTypeInitial: return initialSpace case packetType0RTT: panic("TODO: packetType0RTT") case packetTypeHandshake: return handshakeSpace case packetTypeRetry: panic("retry packets have no number space") case packetType1RTT: return appDataSpace } panic("unknown packet type") } // testConnHooks implements connTestHooks. type testConnHooks testConn func (tc *testConnHooks) init() { tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates tc.keysInitial.r = tc.conn.keysInitial.w tc.keysInitial.w = tc.conn.keysInitial.r if tc.conn.side == serverSide { tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc)) } } // handleTLSEvent processes TLS events generated by // the connection under test's tls.QUICConn. // // We maintain a second tls.QUICConn representing the peer, // and feed the TLS handshake data into it. // // We stash TLS handshake data from both sides in the testConn, // where it can be used by tests. // // We snoop packet protection keys out of the tls.QUICConns, // and verify that both sides of the connection are getting // matching keys. func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) { var space numberSpace switch { case e.Level == tls.QUICEncryptionLevelHandshake: space = handshakeSpace case e.Level == tls.QUICEncryptionLevelApplication: space = appDataSpace default: tc.t.Errorf("unexpected encryption level %v", e.Level) return } if secrets[space].secret == nil { secrets[space].suite = e.Suite secrets[space].secret = append([]byte{}, e.Data...) } else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) { tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level) } } setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) { k.hdr.init(suite, secret) for i := 0; i < len(k.pkt); i++ { k.pkt[i].init(suite, secret) secret = updateSecret(suite, secret) } } switch e.Kind { case tls.QUICSetReadSecret: checkKey("write", &tc.wsecrets, e) switch e.Level { case tls.QUICEncryptionLevelHandshake: tc.keysHandshake.w.init(e.Suite, e.Data) case tls.QUICEncryptionLevelApplication: setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) } case tls.QUICSetWriteSecret: checkKey("read", &tc.rsecrets, e) switch e.Level { case tls.QUICEncryptionLevelHandshake: tc.keysHandshake.r.init(e.Suite, e.Data) case tls.QUICEncryptionLevelApplication: setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) } case tls.QUICWriteData: tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...) tc.peerTLSConn.HandleData(e.Level, e.Data) } for { e := tc.peerTLSConn.NextEvent() switch e.Kind { case tls.QUICNoEvent: return case tls.QUICSetReadSecret: checkKey("write", &tc.rsecrets, e) switch e.Level { case tls.QUICEncryptionLevelHandshake: tc.keysHandshake.r.init(e.Suite, e.Data) case tls.QUICEncryptionLevelApplication: setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) } case tls.QUICSetWriteSecret: checkKey("read", &tc.wsecrets, e) switch e.Level { case tls.QUICEncryptionLevelHandshake: tc.keysHandshake.w.init(e.Suite, e.Data) case tls.QUICEncryptionLevelApplication: setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) } case tls.QUICWriteData: tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...) case tls.QUICTransportParameters: p, err := unmarshalTransportParams(e.Data) if err != nil { tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err) } else { tc.sentTransportParameters = &p } } } } // nextMessage is called by the Conn's event loop to request its next event. func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { tc.timer = timer for { if !timer.IsZero() && !timer.After(tc.endpoint.now) { if timer.Equal(tc.timerLastFired) { // If the connection timer fires at time T, the Conn should take some // action to advance the timer into the future. If the Conn reschedules // the timer for the same time, it isn't making progress and we have a bug. tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer) } else { tc.timerLastFired = timer return tc.endpoint.now, timerEvent{} } } select { case m := <-msgc: return tc.endpoint.now, m default: } if !tc.wakeAsync() { break } } // If the message queue is empty, then the conn is idle. if tc.idlec != nil { idlec := tc.idlec tc.idlec = nil close(idlec) } m = <-msgc return tc.endpoint.now, m } func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { return testLocalConnID(seq), nil } func (tc *testConnHooks) timeNow() time.Time { return tc.endpoint.now } // testLocalConnID returns the connection ID with a given sequence number // used by a Conn under test. func testLocalConnID(seq int64) []byte { cid := make([]byte, connIDLen) copy(cid, []byte{0xc0, 0xff, 0xee}) cid[len(cid)-1] = byte(seq) return cid } // testPeerConnID returns the connection ID with a given sequence number // used by the fake peer of a Conn under test. func testPeerConnID(seq int64) []byte { // Use a different length than we choose for our own conn ids, // to help catch any bad assumptions. return []byte{0xbe, 0xee, 0xff, byte(seq)} } func testPeerStatelessResetToken(seq int64) statelessResetToken { return statelessResetToken{ 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq), } } // canceledContext returns a canceled Context. // // Functions which take a context preference progress over cancelation. // For example, a read with a canceled context will return data if any is available. // Tests use canceled contexts to perform non-blocking operations. func canceledContext() context.Context { ctx, cancel := context.WithCancel(context.Background()) cancel() return ctx }