...

Source file src/golang.org/x/net/internal/quic/endpoint_test.go

Documentation: golang.org/x/net/internal/quic

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"io"
    14  	"log/slog"
    15  	"net"
    16  	"net/netip"
    17  	"testing"
    18  	"time"
    19  
    20  	"golang.org/x/net/internal/quic/qlog"
    21  )
    22  
    23  func TestConnect(t *testing.T) {
    24  	newLocalConnPair(t, &Config{}, &Config{})
    25  }
    26  
    27  func TestStreamTransfer(t *testing.T) {
    28  	ctx := context.Background()
    29  	cli, srv := newLocalConnPair(t, &Config{}, &Config{})
    30  	data := makeTestData(1 << 20)
    31  
    32  	srvdone := make(chan struct{})
    33  	go func() {
    34  		defer close(srvdone)
    35  		s, err := srv.AcceptStream(ctx)
    36  		if err != nil {
    37  			t.Errorf("AcceptStream: %v", err)
    38  			return
    39  		}
    40  		b, err := io.ReadAll(s)
    41  		if err != nil {
    42  			t.Errorf("io.ReadAll(s): %v", err)
    43  			return
    44  		}
    45  		if !bytes.Equal(b, data) {
    46  			t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
    47  		}
    48  		if err := s.Close(); err != nil {
    49  			t.Errorf("s.Close() = %v", err)
    50  		}
    51  	}()
    52  
    53  	s, err := cli.NewSendOnlyStream(ctx)
    54  	if err != nil {
    55  		t.Fatalf("NewStream: %v", err)
    56  	}
    57  	n, err := io.Copy(s, bytes.NewBuffer(data))
    58  	if n != int64(len(data)) || err != nil {
    59  		t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
    60  	}
    61  	if err := s.Close(); err != nil {
    62  		t.Fatalf("s.Close() = %v", err)
    63  	}
    64  }
    65  
    66  func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
    67  	t.Helper()
    68  	ctx := context.Background()
    69  	e1 := newLocalEndpoint(t, serverSide, conf1)
    70  	e2 := newLocalEndpoint(t, clientSide, conf2)
    71  	c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String())
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	c1, err := e1.Accept(ctx)
    76  	if err != nil {
    77  		t.Fatal(err)
    78  	}
    79  	return c2, c1
    80  }
    81  
    82  func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint {
    83  	t.Helper()
    84  	if conf.TLSConfig == nil {
    85  		newConf := *conf
    86  		conf = &newConf
    87  		conf.TLSConfig = newTestTLSConfig(side)
    88  	}
    89  	if conf.QLogLogger == nil {
    90  		conf.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
    91  			Level: QLogLevelFrame,
    92  			Dir:   *qlogdir,
    93  		}))
    94  	}
    95  	e, err := Listen("udp", "127.0.0.1:0", conf)
    96  	if err != nil {
    97  		t.Fatal(err)
    98  	}
    99  	t.Cleanup(func() {
   100  		e.Close(canceledContext())
   101  	})
   102  	return e
   103  }
   104  
   105  type testEndpoint struct {
   106  	t                     *testing.T
   107  	e                     *Endpoint
   108  	now                   time.Time
   109  	recvc                 chan *datagram
   110  	idlec                 chan struct{}
   111  	conns                 map[*Conn]*testConn
   112  	acceptQueue           []*testConn
   113  	configTransportParams []func(*transportParameters)
   114  	configTestConn        []func(*testConn)
   115  	sentDatagrams         [][]byte
   116  	peerTLSConn           *tls.QUICConn
   117  	lastInitialDstConnID  []byte // for parsing Retry packets
   118  }
   119  
   120  func newTestEndpoint(t *testing.T, config *Config) *testEndpoint {
   121  	te := &testEndpoint{
   122  		t:     t,
   123  		now:   time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
   124  		recvc: make(chan *datagram),
   125  		idlec: make(chan struct{}),
   126  		conns: make(map[*Conn]*testConn),
   127  	}
   128  	var err error
   129  	te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te))
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  	t.Cleanup(te.cleanup)
   134  	return te
   135  }
   136  
   137  func (te *testEndpoint) cleanup() {
   138  	te.e.Close(canceledContext())
   139  }
   140  
   141  func (te *testEndpoint) wait() {
   142  	select {
   143  	case te.idlec <- struct{}{}:
   144  	case <-te.e.closec:
   145  	}
   146  	for _, tc := range te.conns {
   147  		tc.wait()
   148  	}
   149  }
   150  
   151  // accept returns a server connection from the endpoint.
   152  // Unlike Endpoint.Accept, connections are available as soon as they are created.
   153  func (te *testEndpoint) accept() *testConn {
   154  	if len(te.acceptQueue) == 0 {
   155  		te.t.Fatalf("accept: expected available conn, but found none")
   156  	}
   157  	tc := te.acceptQueue[0]
   158  	te.acceptQueue = te.acceptQueue[1:]
   159  	return tc
   160  }
   161  
   162  func (te *testEndpoint) write(d *datagram) {
   163  	te.recvc <- d
   164  	te.wait()
   165  }
   166  
   167  var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
   168  
   169  func (te *testEndpoint) writeDatagram(d *testDatagram) {
   170  	te.t.Helper()
   171  	logDatagram(te.t, "<- endpoint under test receives", d)
   172  	var buf []byte
   173  	for _, p := range d.packets {
   174  		tc := te.connForDestination(p.dstConnID)
   175  		if p.ptype != packetTypeRetry && tc != nil {
   176  			space := spaceForPacketType(p.ptype)
   177  			if p.num >= tc.peerNextPacketNum[space] {
   178  				tc.peerNextPacketNum[space] = p.num + 1
   179  			}
   180  		}
   181  		if p.ptype == packetTypeInitial {
   182  			te.lastInitialDstConnID = p.dstConnID
   183  		}
   184  		pad := 0
   185  		if p.ptype == packetType1RTT {
   186  			pad = d.paddedSize - len(buf)
   187  		}
   188  		buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...)
   189  	}
   190  	for len(buf) < d.paddedSize {
   191  		buf = append(buf, 0)
   192  	}
   193  	addr := d.addr
   194  	if !addr.IsValid() {
   195  		addr = testClientAddr
   196  	}
   197  	te.write(&datagram{
   198  		b:    buf,
   199  		addr: addr,
   200  	})
   201  }
   202  
   203  func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn {
   204  	for _, tc := range te.conns {
   205  		for _, loc := range tc.conn.connIDState.local {
   206  			if bytes.Equal(loc.cid, dstConnID) {
   207  				return tc
   208  			}
   209  		}
   210  	}
   211  	return nil
   212  }
   213  
   214  func (te *testEndpoint) connForSource(srcConnID []byte) *testConn {
   215  	for _, tc := range te.conns {
   216  		for _, loc := range tc.conn.connIDState.remote {
   217  			if bytes.Equal(loc.cid, srcConnID) {
   218  				return tc
   219  			}
   220  		}
   221  	}
   222  	return nil
   223  }
   224  
   225  func (te *testEndpoint) read() []byte {
   226  	te.t.Helper()
   227  	te.wait()
   228  	if len(te.sentDatagrams) == 0 {
   229  		return nil
   230  	}
   231  	d := te.sentDatagrams[0]
   232  	te.sentDatagrams = te.sentDatagrams[1:]
   233  	return d
   234  }
   235  
   236  func (te *testEndpoint) readDatagram() *testDatagram {
   237  	te.t.Helper()
   238  	buf := te.read()
   239  	if buf == nil {
   240  		return nil
   241  	}
   242  	p, _ := parseGenericLongHeaderPacket(buf)
   243  	tc := te.connForSource(p.dstConnID)
   244  	d := parseTestDatagram(te.t, te, tc, buf)
   245  	logDatagram(te.t, "-> endpoint under test sends", d)
   246  	return d
   247  }
   248  
   249  // wantDatagram indicates that we expect the Endpoint to send a datagram.
   250  func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
   251  	te.t.Helper()
   252  	got := te.readDatagram()
   253  	if !datagramEqual(got, want) {
   254  		te.t.Fatalf("%v:\ngot datagram:  %v\nwant datagram: %v", expectation, got, want)
   255  	}
   256  }
   257  
   258  // wantIdle indicates that we expect the Endpoint to not send any more datagrams.
   259  func (te *testEndpoint) wantIdle(expectation string) {
   260  	if got := te.readDatagram(); got != nil {
   261  		te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
   262  	}
   263  }
   264  
   265  // advance causes time to pass.
   266  func (te *testEndpoint) advance(d time.Duration) {
   267  	te.t.Helper()
   268  	te.advanceTo(te.now.Add(d))
   269  }
   270  
   271  // advanceTo sets the current time.
   272  func (te *testEndpoint) advanceTo(now time.Time) {
   273  	te.t.Helper()
   274  	if te.now.After(now) {
   275  		te.t.Fatalf("time moved backwards: %v -> %v", te.now, now)
   276  	}
   277  	te.now = now
   278  	for _, tc := range te.conns {
   279  		if !tc.timer.After(te.now) {
   280  			tc.conn.sendMsg(timerEvent{})
   281  			tc.wait()
   282  		}
   283  	}
   284  }
   285  
   286  // testEndpointHooks implements endpointTestHooks.
   287  type testEndpointHooks testEndpoint
   288  
   289  func (te *testEndpointHooks) timeNow() time.Time {
   290  	return te.now
   291  }
   292  
   293  func (te *testEndpointHooks) newConn(c *Conn) {
   294  	tc := newTestConnForConn(te.t, (*testEndpoint)(te), c)
   295  	te.conns[c] = tc
   296  }
   297  
   298  // testEndpointUDPConn implements UDPConn.
   299  type testEndpointUDPConn testEndpoint
   300  
   301  func (te *testEndpointUDPConn) Close() error {
   302  	close(te.recvc)
   303  	return nil
   304  }
   305  
   306  func (te *testEndpointUDPConn) LocalAddr() net.Addr {
   307  	return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
   308  }
   309  
   310  func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
   311  	for {
   312  		select {
   313  		case d, ok := <-te.recvc:
   314  			if !ok {
   315  				return 0, 0, 0, netip.AddrPort{}, io.EOF
   316  			}
   317  			n = copy(b, d.b)
   318  			return n, 0, 0, d.addr, nil
   319  		case <-te.idlec:
   320  		}
   321  	}
   322  }
   323  
   324  func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
   325  	te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...))
   326  	return len(b), nil
   327  }
   328  

View as plain text