...

Source file src/golang.org/x/net/internal/quic/endpoint.go

Documentation: golang.org/x/net/internal/quic

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"context"
    11  	"crypto/rand"
    12  	"errors"
    13  	"net"
    14  	"net/netip"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  )
    19  
    20  // An Endpoint handles QUIC traffic on a network address.
    21  // It can accept inbound connections or create outbound ones.
    22  //
    23  // Multiple goroutines may invoke methods on an Endpoint simultaneously.
    24  type Endpoint struct {
    25  	config    *Config
    26  	udpConn   udpConn
    27  	testHooks endpointTestHooks
    28  	resetGen  statelessResetTokenGenerator
    29  	retry     retryState
    30  
    31  	acceptQueue queue[*Conn] // new inbound connections
    32  	connsMap    connsMap     // only accessed by the listen loop
    33  
    34  	connsMu sync.Mutex
    35  	conns   map[*Conn]struct{}
    36  	closing bool          // set when Close is called
    37  	closec  chan struct{} // closed when the listen loop exits
    38  }
    39  
    40  type endpointTestHooks interface {
    41  	timeNow() time.Time
    42  	newConn(c *Conn)
    43  }
    44  
    45  // A udpConn is a UDP connection.
    46  // It is implemented by net.UDPConn.
    47  type udpConn interface {
    48  	Close() error
    49  	LocalAddr() net.Addr
    50  	ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error)
    51  	WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
    52  }
    53  
    54  // Listen listens on a local network address.
    55  // The configuration config must be non-nil.
    56  func Listen(network, address string, config *Config) (*Endpoint, error) {
    57  	if config.TLSConfig == nil {
    58  		return nil, errors.New("TLSConfig is not set")
    59  	}
    60  	a, err := net.ResolveUDPAddr(network, address)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	udpConn, err := net.ListenUDP(network, a)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	return newEndpoint(udpConn, config, nil)
    69  }
    70  
    71  func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
    72  	e := &Endpoint{
    73  		config:      config,
    74  		udpConn:     udpConn,
    75  		testHooks:   hooks,
    76  		conns:       make(map[*Conn]struct{}),
    77  		acceptQueue: newQueue[*Conn](),
    78  		closec:      make(chan struct{}),
    79  	}
    80  	e.resetGen.init(config.StatelessResetKey)
    81  	e.connsMap.init()
    82  	if config.RequireAddressValidation {
    83  		if err := e.retry.init(); err != nil {
    84  			return nil, err
    85  		}
    86  	}
    87  	go e.listen()
    88  	return e, nil
    89  }
    90  
    91  // LocalAddr returns the local network address.
    92  func (e *Endpoint) LocalAddr() netip.AddrPort {
    93  	a, _ := e.udpConn.LocalAddr().(*net.UDPAddr)
    94  	return a.AddrPort()
    95  }
    96  
    97  // Close closes the Endpoint.
    98  // Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked
    99  // and return errors.
   100  //
   101  // Close aborts every open connection.
   102  // Data in stream read and write buffers is discarded.
   103  // It waits for the peers of any open connection to acknowledge the connection has been closed.
   104  func (e *Endpoint) Close(ctx context.Context) error {
   105  	e.acceptQueue.close(errors.New("endpoint closed"))
   106  
   107  	// It isn't safe to call Conn.Abort or conn.exit with connsMu held,
   108  	// so copy the list of conns.
   109  	var conns []*Conn
   110  	e.connsMu.Lock()
   111  	if !e.closing {
   112  		e.closing = true // setting e.closing prevents new conns from being created
   113  		for c := range e.conns {
   114  			conns = append(conns, c)
   115  		}
   116  		if len(e.conns) == 0 {
   117  			e.udpConn.Close()
   118  		}
   119  	}
   120  	e.connsMu.Unlock()
   121  
   122  	for _, c := range conns {
   123  		c.Abort(localTransportError{code: errNo})
   124  	}
   125  	select {
   126  	case <-e.closec:
   127  	case <-ctx.Done():
   128  		for _, c := range conns {
   129  			c.exit()
   130  		}
   131  		return ctx.Err()
   132  	}
   133  	return nil
   134  }
   135  
   136  // Accept waits for and returns the next connection.
   137  func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
   138  	return e.acceptQueue.get(ctx, nil)
   139  }
   140  
   141  // Dial creates and returns a connection to a network address.
   142  func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, error) {
   143  	u, err := net.ResolveUDPAddr(network, address)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	addr := u.AddrPort()
   148  	addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
   149  	c, err := e.newConn(time.Now(), clientSide, newServerConnIDs{}, addr)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	if err := c.waitReady(ctx); err != nil {
   154  		c.Abort(nil)
   155  		return nil, err
   156  	}
   157  	return c, nil
   158  }
   159  
   160  func (e *Endpoint) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) {
   161  	e.connsMu.Lock()
   162  	defer e.connsMu.Unlock()
   163  	if e.closing {
   164  		return nil, errors.New("endpoint closed")
   165  	}
   166  	c, err := newConn(now, side, cids, peerAddr, e.config, e)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	e.conns[c] = struct{}{}
   171  	return c, nil
   172  }
   173  
   174  // serverConnEstablished is called by a conn when the handshake completes
   175  // for an inbound (serverSide) connection.
   176  func (e *Endpoint) serverConnEstablished(c *Conn) {
   177  	e.acceptQueue.put(c)
   178  }
   179  
   180  // connDrained is called by a conn when it leaves the draining state,
   181  // either when the peer acknowledges connection closure or the drain timeout expires.
   182  func (e *Endpoint) connDrained(c *Conn) {
   183  	var cids [][]byte
   184  	for i := range c.connIDState.local {
   185  		cids = append(cids, c.connIDState.local[i].cid)
   186  	}
   187  	var tokens []statelessResetToken
   188  	for i := range c.connIDState.remote {
   189  		tokens = append(tokens, c.connIDState.remote[i].resetToken)
   190  	}
   191  	e.connsMap.updateConnIDs(func(conns *connsMap) {
   192  		for _, cid := range cids {
   193  			conns.retireConnID(c, cid)
   194  		}
   195  		for _, token := range tokens {
   196  			conns.retireResetToken(c, token)
   197  		}
   198  	})
   199  	e.connsMu.Lock()
   200  	defer e.connsMu.Unlock()
   201  	delete(e.conns, c)
   202  	if e.closing && len(e.conns) == 0 {
   203  		e.udpConn.Close()
   204  	}
   205  }
   206  
   207  func (e *Endpoint) listen() {
   208  	defer close(e.closec)
   209  	for {
   210  		m := newDatagram()
   211  		// TODO: Read and process the ECN (explicit congestion notification) field.
   212  		// https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4
   213  		n, _, _, addr, err := e.udpConn.ReadMsgUDPAddrPort(m.b, nil)
   214  		if err != nil {
   215  			// The user has probably closed the endpoint.
   216  			// We currently don't surface errors from other causes;
   217  			// we could check to see if the endpoint has been closed and
   218  			// record the unexpected error if it has not.
   219  			return
   220  		}
   221  		if n == 0 {
   222  			continue
   223  		}
   224  		if e.connsMap.updateNeeded.Load() {
   225  			e.connsMap.applyUpdates()
   226  		}
   227  		m.addr = addr
   228  		m.b = m.b[:n]
   229  		e.handleDatagram(m)
   230  	}
   231  }
   232  
   233  func (e *Endpoint) handleDatagram(m *datagram) {
   234  	dstConnID, ok := dstConnIDForDatagram(m.b)
   235  	if !ok {
   236  		m.recycle()
   237  		return
   238  	}
   239  	c := e.connsMap.byConnID[string(dstConnID)]
   240  	if c == nil {
   241  		// TODO: Move this branch into a separate goroutine to avoid blocking
   242  		// the endpoint while processing packets.
   243  		e.handleUnknownDestinationDatagram(m)
   244  		return
   245  	}
   246  
   247  	// TODO: This can block the endpoint while waiting for the conn to accept the dgram.
   248  	// Think about buffering between the receive loop and the conn.
   249  	c.sendMsg(m)
   250  }
   251  
   252  func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
   253  	defer func() {
   254  		if m != nil {
   255  			m.recycle()
   256  		}
   257  	}()
   258  	const minimumValidPacketSize = 21
   259  	if len(m.b) < minimumValidPacketSize {
   260  		return
   261  	}
   262  	var now time.Time
   263  	if e.testHooks != nil {
   264  		now = e.testHooks.timeNow()
   265  	} else {
   266  		now = time.Now()
   267  	}
   268  	// Check to see if this is a stateless reset.
   269  	var token statelessResetToken
   270  	copy(token[:], m.b[len(m.b)-len(token):])
   271  	if c := e.connsMap.byResetToken[token]; c != nil {
   272  		c.sendMsg(func(now time.Time, c *Conn) {
   273  			c.handleStatelessReset(now, token)
   274  		})
   275  		return
   276  	}
   277  	// If this is a 1-RTT packet, there's nothing productive we can do with it.
   278  	// Send a stateless reset if possible.
   279  	if !isLongHeader(m.b[0]) {
   280  		e.maybeSendStatelessReset(m.b, m.addr)
   281  		return
   282  	}
   283  	p, ok := parseGenericLongHeaderPacket(m.b)
   284  	if !ok || len(m.b) < paddedInitialDatagramSize {
   285  		return
   286  	}
   287  	switch p.version {
   288  	case quicVersion1:
   289  	case 0:
   290  		// Version Negotiation for an unknown connection.
   291  		return
   292  	default:
   293  		// Unknown version.
   294  		e.sendVersionNegotiation(p, m.addr)
   295  		return
   296  	}
   297  	if getPacketType(m.b) != packetTypeInitial {
   298  		// This packet isn't trying to create a new connection.
   299  		// It might be associated with some connection we've lost state for.
   300  		// We are technically permitted to send a stateless reset for
   301  		// a long-header packet, but this isn't generally useful. See:
   302  		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
   303  		return
   304  	}
   305  	cids := newServerConnIDs{
   306  		srcConnID: p.srcConnID,
   307  		dstConnID: p.dstConnID,
   308  	}
   309  	if e.config.RequireAddressValidation {
   310  		var ok bool
   311  		cids.retrySrcConnID = p.dstConnID
   312  		cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.addr)
   313  		if !ok {
   314  			return
   315  		}
   316  	} else {
   317  		cids.originalDstConnID = p.dstConnID
   318  	}
   319  	var err error
   320  	c, err := e.newConn(now, serverSide, cids, m.addr)
   321  	if err != nil {
   322  		// The accept queue is probably full.
   323  		// We could send a CONNECTION_CLOSE to the peer to reject the connection.
   324  		// Currently, we just drop the datagram.
   325  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
   326  		return
   327  	}
   328  	c.sendMsg(m)
   329  	m = nil // don't recycle, sendMsg takes ownership
   330  }
   331  
   332  func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) {
   333  	if !e.resetGen.canReset {
   334  		// Config.StatelessResetKey isn't set, so we don't send stateless resets.
   335  		return
   336  	}
   337  	// The smallest possible valid packet a peer can send us is:
   338  	//   1 byte of header
   339  	//   connIDLen bytes of destination connection ID
   340  	//   1 byte of packet number
   341  	//   1 byte of payload
   342  	//   16 bytes AEAD expansion
   343  	if len(b) < 1+connIDLen+1+1+16 {
   344  		return
   345  	}
   346  	// TODO: Rate limit stateless resets.
   347  	cid := b[1:][:connIDLen]
   348  	token := e.resetGen.tokenForConnID(cid)
   349  	// We want to generate a stateless reset that is as short as possible,
   350  	// but long enough to be difficult to distinguish from a 1-RTT packet.
   351  	//
   352  	// The minimal 1-RTT packet is:
   353  	//   1 byte of header
   354  	//   0-20 bytes of destination connection ID
   355  	//   1-4 bytes of packet number
   356  	//   1 byte of payload
   357  	//   16 bytes AEAD expansion
   358  	//
   359  	// Assuming the maximum possible connection ID and packet number size,
   360  	// this gives 1 + 20 + 4 + 1 + 16 = 42 bytes.
   361  	//
   362  	// We also must generate a stateless reset that is shorter than the datagram
   363  	// we are responding to, in order to ensure that reset loops terminate.
   364  	//
   365  	// See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3
   366  	size := min(len(b)-1, 42)
   367  	// Reuse the input buffer for generating the stateless reset.
   368  	b = b[:size]
   369  	rand.Read(b[:len(b)-statelessResetTokenLen])
   370  	b[0] &^= headerFormLong // clear long header bit
   371  	b[0] |= fixedBit        // set fixed bit
   372  	copy(b[len(b)-statelessResetTokenLen:], token[:])
   373  	e.sendDatagram(b, addr)
   374  }
   375  
   376  func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) {
   377  	m := newDatagram()
   378  	m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
   379  	e.sendDatagram(m.b, addr)
   380  	m.recycle()
   381  }
   382  
   383  func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) {
   384  	keys := initialKeys(in.dstConnID, serverSide)
   385  	var w packetWriter
   386  	p := longPacket{
   387  		ptype:     packetTypeInitial,
   388  		version:   quicVersion1,
   389  		num:       0,
   390  		dstConnID: in.srcConnID,
   391  		srcConnID: in.dstConnID,
   392  	}
   393  	const pnumMaxAcked = 0
   394  	w.reset(paddedInitialDatagramSize)
   395  	w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
   396  	w.appendConnectionCloseTransportFrame(code, 0, "")
   397  	w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
   398  	buf := w.datagram()
   399  	if len(buf) == 0 {
   400  		return
   401  	}
   402  	e.sendDatagram(buf, addr)
   403  }
   404  
   405  func (e *Endpoint) sendDatagram(p []byte, addr netip.AddrPort) error {
   406  	_, err := e.udpConn.WriteToUDPAddrPort(p, addr)
   407  	return err
   408  }
   409  
   410  // A connsMap is an endpoint's mapping of conn ids and reset tokens to conns.
   411  type connsMap struct {
   412  	byConnID     map[string]*Conn
   413  	byResetToken map[statelessResetToken]*Conn
   414  
   415  	updateMu     sync.Mutex
   416  	updateNeeded atomic.Bool
   417  	updates      []func(*connsMap)
   418  }
   419  
   420  func (m *connsMap) init() {
   421  	m.byConnID = map[string]*Conn{}
   422  	m.byResetToken = map[statelessResetToken]*Conn{}
   423  }
   424  
   425  func (m *connsMap) addConnID(c *Conn, cid []byte) {
   426  	m.byConnID[string(cid)] = c
   427  }
   428  
   429  func (m *connsMap) retireConnID(c *Conn, cid []byte) {
   430  	delete(m.byConnID, string(cid))
   431  }
   432  
   433  func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
   434  	m.byResetToken[token] = c
   435  }
   436  
   437  func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
   438  	delete(m.byResetToken, token)
   439  }
   440  
   441  func (m *connsMap) updateConnIDs(f func(*connsMap)) {
   442  	m.updateMu.Lock()
   443  	defer m.updateMu.Unlock()
   444  	m.updates = append(m.updates, f)
   445  	m.updateNeeded.Store(true)
   446  }
   447  
   448  // applyConnIDUpdates is called by the datagram receive loop to update its connection ID map.
   449  func (m *connsMap) applyUpdates() {
   450  	m.updateMu.Lock()
   451  	defer m.updateMu.Unlock()
   452  	for _, f := range m.updates {
   453  		f(m)
   454  	}
   455  	clear(m.updates)
   456  	m.updates = m.updates[:0]
   457  	m.updateNeeded.Store(false)
   458  }
   459  

View as plain text