// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build go1.21 package quic import ( "bytes" "crypto/rand" ) // connIDState is a conn's connection IDs. type connIDState struct { // The destination connection IDs of packets we receive are local. // The destination connection IDs of packets we send are remote. // // Local IDs are usually issued by us, and remote IDs by the peer. // The exception is the transient destination connection ID sent in // a client's Initial packets, which is chosen by the client. // // These are []connID rather than []*connID to minimize allocations. local []connID remote []remoteConnID nextLocalSeq int64 retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter originalDstConnID []byte // expected original_destination_connection_id param retrySrcConnID []byte // expected retry_source_connection_id param needSend bool } // A connID is a connection ID and associated metadata. type connID struct { // cid is the connection ID itself. cid []byte // seq is the connection ID's sequence number: // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1 // // For the transient destination ID in a client's Initial packet, this is -1. seq int64 // retired is set when the connection ID is retired. retired bool // send is set when the connection ID's state needs to be sent to the peer. // // For local IDs, this indicates a new ID that should be sent // in a NEW_CONNECTION_ID frame. // // For remote IDs, this indicates a retired ID that should be sent // in a RETIRE_CONNECTION_ID frame. send sentVal } // A remoteConnID is a connection ID and stateless reset token. type remoteConnID struct { connID resetToken statelessResetToken } func (s *connIDState) initClient(c *Conn) error { // Client chooses its initial connection ID, and sends it // in the Source Connection ID field of the first Initial packet. locid, err := c.newConnID(0) if err != nil { return err } s.local = append(s.local, connID{ seq: 0, cid: locid, }) s.nextLocalSeq = 1 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addConnID(c, locid) }) // Client chooses an initial, transient connection ID for the server, // and sends it in the Destination Connection ID field of the first Initial packet. remid, err := c.newConnID(-1) if err != nil { return err } s.remote = append(s.remote, remoteConnID{ connID: connID{ seq: -1, cid: remid, }, }) s.originalDstConnID = remid return nil } func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { dstConnID := cloneBytes(cids.dstConnID) // Client-chosen, transient connection ID received in the first Initial packet. // The server will not use this as the Source Connection ID of packets it sends, // but remembers it because it may receive packets sent to this destination. s.local = append(s.local, connID{ seq: -1, cid: dstConnID, }) // Server chooses a connection ID, and sends it in the Source Connection ID of // the response to the clent. locid, err := c.newConnID(0) if err != nil { return err } s.local = append(s.local, connID{ seq: 0, cid: locid, }) s.nextLocalSeq = 1 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addConnID(c, dstConnID) conns.addConnID(c, locid) }) // Client chose its own connection ID. s.remote = append(s.remote, remoteConnID{ connID: connID{ seq: 0, cid: cloneBytes(cids.srcConnID), }, }) return nil } // srcConnID is the Source Connection ID to use in a sent packet. func (s *connIDState) srcConnID() []byte { if s.local[0].seq == -1 && len(s.local) > 1 { // Don't use the transient connection ID if another is available. return s.local[1].cid } return s.local[0].cid } // dstConnID is the Destination Connection ID to use in a sent packet. func (s *connIDState) dstConnID() (cid []byte, ok bool) { for i := range s.remote { if !s.remote[i].retired { return s.remote[i].cid, true } } return nil, false } // isValidStatelessResetToken reports whether the given reset token is // associated with a non-retired connection ID which we have used. func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool { for i := range s.remote { // We currently only use the first available remote connection ID, // so any other reset token is not valid. if !s.remote[i].retired { return s.remote[i].resetToken == resetToken } } return false } // setPeerActiveConnIDLimit sets the active_connection_id_limit // transport parameter received from the peer. func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error { s.peerActiveConnIDLimit = lim return s.issueLocalIDs(c) } func (s *connIDState) issueLocalIDs(c *Conn) error { toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit) for i := range s.local { if s.local[i].seq != -1 && !s.local[i].retired { toIssue-- } } var newIDs [][]byte for toIssue > 0 { cid, err := c.newConnID(s.nextLocalSeq) if err != nil { return err } newIDs = append(newIDs, cid) s.local = append(s.local, connID{ seq: s.nextLocalSeq, cid: cid, }) s.local[len(s.local)-1].send.setUnsent() s.nextLocalSeq++ s.needSend = true toIssue-- } c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { for _, cid := range newIDs { conns.addConnID(c, cid) } }) return nil } // validateTransportParameters verifies the original_destination_connection_id and // initial_source_connection_id transport parameters match the expected values. func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error { // TODO: Consider returning more detailed errors, for debugging. // Verify original_destination_connection_id matches // the transient remote connection ID we chose (client) // or is empty (server). if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) { return localTransportError{ code: errTransportParameter, reason: "original_destination_connection_id mismatch", } } s.originalDstConnID = nil // we have no further need for this // Verify retry_source_connection_id matches the value from // the server's Retry packet (when one was sent), or is empty. if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) { return localTransportError{ code: errTransportParameter, reason: "retry_source_connection_id mismatch", } } s.retrySrcConnID = nil // we have no further need for this // Verify initial_source_connection_id matches the first remote connection ID. if len(s.remote) == 0 || s.remote[0].seq != 0 { return localTransportError{ code: errInternal, reason: "remote connection id missing", } } if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) { return localTransportError{ code: errTransportParameter, reason: "initial_source_connection_id mismatch", } } if len(p.statelessResetToken) > 0 { if c.side == serverSide { return localTransportError{ code: errTransportParameter, reason: "client sent stateless_reset_token", } } token := statelessResetToken(p.statelessResetToken) s.remote[0].resetToken = token c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addResetToken(c, token) }) } return nil } // handlePacket updates the connection ID state during the handshake // (Initial and Handshake packets). func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) { switch { case ptype == packetTypeInitial && c.side == clientSide: if len(s.remote) == 1 && s.remote[0].seq == -1 { // We're a client connection processing the first Initial packet // from the server. Replace the transient remote connection ID // with the Source Connection ID from the packet. s.remote[0] = remoteConnID{ connID: connID{ seq: 0, cid: cloneBytes(srcConnID), }, } } case ptype == packetTypeHandshake && c.side == serverSide: if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired { // We're a server connection processing the first Handshake packet from // the client. Discard the transient, client-chosen connection ID used // for Initial packets; the client will never send it again. cid := s.local[0].cid c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireConnID(c, cid) }) s.local = append(s.local[:0], s.local[1:]...) } } } func (s *connIDState) handleRetryPacket(srcConnID []byte) { if len(s.remote) != 1 || s.remote[0].seq != -1 { panic("BUG: handling retry with non-transient remote conn id") } s.retrySrcConnID = cloneBytes(srcConnID) s.remote[0].cid = s.retrySrcConnID } func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error { if len(s.remote[0].cid) == 0 { // "An endpoint that is sending packets with a zero-length // Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID // frame as a connection error of type PROTOCOL_VIOLATION." // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6 return localTransportError{ code: errProtocolViolation, reason: "NEW_CONNECTION_ID from peer with zero-length DCID", } } if retire > s.retireRemotePriorTo { s.retireRemotePriorTo = retire } have := false // do we already have this connection ID? active := 0 for i := range s.remote { rcid := &s.remote[i] if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo { s.retireRemote(rcid) c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireResetToken(c, rcid.resetToken) }) } if !rcid.retired { active++ } if rcid.seq == seq { if !bytes.Equal(rcid.cid, cid) { return localTransportError{ code: errProtocolViolation, reason: "NEW_CONNECTION_ID does not match prior id", } } have = true // yes, we've seen this sequence number } } if !have { // This is a new connection ID that we have not seen before. // // We could take steps to keep the list of remote connection IDs // sorted by sequence number, but there's no particular need // so we don't bother. s.remote = append(s.remote, remoteConnID{ connID: connID{ seq: seq, cid: cloneBytes(cid), }, resetToken: resetToken, }) if seq < s.retireRemotePriorTo { // This ID was already retired by a previous NEW_CONNECTION_ID frame. s.retireRemote(&s.remote[len(s.remote)-1]) } else { active++ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.addResetToken(c, resetToken) }) } } if active > activeConnIDLimit { // Retired connection IDs (including newly-retired ones) do not count // against the limit. // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5 return localTransportError{ code: errConnectionIDLimit, reason: "active_connection_id_limit exceeded", } } // "An endpoint SHOULD limit the number of connection IDs it has retired locally // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 // // Set a limit of four times the active_connection_id_limit for // the total number of remote connection IDs we keep state for locally. if len(s.remote) > 4*activeConnIDLimit { return localTransportError{ code: errConnectionIDLimit, reason: "too many unacknowledged RETIRE_CONNECTION_ID frames", } } return nil } // retireRemote marks a remote connection ID as retired. func (s *connIDState) retireRemote(rcid *remoteConnID) { rcid.retired = true rcid.send.setUnsent() s.needSend = true } func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { if seq >= s.nextLocalSeq { return localTransportError{ code: errProtocolViolation, reason: "RETIRE_CONNECTION_ID for unissued sequence number", } } for i := range s.local { if s.local[i].seq == seq { cid := s.local[i].cid c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { conns.retireConnID(c, cid) }) s.local = append(s.local[:i], s.local[i+1:]...) break } } s.issueLocalIDs(c) return nil } func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) { for i := range s.local { if s.local[i].seq != seq { continue } s.local[i].send.ackOrLoss(pnum, fate) if fate != packetAcked { s.needSend = true } return } } func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) { for i := 0; i < len(s.remote); i++ { if s.remote[i].seq != seq { continue } if fate == packetAcked { // We have retired this connection ID, and the peer has acked. // Discard its state completely. s.remote = append(s.remote[:i], s.remote[i+1:]...) } else { // RETIRE_CONNECTION_ID frame was lost, mark for retransmission. s.needSend = true s.remote[i].send.ackOrLoss(pnum, fate) } return } } // appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames // to the current packet. // // It returns true if no more frames need appending, // false if not everything fit in the current packet. func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool { if !s.needSend && !pto { // Fast path: We don't need to send anything. return true } retireBefore := int64(0) if s.local[0].seq != -1 { retireBefore = s.local[0].seq } for i := range s.local { if !s.local[i].send.shouldSendPTO(pto) { continue } if !c.w.appendNewConnectionIDFrame( s.local[i].seq, retireBefore, s.local[i].cid, c.endpoint.resetGen.tokenForConnID(s.local[i].cid), ) { return false } s.local[i].send.setSent(pnum) } for i := range s.remote { if !s.remote[i].send.shouldSendPTO(pto) { continue } if !c.w.appendRetireConnectionIDFrame(s.remote[i].seq) { return false } s.remote[i].send.setSent(pnum) } s.needSend = false return true } func cloneBytes(b []byte) []byte { n := make([]byte, len(b)) copy(n, b) return n } func (c *Conn) newConnID(seq int64) ([]byte, error) { if c.testHooks != nil { return c.testHooks.newConnID(seq) } return newRandomConnID(seq) } func newRandomConnID(_ int64) ([]byte, error) { // It is not necessary for connection IDs to be cryptographically secure, // but it doesn't hurt. id := make([]byte, connIDLen) if _, err := rand.Read(id); err != nil { // TODO: Surface this error as a metric or log event or something. // rand.Read really shouldn't ever fail, but if it does, we should // have a way to inform the user. return nil, err } return id, nil }