...

Source file src/golang.org/x/net/internal/quic/conn_test.go

Documentation: golang.org/x/net/internal/quic

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"errors"
    14  	"flag"
    15  	"fmt"
    16  	"log/slog"
    17  	"math"
    18  	"net/netip"
    19  	"reflect"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"golang.org/x/net/internal/quic/qlog"
    25  )
    26  
    27  var (
    28  	testVV  = flag.Bool("vv", false, "even more verbose test output")
    29  	qlogdir = flag.String("qlog", "", "write qlog logs to directory")
    30  )
    31  
    32  func TestConnTestConn(t *testing.T) {
    33  	tc := newTestConn(t, serverSide)
    34  	tc.handshake()
    35  	if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
    36  		t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
    37  	}
    38  
    39  	ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
    40  		tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
    41  			when = now
    42  		})
    43  		return
    44  	}).result()
    45  	if !ranAt.Equal(tc.endpoint.now) {
    46  		t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
    47  	}
    48  	tc.wait()
    49  
    50  	nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
    51  	tc.advanceTo(nextTime)
    52  	ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
    53  		tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
    54  			when = now
    55  		})
    56  		return
    57  	}).result()
    58  	if !ranAt.Equal(nextTime) {
    59  		t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
    60  	}
    61  	tc.wait()
    62  
    63  	tc.advanceToTimer()
    64  	if got := tc.conn.lifetime.state; got != connStateDone {
    65  		t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
    66  	}
    67  }
    68  
    69  type testDatagram struct {
    70  	packets    []*testPacket
    71  	paddedSize int
    72  	addr       netip.AddrPort
    73  }
    74  
    75  func (d testDatagram) String() string {
    76  	var b strings.Builder
    77  	fmt.Fprintf(&b, "datagram with %v packets", len(d.packets))
    78  	if d.paddedSize > 0 {
    79  		fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize)
    80  	}
    81  	b.WriteString(":")
    82  	for _, p := range d.packets {
    83  		b.WriteString("\n")
    84  		b.WriteString(p.String())
    85  	}
    86  	return b.String()
    87  }
    88  
    89  type testPacket struct {
    90  	ptype             packetType
    91  	header            byte
    92  	version           uint32
    93  	num               packetNumber
    94  	keyPhaseBit       bool
    95  	keyNumber         int
    96  	dstConnID         []byte
    97  	srcConnID         []byte
    98  	token             []byte
    99  	originalDstConnID []byte // used for encoding Retry packets
   100  	frames            []debugFrame
   101  }
   102  
   103  func (p testPacket) String() string {
   104  	var b strings.Builder
   105  	fmt.Fprintf(&b, "  %v %v", p.ptype, p.num)
   106  	if p.version != 0 {
   107  		fmt.Fprintf(&b, " version=%v", p.version)
   108  	}
   109  	if p.srcConnID != nil {
   110  		fmt.Fprintf(&b, " src={%x}", p.srcConnID)
   111  	}
   112  	if p.dstConnID != nil {
   113  		fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
   114  	}
   115  	if p.token != nil {
   116  		fmt.Fprintf(&b, " token={%x}", p.token)
   117  	}
   118  	for _, f := range p.frames {
   119  		fmt.Fprintf(&b, "\n    %v", f)
   120  	}
   121  	return b.String()
   122  }
   123  
   124  // maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
   125  const maxTestKeyPhases = 3
   126  
   127  // A testConn is a Conn whose external interactions (sending and receiving packets,
   128  // setting timers) can be manipulated in tests.
   129  type testConn struct {
   130  	t              *testing.T
   131  	conn           *Conn
   132  	endpoint       *testEndpoint
   133  	timer          time.Time
   134  	timerLastFired time.Time
   135  	idlec          chan struct{} // only accessed on the conn's loop
   136  
   137  	// Keys are distinct from the conn's keys,
   138  	// because the test may know about keys before the conn does.
   139  	// For example, when sending a datagram with coalesced
   140  	// Initial and Handshake packets to a client conn,
   141  	// we use Handshake keys to encrypt the packet.
   142  	// The client only acquires those keys when it processes
   143  	// the Initial packet.
   144  	keysInitial   fixedKeyPair
   145  	keysHandshake fixedKeyPair
   146  	rkeyAppData   test1RTTKeys
   147  	wkeyAppData   test1RTTKeys
   148  	rsecrets      [numberSpaceCount]keySecret
   149  	wsecrets      [numberSpaceCount]keySecret
   150  
   151  	// testConn uses a test hook to snoop on the conn's TLS events.
   152  	// CRYPTO data produced by the conn's QUICConn is placed in
   153  	// cryptoDataOut.
   154  	//
   155  	// The peerTLSConn is is a QUICConn representing the peer.
   156  	// CRYPTO data produced by the conn is written to peerTLSConn,
   157  	// and data produced by peerTLSConn is placed in cryptoDataIn.
   158  	cryptoDataOut map[tls.QUICEncryptionLevel][]byte
   159  	cryptoDataIn  map[tls.QUICEncryptionLevel][]byte
   160  	peerTLSConn   *tls.QUICConn
   161  
   162  	// Information about the conn's (fake) peer.
   163  	peerConnID        []byte                         // source conn id of peer's packets
   164  	peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use
   165  
   166  	// Datagrams, packets, and frames sent by the conn,
   167  	// but not yet processed by the test.
   168  	sentDatagrams [][]byte
   169  	sentPackets   []*testPacket
   170  	sentFrames    []debugFrame
   171  	lastPacket    *testPacket
   172  
   173  	recvDatagram chan *datagram
   174  
   175  	// Transport parameters sent by the conn.
   176  	sentTransportParameters *transportParameters
   177  
   178  	// Frame types to ignore in tests.
   179  	ignoreFrames map[byte]bool
   180  
   181  	// Values to set in packets sent to the conn.
   182  	sendKeyNumber   int
   183  	sendKeyPhaseBit bool
   184  
   185  	asyncTestState
   186  }
   187  
   188  type test1RTTKeys struct {
   189  	hdr headerKey
   190  	pkt [maxTestKeyPhases]packetKey
   191  }
   192  
   193  type keySecret struct {
   194  	suite  uint16
   195  	secret []byte
   196  }
   197  
   198  // newTestConn creates a Conn for testing.
   199  //
   200  // The Conn's event loop is controlled by the test,
   201  // allowing test code to access Conn state directly
   202  // by first ensuring the loop goroutine is idle.
   203  func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
   204  	t.Helper()
   205  	config := &Config{
   206  		TLSConfig:         newTestTLSConfig(side),
   207  		StatelessResetKey: testStatelessResetKey,
   208  		QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
   209  			Level: QLogLevelFrame,
   210  			Dir:   *qlogdir,
   211  		})),
   212  	}
   213  	var cids newServerConnIDs
   214  	if side == serverSide {
   215  		// The initial connection ID for the server is chosen by the client.
   216  		cids.srcConnID = testPeerConnID(0)
   217  		cids.dstConnID = testPeerConnID(-1)
   218  		cids.originalDstConnID = cids.dstConnID
   219  	}
   220  	var configTransportParams []func(*transportParameters)
   221  	var configTestConn []func(*testConn)
   222  	for _, o := range opts {
   223  		switch o := o.(type) {
   224  		case func(*Config):
   225  			o(config)
   226  		case func(*tls.Config):
   227  			o(config.TLSConfig)
   228  		case func(cids *newServerConnIDs):
   229  			o(&cids)
   230  		case func(p *transportParameters):
   231  			configTransportParams = append(configTransportParams, o)
   232  		case func(p *testConn):
   233  			configTestConn = append(configTestConn, o)
   234  		default:
   235  			t.Fatalf("unknown newTestConn option %T", o)
   236  		}
   237  	}
   238  
   239  	endpoint := newTestEndpoint(t, config)
   240  	endpoint.configTransportParams = configTransportParams
   241  	endpoint.configTestConn = configTestConn
   242  	conn, err := endpoint.e.newConn(
   243  		endpoint.now,
   244  		side,
   245  		cids,
   246  		netip.MustParseAddrPort("127.0.0.1:443"))
   247  	if err != nil {
   248  		t.Fatal(err)
   249  	}
   250  	tc := endpoint.conns[conn]
   251  	tc.wait()
   252  	return tc
   253  }
   254  
   255  func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
   256  	t.Helper()
   257  	tc := &testConn{
   258  		t:          t,
   259  		endpoint:   endpoint,
   260  		conn:       conn,
   261  		peerConnID: testPeerConnID(0),
   262  		ignoreFrames: map[byte]bool{
   263  			frameTypePadding: true, // ignore PADDING by default
   264  		},
   265  		cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
   266  		cryptoDataIn:  make(map[tls.QUICEncryptionLevel][]byte),
   267  		recvDatagram:  make(chan *datagram),
   268  	}
   269  	t.Cleanup(tc.cleanup)
   270  	for _, f := range endpoint.configTestConn {
   271  		f(tc)
   272  	}
   273  	conn.testHooks = (*testConnHooks)(tc)
   274  
   275  	if endpoint.peerTLSConn != nil {
   276  		tc.peerTLSConn = endpoint.peerTLSConn
   277  		endpoint.peerTLSConn = nil
   278  		return tc
   279  	}
   280  
   281  	peerProvidedParams := defaultTransportParameters()
   282  	peerProvidedParams.initialSrcConnID = testPeerConnID(0)
   283  	if conn.side == clientSide {
   284  		peerProvidedParams.originalDstConnID = testLocalConnID(-1)
   285  	}
   286  	for _, f := range endpoint.configTransportParams {
   287  		f(&peerProvidedParams)
   288  	}
   289  
   290  	peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
   291  	if conn.side == clientSide {
   292  		tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
   293  	} else {
   294  		tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
   295  	}
   296  	tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
   297  	tc.peerTLSConn.Start(context.Background())
   298  	t.Cleanup(func() {
   299  		tc.peerTLSConn.Close()
   300  	})
   301  
   302  	return tc
   303  }
   304  
   305  // advance causes time to pass.
   306  func (tc *testConn) advance(d time.Duration) {
   307  	tc.t.Helper()
   308  	tc.endpoint.advance(d)
   309  }
   310  
   311  // advanceTo sets the current time.
   312  func (tc *testConn) advanceTo(now time.Time) {
   313  	tc.t.Helper()
   314  	tc.endpoint.advanceTo(now)
   315  }
   316  
   317  // advanceToTimer sets the current time to the time of the Conn's next timer event.
   318  func (tc *testConn) advanceToTimer() {
   319  	if tc.timer.IsZero() {
   320  		tc.t.Fatalf("advancing to timer, but timer is not set")
   321  	}
   322  	tc.advanceTo(tc.timer)
   323  }
   324  
   325  func (tc *testConn) timerDelay() time.Duration {
   326  	if tc.timer.IsZero() {
   327  		return math.MaxInt64 // infinite
   328  	}
   329  	if tc.timer.Before(tc.endpoint.now) {
   330  		return 0
   331  	}
   332  	return tc.timer.Sub(tc.endpoint.now)
   333  }
   334  
   335  const infiniteDuration = time.Duration(math.MaxInt64)
   336  
   337  // timeUntilEvent returns the amount of time until the next connection event.
   338  func (tc *testConn) timeUntilEvent() time.Duration {
   339  	if tc.timer.IsZero() {
   340  		return infiniteDuration
   341  	}
   342  	if tc.timer.Before(tc.endpoint.now) {
   343  		return 0
   344  	}
   345  	return tc.timer.Sub(tc.endpoint.now)
   346  }
   347  
   348  // wait blocks until the conn becomes idle.
   349  // The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire.
   350  // Tests shouldn't need to call wait directly.
   351  // testConn methods that wake the Conn event loop will call wait for them.
   352  func (tc *testConn) wait() {
   353  	tc.t.Helper()
   354  	idlec := make(chan struct{})
   355  	fail := false
   356  	tc.conn.sendMsg(func(now time.Time, c *Conn) {
   357  		if tc.idlec != nil {
   358  			tc.t.Errorf("testConn.wait called concurrently")
   359  			fail = true
   360  			close(idlec)
   361  		} else {
   362  			// nextMessage will close idlec.
   363  			tc.idlec = idlec
   364  		}
   365  	})
   366  	select {
   367  	case <-idlec:
   368  	case <-tc.conn.donec:
   369  		// We may have async ops that can proceed now that the conn is done.
   370  		tc.wakeAsync()
   371  	}
   372  	if fail {
   373  		panic(fail)
   374  	}
   375  }
   376  
   377  func (tc *testConn) cleanup() {
   378  	if tc.conn == nil {
   379  		return
   380  	}
   381  	tc.conn.exit()
   382  	<-tc.conn.donec
   383  }
   384  
   385  func logDatagram(t *testing.T, text string, d *testDatagram) {
   386  	t.Helper()
   387  	if !*testVV {
   388  		return
   389  	}
   390  	pad := ""
   391  	if d.paddedSize > 0 {
   392  		pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
   393  	}
   394  	t.Logf("%v datagram%v", text, pad)
   395  	for _, p := range d.packets {
   396  		var s string
   397  		switch p.ptype {
   398  		case packetType1RTT:
   399  			s = fmt.Sprintf("  %v pnum=%v", p.ptype, p.num)
   400  		default:
   401  			s = fmt.Sprintf("  %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
   402  		}
   403  		if p.token != nil {
   404  			s += fmt.Sprintf(" token={%x}", p.token)
   405  		}
   406  		if p.keyPhaseBit {
   407  			s += fmt.Sprintf(" KeyPhase")
   408  		}
   409  		if p.keyNumber != 0 {
   410  			s += fmt.Sprintf(" keynum=%v", p.keyNumber)
   411  		}
   412  		t.Log(s)
   413  		for _, f := range p.frames {
   414  			t.Logf("    %v", f)
   415  		}
   416  	}
   417  }
   418  
   419  // write sends the Conn a datagram.
   420  func (tc *testConn) write(d *testDatagram) {
   421  	tc.t.Helper()
   422  	tc.endpoint.writeDatagram(d)
   423  }
   424  
   425  // writeFrame sends the Conn a datagram containing the given frames.
   426  func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
   427  	tc.t.Helper()
   428  	space := spaceForPacketType(ptype)
   429  	dstConnID := tc.conn.connIDState.local[0].cid
   430  	if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial {
   431  		// Only use the transient connection ID in Initial packets.
   432  		dstConnID = tc.conn.connIDState.local[1].cid
   433  	}
   434  	d := &testDatagram{
   435  		packets: []*testPacket{{
   436  			ptype:       ptype,
   437  			num:         tc.peerNextPacketNum[space],
   438  			keyNumber:   tc.sendKeyNumber,
   439  			keyPhaseBit: tc.sendKeyPhaseBit,
   440  			frames:      frames,
   441  			version:     quicVersion1,
   442  			dstConnID:   dstConnID,
   443  			srcConnID:   tc.peerConnID,
   444  		}},
   445  	}
   446  	if ptype == packetTypeInitial && tc.conn.side == serverSide {
   447  		d.paddedSize = 1200
   448  	}
   449  	tc.write(d)
   450  }
   451  
   452  // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
   453  // last one received.
   454  func (tc *testConn) writeAckForAll() {
   455  	tc.t.Helper()
   456  	if tc.lastPacket == nil {
   457  		return
   458  	}
   459  	tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
   460  		ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
   461  	})
   462  }
   463  
   464  // writeAckForLatest sends the Conn a datagram containing an ack for the
   465  // most recent packet received.
   466  func (tc *testConn) writeAckForLatest() {
   467  	tc.t.Helper()
   468  	if tc.lastPacket == nil {
   469  		return
   470  	}
   471  	tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
   472  		ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}},
   473  	})
   474  }
   475  
   476  // ignoreFrame hides frames of the given type sent by the Conn.
   477  func (tc *testConn) ignoreFrame(frameType byte) {
   478  	tc.ignoreFrames[frameType] = true
   479  }
   480  
   481  // readDatagram reads the next datagram sent by the Conn.
   482  // It returns nil if the Conn has no more datagrams to send at this time.
   483  func (tc *testConn) readDatagram() *testDatagram {
   484  	tc.t.Helper()
   485  	tc.wait()
   486  	tc.sentPackets = nil
   487  	tc.sentFrames = nil
   488  	buf := tc.endpoint.read()
   489  	if buf == nil {
   490  		return nil
   491  	}
   492  	d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
   493  	// Log the datagram before removing ignored frames.
   494  	// When things go wrong, it's useful to see all the frames.
   495  	logDatagram(tc.t, "-> conn under test sends", d)
   496  	typeForFrame := func(f debugFrame) byte {
   497  		// This is very clunky, and points at a problem
   498  		// in how we specify what frames to ignore in tests.
   499  		//
   500  		// We mark frames to ignore using the frame type,
   501  		// but we've got a debugFrame data structure here.
   502  		// Perhaps we should be ignoring frames by debugFrame
   503  		// type instead: tc.ignoreFrame[debugFrameAck]().
   504  		switch f := f.(type) {
   505  		case debugFramePadding:
   506  			return frameTypePadding
   507  		case debugFramePing:
   508  			return frameTypePing
   509  		case debugFrameAck:
   510  			return frameTypeAck
   511  		case debugFrameResetStream:
   512  			return frameTypeResetStream
   513  		case debugFrameStopSending:
   514  			return frameTypeStopSending
   515  		case debugFrameCrypto:
   516  			return frameTypeCrypto
   517  		case debugFrameNewToken:
   518  			return frameTypeNewToken
   519  		case debugFrameStream:
   520  			return frameTypeStreamBase
   521  		case debugFrameMaxData:
   522  			return frameTypeMaxData
   523  		case debugFrameMaxStreamData:
   524  			return frameTypeMaxStreamData
   525  		case debugFrameMaxStreams:
   526  			if f.streamType == bidiStream {
   527  				return frameTypeMaxStreamsBidi
   528  			} else {
   529  				return frameTypeMaxStreamsUni
   530  			}
   531  		case debugFrameDataBlocked:
   532  			return frameTypeDataBlocked
   533  		case debugFrameStreamDataBlocked:
   534  			return frameTypeStreamDataBlocked
   535  		case debugFrameStreamsBlocked:
   536  			if f.streamType == bidiStream {
   537  				return frameTypeStreamsBlockedBidi
   538  			} else {
   539  				return frameTypeStreamsBlockedUni
   540  			}
   541  		case debugFrameNewConnectionID:
   542  			return frameTypeNewConnectionID
   543  		case debugFrameRetireConnectionID:
   544  			return frameTypeRetireConnectionID
   545  		case debugFramePathChallenge:
   546  			return frameTypePathChallenge
   547  		case debugFramePathResponse:
   548  			return frameTypePathResponse
   549  		case debugFrameConnectionCloseTransport:
   550  			return frameTypeConnectionCloseTransport
   551  		case debugFrameConnectionCloseApplication:
   552  			return frameTypeConnectionCloseApplication
   553  		case debugFrameHandshakeDone:
   554  			return frameTypeHandshakeDone
   555  		}
   556  		panic(fmt.Errorf("unhandled frame type %T", f))
   557  	}
   558  	for _, p := range d.packets {
   559  		var frames []debugFrame
   560  		for _, f := range p.frames {
   561  			if !tc.ignoreFrames[typeForFrame(f)] {
   562  				frames = append(frames, f)
   563  			}
   564  		}
   565  		p.frames = frames
   566  	}
   567  	return d
   568  }
   569  
   570  // readPacket reads the next packet sent by the Conn.
   571  // It returns nil if the Conn has no more packets to send at this time.
   572  func (tc *testConn) readPacket() *testPacket {
   573  	tc.t.Helper()
   574  	for len(tc.sentPackets) == 0 {
   575  		d := tc.readDatagram()
   576  		if d == nil {
   577  			return nil
   578  		}
   579  		for _, p := range d.packets {
   580  			if len(p.frames) == 0 {
   581  				tc.lastPacket = p
   582  				continue
   583  			}
   584  			tc.sentPackets = append(tc.sentPackets, p)
   585  		}
   586  	}
   587  	p := tc.sentPackets[0]
   588  	tc.sentPackets = tc.sentPackets[1:]
   589  	tc.lastPacket = p
   590  	return p
   591  }
   592  
   593  // readFrame reads the next frame sent by the Conn.
   594  // It returns nil if the Conn has no more frames to send at this time.
   595  func (tc *testConn) readFrame() (debugFrame, packetType) {
   596  	tc.t.Helper()
   597  	for len(tc.sentFrames) == 0 {
   598  		p := tc.readPacket()
   599  		if p == nil {
   600  			return nil, packetTypeInvalid
   601  		}
   602  		tc.sentFrames = p.frames
   603  	}
   604  	f := tc.sentFrames[0]
   605  	tc.sentFrames = tc.sentFrames[1:]
   606  	return f, tc.lastPacket.ptype
   607  }
   608  
   609  // wantDatagram indicates that we expect the Conn to send a datagram.
   610  func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
   611  	tc.t.Helper()
   612  	got := tc.readDatagram()
   613  	if !datagramEqual(got, want) {
   614  		tc.t.Fatalf("%v:\ngot datagram:  %v\nwant datagram: %v", expectation, got, want)
   615  	}
   616  }
   617  
   618  func datagramEqual(a, b *testDatagram) bool {
   619  	if a == nil && b == nil {
   620  		return true
   621  	}
   622  	if a == nil || b == nil {
   623  		return false
   624  	}
   625  	if a.paddedSize != b.paddedSize ||
   626  		a.addr != b.addr ||
   627  		len(a.packets) != len(b.packets) {
   628  		return false
   629  	}
   630  	for i := range a.packets {
   631  		if !packetEqual(a.packets[i], b.packets[i]) {
   632  			return false
   633  		}
   634  	}
   635  	return true
   636  }
   637  
   638  // wantPacket indicates that we expect the Conn to send a packet.
   639  func (tc *testConn) wantPacket(expectation string, want *testPacket) {
   640  	tc.t.Helper()
   641  	got := tc.readPacket()
   642  	if !packetEqual(got, want) {
   643  		tc.t.Fatalf("%v:\ngot packet:  %v\nwant packet: %v", expectation, got, want)
   644  	}
   645  }
   646  
   647  func packetEqual(a, b *testPacket) bool {
   648  	ac := *a
   649  	ac.frames = nil
   650  	ac.header = 0
   651  	bc := *b
   652  	bc.frames = nil
   653  	bc.header = 0
   654  	if !reflect.DeepEqual(ac, bc) {
   655  		return false
   656  	}
   657  	if len(a.frames) != len(b.frames) {
   658  		return false
   659  	}
   660  	for i := range a.frames {
   661  		if !frameEqual(a.frames[i], b.frames[i]) {
   662  			return false
   663  		}
   664  	}
   665  	return true
   666  }
   667  
   668  // wantFrame indicates that we expect the Conn to send a frame.
   669  func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
   670  	tc.t.Helper()
   671  	got, gotType := tc.readFrame()
   672  	if got == nil {
   673  		tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
   674  	}
   675  	if gotType != wantType {
   676  		tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame:  %v", expectation, gotType, wantType, got)
   677  	}
   678  	if !frameEqual(got, want) {
   679  		tc.t.Fatalf("%v:\ngot frame:  %v\nwant frame: %v", expectation, got, want)
   680  	}
   681  }
   682  
   683  func frameEqual(a, b debugFrame) bool {
   684  	switch af := a.(type) {
   685  	case debugFrameConnectionCloseTransport:
   686  		bf, ok := b.(debugFrameConnectionCloseTransport)
   687  		return ok && af.code == bf.code
   688  	}
   689  	return reflect.DeepEqual(a, b)
   690  }
   691  
   692  // wantFrameType indicates that we expect the Conn to send a frame,
   693  // although we don't care about the contents.
   694  func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
   695  	tc.t.Helper()
   696  	got, gotType := tc.readFrame()
   697  	if got == nil {
   698  		tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
   699  	}
   700  	if gotType != wantType {
   701  		tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame:  %v", expectation, gotType, wantType, got)
   702  	}
   703  	if reflect.TypeOf(got) != reflect.TypeOf(want) {
   704  		tc.t.Fatalf("%v:\ngot frame:  %v\nwant frame of type: %v", expectation, got, want)
   705  	}
   706  }
   707  
   708  // wantIdle indicates that we expect the Conn to not send any more frames.
   709  func (tc *testConn) wantIdle(expectation string) {
   710  	tc.t.Helper()
   711  	switch {
   712  	case len(tc.sentFrames) > 0:
   713  		tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0])
   714  	case len(tc.sentPackets) > 0:
   715  		tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0])
   716  	}
   717  	if f, _ := tc.readFrame(); f != nil {
   718  		tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f)
   719  	}
   720  }
   721  
   722  func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
   723  	t.Helper()
   724  	var w packetWriter
   725  	w.reset(1200)
   726  	var pnumMaxAcked packetNumber
   727  	switch p.ptype {
   728  	case packetTypeRetry:
   729  		return encodeRetryPacket(p.originalDstConnID, retryPacket{
   730  			srcConnID: p.srcConnID,
   731  			dstConnID: p.dstConnID,
   732  			token:     p.token,
   733  		})
   734  	case packetType1RTT:
   735  		w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
   736  	default:
   737  		w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
   738  			ptype:     p.ptype,
   739  			version:   p.version,
   740  			num:       p.num,
   741  			dstConnID: p.dstConnID,
   742  			srcConnID: p.srcConnID,
   743  			extra:     p.token,
   744  		})
   745  	}
   746  	for _, f := range p.frames {
   747  		f.write(&w)
   748  	}
   749  	w.appendPaddingTo(pad)
   750  	if p.ptype != packetType1RTT {
   751  		var k fixedKeys
   752  		if tc == nil {
   753  			if p.ptype == packetTypeInitial {
   754  				k = initialKeys(p.dstConnID, serverSide).r
   755  			} else {
   756  				t.Fatalf("sending %v packet with no conn", p.ptype)
   757  			}
   758  		} else {
   759  			switch p.ptype {
   760  			case packetTypeInitial:
   761  				k = tc.keysInitial.w
   762  			case packetTypeHandshake:
   763  				k = tc.keysHandshake.w
   764  			}
   765  		}
   766  		if !k.isSet() {
   767  			t.Fatalf("sending %v packet with no write key", p.ptype)
   768  		}
   769  		w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
   770  			ptype:     p.ptype,
   771  			version:   p.version,
   772  			num:       p.num,
   773  			dstConnID: p.dstConnID,
   774  			srcConnID: p.srcConnID,
   775  			extra:     p.token,
   776  		})
   777  	} else {
   778  		if tc == nil || !tc.wkeyAppData.hdr.isSet() {
   779  			t.Fatalf("sending 1-RTT packet with no write key")
   780  		}
   781  		// Somewhat hackish: Generate a temporary updatingKeyPair that will
   782  		// always use our desired key phase.
   783  		k := &updatingKeyPair{
   784  			w: updatingKeys{
   785  				hdr: tc.wkeyAppData.hdr,
   786  				pkt: [2]packetKey{
   787  					tc.wkeyAppData.pkt[p.keyNumber],
   788  					tc.wkeyAppData.pkt[p.keyNumber],
   789  				},
   790  			},
   791  			updateAfter: maxPacketNumber,
   792  		}
   793  		if p.keyPhaseBit {
   794  			k.phase |= keyPhaseBit
   795  		}
   796  		w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
   797  	}
   798  	return w.datagram()
   799  }
   800  
   801  func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
   802  	t.Helper()
   803  	bufSize := len(buf)
   804  	d := &testDatagram{}
   805  	size := len(buf)
   806  	for len(buf) > 0 {
   807  		if buf[0] == 0 {
   808  			d.paddedSize = bufSize
   809  			break
   810  		}
   811  		ptype := getPacketType(buf)
   812  		switch ptype {
   813  		case packetTypeRetry:
   814  			retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
   815  			if !ok {
   816  				t.Fatalf("could not parse %v packet", ptype)
   817  			}
   818  			return &testDatagram{
   819  				packets: []*testPacket{{
   820  					ptype:     packetTypeRetry,
   821  					dstConnID: retry.dstConnID,
   822  					srcConnID: retry.srcConnID,
   823  					token:     retry.token,
   824  				}},
   825  			}
   826  		case packetTypeInitial, packetTypeHandshake:
   827  			var k fixedKeys
   828  			if tc == nil {
   829  				if ptype == packetTypeInitial {
   830  					p, _ := parseGenericLongHeaderPacket(buf)
   831  					k = initialKeys(p.srcConnID, serverSide).w
   832  				} else {
   833  					t.Fatalf("reading %v packet with no conn", ptype)
   834  				}
   835  			} else {
   836  				switch ptype {
   837  				case packetTypeInitial:
   838  					k = tc.keysInitial.r
   839  				case packetTypeHandshake:
   840  					k = tc.keysHandshake.r
   841  				}
   842  			}
   843  			if !k.isSet() {
   844  				t.Fatalf("reading %v packet with no read key", ptype)
   845  			}
   846  			var pnumMax packetNumber // TODO: Track packet numbers.
   847  			p, n := parseLongHeaderPacket(buf, k, pnumMax)
   848  			if n < 0 {
   849  				t.Fatalf("packet parse error")
   850  			}
   851  			frames, err := parseTestFrames(t, p.payload)
   852  			if err != nil {
   853  				t.Fatal(err)
   854  			}
   855  			var token []byte
   856  			if ptype == packetTypeInitial && len(p.extra) > 0 {
   857  				token = p.extra
   858  			}
   859  			d.packets = append(d.packets, &testPacket{
   860  				ptype:     p.ptype,
   861  				header:    buf[0],
   862  				version:   p.version,
   863  				num:       p.num,
   864  				dstConnID: p.dstConnID,
   865  				srcConnID: p.srcConnID,
   866  				token:     token,
   867  				frames:    frames,
   868  			})
   869  			buf = buf[n:]
   870  		case packetType1RTT:
   871  			if tc == nil || !tc.rkeyAppData.hdr.isSet() {
   872  				t.Fatalf("reading 1-RTT packet with no read key")
   873  			}
   874  			var pnumMax packetNumber // TODO: Track packet numbers.
   875  			pnumOff := 1 + len(tc.peerConnID)
   876  			// Try unprotecting the packet with the first maxTestKeyPhases keys.
   877  			var phase int
   878  			var pnum packetNumber
   879  			var hdr []byte
   880  			var pay []byte
   881  			var err error
   882  			for phase = 0; phase < maxTestKeyPhases; phase++ {
   883  				b := append([]byte{}, buf...)
   884  				hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
   885  				if err != nil {
   886  					t.Fatalf("1-RTT packet header parse error")
   887  				}
   888  				k := tc.rkeyAppData.pkt[phase]
   889  				pay, err = k.unprotect(hdr, pay, pnum)
   890  				if err == nil {
   891  					break
   892  				}
   893  			}
   894  			if err != nil {
   895  				t.Fatalf("1-RTT packet payload parse error")
   896  			}
   897  			frames, err := parseTestFrames(t, pay)
   898  			if err != nil {
   899  				t.Fatal(err)
   900  			}
   901  			d.packets = append(d.packets, &testPacket{
   902  				ptype:       packetType1RTT,
   903  				header:      hdr[0],
   904  				num:         pnum,
   905  				dstConnID:   hdr[1:][:len(tc.peerConnID)],
   906  				keyPhaseBit: hdr[0]&keyPhaseBit != 0,
   907  				keyNumber:   phase,
   908  				frames:      frames,
   909  			})
   910  			buf = buf[len(buf):]
   911  		default:
   912  			t.Fatalf("unhandled packet type %v", ptype)
   913  		}
   914  	}
   915  	// This is rather hackish: If the last frame in the last packet
   916  	// in the datagram is PADDING, then remove it and record
   917  	// the padded size in the testDatagram.paddedSize.
   918  	//
   919  	// This makes it easier to write a test that expects a datagram
   920  	// padded to 1200 bytes.
   921  	if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 {
   922  		p := d.packets[len(d.packets)-1]
   923  		f := p.frames[len(p.frames)-1]
   924  		if _, ok := f.(debugFramePadding); ok {
   925  			p.frames = p.frames[:len(p.frames)-1]
   926  			d.paddedSize = size
   927  		}
   928  	}
   929  	return d
   930  }
   931  
   932  func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
   933  	t.Helper()
   934  	var frames []debugFrame
   935  	for len(payload) > 0 {
   936  		f, n := parseDebugFrame(payload)
   937  		if n < 0 {
   938  			return nil, errors.New("error parsing frames")
   939  		}
   940  		frames = append(frames, f)
   941  		payload = payload[n:]
   942  	}
   943  	return frames, nil
   944  }
   945  
   946  func spaceForPacketType(ptype packetType) numberSpace {
   947  	switch ptype {
   948  	case packetTypeInitial:
   949  		return initialSpace
   950  	case packetType0RTT:
   951  		panic("TODO: packetType0RTT")
   952  	case packetTypeHandshake:
   953  		return handshakeSpace
   954  	case packetTypeRetry:
   955  		panic("retry packets have no number space")
   956  	case packetType1RTT:
   957  		return appDataSpace
   958  	}
   959  	panic("unknown packet type")
   960  }
   961  
   962  // testConnHooks implements connTestHooks.
   963  type testConnHooks testConn
   964  
   965  func (tc *testConnHooks) init() {
   966  	tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
   967  	tc.keysInitial.r = tc.conn.keysInitial.w
   968  	tc.keysInitial.w = tc.conn.keysInitial.r
   969  	if tc.conn.side == serverSide {
   970  		tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
   971  	}
   972  }
   973  
   974  // handleTLSEvent processes TLS events generated by
   975  // the connection under test's tls.QUICConn.
   976  //
   977  // We maintain a second tls.QUICConn representing the peer,
   978  // and feed the TLS handshake data into it.
   979  //
   980  // We stash TLS handshake data from both sides in the testConn,
   981  // where it can be used by tests.
   982  //
   983  // We snoop packet protection keys out of the tls.QUICConns,
   984  // and verify that both sides of the connection are getting
   985  // matching keys.
   986  func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
   987  	checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
   988  		var space numberSpace
   989  		switch {
   990  		case e.Level == tls.QUICEncryptionLevelHandshake:
   991  			space = handshakeSpace
   992  		case e.Level == tls.QUICEncryptionLevelApplication:
   993  			space = appDataSpace
   994  		default:
   995  			tc.t.Errorf("unexpected encryption level %v", e.Level)
   996  			return
   997  		}
   998  		if secrets[space].secret == nil {
   999  			secrets[space].suite = e.Suite
  1000  			secrets[space].secret = append([]byte{}, e.Data...)
  1001  		} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
  1002  			tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
  1003  		}
  1004  	}
  1005  	setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
  1006  		k.hdr.init(suite, secret)
  1007  		for i := 0; i < len(k.pkt); i++ {
  1008  			k.pkt[i].init(suite, secret)
  1009  			secret = updateSecret(suite, secret)
  1010  		}
  1011  	}
  1012  	switch e.Kind {
  1013  	case tls.QUICSetReadSecret:
  1014  		checkKey("write", &tc.wsecrets, e)
  1015  		switch e.Level {
  1016  		case tls.QUICEncryptionLevelHandshake:
  1017  			tc.keysHandshake.w.init(e.Suite, e.Data)
  1018  		case tls.QUICEncryptionLevelApplication:
  1019  			setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
  1020  		}
  1021  	case tls.QUICSetWriteSecret:
  1022  		checkKey("read", &tc.rsecrets, e)
  1023  		switch e.Level {
  1024  		case tls.QUICEncryptionLevelHandshake:
  1025  			tc.keysHandshake.r.init(e.Suite, e.Data)
  1026  		case tls.QUICEncryptionLevelApplication:
  1027  			setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
  1028  		}
  1029  	case tls.QUICWriteData:
  1030  		tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
  1031  		tc.peerTLSConn.HandleData(e.Level, e.Data)
  1032  	}
  1033  	for {
  1034  		e := tc.peerTLSConn.NextEvent()
  1035  		switch e.Kind {
  1036  		case tls.QUICNoEvent:
  1037  			return
  1038  		case tls.QUICSetReadSecret:
  1039  			checkKey("write", &tc.rsecrets, e)
  1040  			switch e.Level {
  1041  			case tls.QUICEncryptionLevelHandshake:
  1042  				tc.keysHandshake.r.init(e.Suite, e.Data)
  1043  			case tls.QUICEncryptionLevelApplication:
  1044  				setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
  1045  			}
  1046  		case tls.QUICSetWriteSecret:
  1047  			checkKey("read", &tc.wsecrets, e)
  1048  			switch e.Level {
  1049  			case tls.QUICEncryptionLevelHandshake:
  1050  				tc.keysHandshake.w.init(e.Suite, e.Data)
  1051  			case tls.QUICEncryptionLevelApplication:
  1052  				setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
  1053  			}
  1054  		case tls.QUICWriteData:
  1055  			tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
  1056  		case tls.QUICTransportParameters:
  1057  			p, err := unmarshalTransportParams(e.Data)
  1058  			if err != nil {
  1059  				tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err)
  1060  			} else {
  1061  				tc.sentTransportParameters = &p
  1062  			}
  1063  		}
  1064  	}
  1065  }
  1066  
  1067  // nextMessage is called by the Conn's event loop to request its next event.
  1068  func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
  1069  	tc.timer = timer
  1070  	for {
  1071  		if !timer.IsZero() && !timer.After(tc.endpoint.now) {
  1072  			if timer.Equal(tc.timerLastFired) {
  1073  				// If the connection timer fires at time T, the Conn should take some
  1074  				// action to advance the timer into the future. If the Conn reschedules
  1075  				// the timer for the same time, it isn't making progress and we have a bug.
  1076  				tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
  1077  			} else {
  1078  				tc.timerLastFired = timer
  1079  				return tc.endpoint.now, timerEvent{}
  1080  			}
  1081  		}
  1082  		select {
  1083  		case m := <-msgc:
  1084  			return tc.endpoint.now, m
  1085  		default:
  1086  		}
  1087  		if !tc.wakeAsync() {
  1088  			break
  1089  		}
  1090  	}
  1091  	// If the message queue is empty, then the conn is idle.
  1092  	if tc.idlec != nil {
  1093  		idlec := tc.idlec
  1094  		tc.idlec = nil
  1095  		close(idlec)
  1096  	}
  1097  	m = <-msgc
  1098  	return tc.endpoint.now, m
  1099  }
  1100  
  1101  func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
  1102  	return testLocalConnID(seq), nil
  1103  }
  1104  
  1105  func (tc *testConnHooks) timeNow() time.Time {
  1106  	return tc.endpoint.now
  1107  }
  1108  
  1109  // testLocalConnID returns the connection ID with a given sequence number
  1110  // used by a Conn under test.
  1111  func testLocalConnID(seq int64) []byte {
  1112  	cid := make([]byte, connIDLen)
  1113  	copy(cid, []byte{0xc0, 0xff, 0xee})
  1114  	cid[len(cid)-1] = byte(seq)
  1115  	return cid
  1116  }
  1117  
  1118  // testPeerConnID returns the connection ID with a given sequence number
  1119  // used by the fake peer of a Conn under test.
  1120  func testPeerConnID(seq int64) []byte {
  1121  	// Use a different length than we choose for our own conn ids,
  1122  	// to help catch any bad assumptions.
  1123  	return []byte{0xbe, 0xee, 0xff, byte(seq)}
  1124  }
  1125  
  1126  func testPeerStatelessResetToken(seq int64) statelessResetToken {
  1127  	return statelessResetToken{
  1128  		0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
  1129  		0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
  1130  	}
  1131  }
  1132  
  1133  // canceledContext returns a canceled Context.
  1134  //
  1135  // Functions which take a context preference progress over cancelation.
  1136  // For example, a read with a canceled context will return data if any is available.
  1137  // Tests use canceled contexts to perform non-blocking operations.
  1138  func canceledContext() context.Context {
  1139  	ctx, cancel := context.WithCancel(context.Background())
  1140  	cancel()
  1141  	return ctx
  1142  }
  1143  

View as plain text