1
2
3
4
5
6
7 package quic
8
9 import (
10 "context"
11 "crypto/rand"
12 "errors"
13 "net"
14 "net/netip"
15 "sync"
16 "sync/atomic"
17 "time"
18 )
19
20
21
22
23
24 type Endpoint struct {
25 config *Config
26 udpConn udpConn
27 testHooks endpointTestHooks
28 resetGen statelessResetTokenGenerator
29 retry retryState
30
31 acceptQueue queue[*Conn]
32 connsMap connsMap
33
34 connsMu sync.Mutex
35 conns map[*Conn]struct{}
36 closing bool
37 closec chan struct{}
38 }
39
40 type endpointTestHooks interface {
41 timeNow() time.Time
42 newConn(c *Conn)
43 }
44
45
46
47 type udpConn interface {
48 Close() error
49 LocalAddr() net.Addr
50 ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error)
51 WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
52 }
53
54
55
56 func Listen(network, address string, config *Config) (*Endpoint, error) {
57 if config.TLSConfig == nil {
58 return nil, errors.New("TLSConfig is not set")
59 }
60 a, err := net.ResolveUDPAddr(network, address)
61 if err != nil {
62 return nil, err
63 }
64 udpConn, err := net.ListenUDP(network, a)
65 if err != nil {
66 return nil, err
67 }
68 return newEndpoint(udpConn, config, nil)
69 }
70
71 func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
72 e := &Endpoint{
73 config: config,
74 udpConn: udpConn,
75 testHooks: hooks,
76 conns: make(map[*Conn]struct{}),
77 acceptQueue: newQueue[*Conn](),
78 closec: make(chan struct{}),
79 }
80 e.resetGen.init(config.StatelessResetKey)
81 e.connsMap.init()
82 if config.RequireAddressValidation {
83 if err := e.retry.init(); err != nil {
84 return nil, err
85 }
86 }
87 go e.listen()
88 return e, nil
89 }
90
91
92 func (e *Endpoint) LocalAddr() netip.AddrPort {
93 a, _ := e.udpConn.LocalAddr().(*net.UDPAddr)
94 return a.AddrPort()
95 }
96
97
98
99
100
101
102
103
104 func (e *Endpoint) Close(ctx context.Context) error {
105 e.acceptQueue.close(errors.New("endpoint closed"))
106
107
108
109 var conns []*Conn
110 e.connsMu.Lock()
111 if !e.closing {
112 e.closing = true
113 for c := range e.conns {
114 conns = append(conns, c)
115 }
116 if len(e.conns) == 0 {
117 e.udpConn.Close()
118 }
119 }
120 e.connsMu.Unlock()
121
122 for _, c := range conns {
123 c.Abort(localTransportError{code: errNo})
124 }
125 select {
126 case <-e.closec:
127 case <-ctx.Done():
128 for _, c := range conns {
129 c.exit()
130 }
131 return ctx.Err()
132 }
133 return nil
134 }
135
136
137 func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
138 return e.acceptQueue.get(ctx, nil)
139 }
140
141
142 func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, error) {
143 u, err := net.ResolveUDPAddr(network, address)
144 if err != nil {
145 return nil, err
146 }
147 addr := u.AddrPort()
148 addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
149 c, err := e.newConn(time.Now(), clientSide, newServerConnIDs{}, addr)
150 if err != nil {
151 return nil, err
152 }
153 if err := c.waitReady(ctx); err != nil {
154 c.Abort(nil)
155 return nil, err
156 }
157 return c, nil
158 }
159
160 func (e *Endpoint) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) {
161 e.connsMu.Lock()
162 defer e.connsMu.Unlock()
163 if e.closing {
164 return nil, errors.New("endpoint closed")
165 }
166 c, err := newConn(now, side, cids, peerAddr, e.config, e)
167 if err != nil {
168 return nil, err
169 }
170 e.conns[c] = struct{}{}
171 return c, nil
172 }
173
174
175
176 func (e *Endpoint) serverConnEstablished(c *Conn) {
177 e.acceptQueue.put(c)
178 }
179
180
181
182 func (e *Endpoint) connDrained(c *Conn) {
183 var cids [][]byte
184 for i := range c.connIDState.local {
185 cids = append(cids, c.connIDState.local[i].cid)
186 }
187 var tokens []statelessResetToken
188 for i := range c.connIDState.remote {
189 tokens = append(tokens, c.connIDState.remote[i].resetToken)
190 }
191 e.connsMap.updateConnIDs(func(conns *connsMap) {
192 for _, cid := range cids {
193 conns.retireConnID(c, cid)
194 }
195 for _, token := range tokens {
196 conns.retireResetToken(c, token)
197 }
198 })
199 e.connsMu.Lock()
200 defer e.connsMu.Unlock()
201 delete(e.conns, c)
202 if e.closing && len(e.conns) == 0 {
203 e.udpConn.Close()
204 }
205 }
206
207 func (e *Endpoint) listen() {
208 defer close(e.closec)
209 for {
210 m := newDatagram()
211
212
213 n, _, _, addr, err := e.udpConn.ReadMsgUDPAddrPort(m.b, nil)
214 if err != nil {
215
216
217
218
219 return
220 }
221 if n == 0 {
222 continue
223 }
224 if e.connsMap.updateNeeded.Load() {
225 e.connsMap.applyUpdates()
226 }
227 m.addr = addr
228 m.b = m.b[:n]
229 e.handleDatagram(m)
230 }
231 }
232
233 func (e *Endpoint) handleDatagram(m *datagram) {
234 dstConnID, ok := dstConnIDForDatagram(m.b)
235 if !ok {
236 m.recycle()
237 return
238 }
239 c := e.connsMap.byConnID[string(dstConnID)]
240 if c == nil {
241
242
243 e.handleUnknownDestinationDatagram(m)
244 return
245 }
246
247
248
249 c.sendMsg(m)
250 }
251
252 func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
253 defer func() {
254 if m != nil {
255 m.recycle()
256 }
257 }()
258 const minimumValidPacketSize = 21
259 if len(m.b) < minimumValidPacketSize {
260 return
261 }
262 var now time.Time
263 if e.testHooks != nil {
264 now = e.testHooks.timeNow()
265 } else {
266 now = time.Now()
267 }
268
269 var token statelessResetToken
270 copy(token[:], m.b[len(m.b)-len(token):])
271 if c := e.connsMap.byResetToken[token]; c != nil {
272 c.sendMsg(func(now time.Time, c *Conn) {
273 c.handleStatelessReset(now, token)
274 })
275 return
276 }
277
278
279 if !isLongHeader(m.b[0]) {
280 e.maybeSendStatelessReset(m.b, m.addr)
281 return
282 }
283 p, ok := parseGenericLongHeaderPacket(m.b)
284 if !ok || len(m.b) < paddedInitialDatagramSize {
285 return
286 }
287 switch p.version {
288 case quicVersion1:
289 case 0:
290
291 return
292 default:
293
294 e.sendVersionNegotiation(p, m.addr)
295 return
296 }
297 if getPacketType(m.b) != packetTypeInitial {
298
299
300
301
302
303 return
304 }
305 cids := newServerConnIDs{
306 srcConnID: p.srcConnID,
307 dstConnID: p.dstConnID,
308 }
309 if e.config.RequireAddressValidation {
310 var ok bool
311 cids.retrySrcConnID = p.dstConnID
312 cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.addr)
313 if !ok {
314 return
315 }
316 } else {
317 cids.originalDstConnID = p.dstConnID
318 }
319 var err error
320 c, err := e.newConn(now, serverSide, cids, m.addr)
321 if err != nil {
322
323
324
325
326 return
327 }
328 c.sendMsg(m)
329 m = nil
330 }
331
332 func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) {
333 if !e.resetGen.canReset {
334
335 return
336 }
337
338
339
340
341
342
343 if len(b) < 1+connIDLen+1+1+16 {
344 return
345 }
346
347 cid := b[1:][:connIDLen]
348 token := e.resetGen.tokenForConnID(cid)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366 size := min(len(b)-1, 42)
367
368 b = b[:size]
369 rand.Read(b[:len(b)-statelessResetTokenLen])
370 b[0] &^= headerFormLong
371 b[0] |= fixedBit
372 copy(b[len(b)-statelessResetTokenLen:], token[:])
373 e.sendDatagram(b, addr)
374 }
375
376 func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) {
377 m := newDatagram()
378 m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
379 e.sendDatagram(m.b, addr)
380 m.recycle()
381 }
382
383 func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) {
384 keys := initialKeys(in.dstConnID, serverSide)
385 var w packetWriter
386 p := longPacket{
387 ptype: packetTypeInitial,
388 version: quicVersion1,
389 num: 0,
390 dstConnID: in.srcConnID,
391 srcConnID: in.dstConnID,
392 }
393 const pnumMaxAcked = 0
394 w.reset(paddedInitialDatagramSize)
395 w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
396 w.appendConnectionCloseTransportFrame(code, 0, "")
397 w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
398 buf := w.datagram()
399 if len(buf) == 0 {
400 return
401 }
402 e.sendDatagram(buf, addr)
403 }
404
405 func (e *Endpoint) sendDatagram(p []byte, addr netip.AddrPort) error {
406 _, err := e.udpConn.WriteToUDPAddrPort(p, addr)
407 return err
408 }
409
410
411 type connsMap struct {
412 byConnID map[string]*Conn
413 byResetToken map[statelessResetToken]*Conn
414
415 updateMu sync.Mutex
416 updateNeeded atomic.Bool
417 updates []func(*connsMap)
418 }
419
420 func (m *connsMap) init() {
421 m.byConnID = map[string]*Conn{}
422 m.byResetToken = map[statelessResetToken]*Conn{}
423 }
424
425 func (m *connsMap) addConnID(c *Conn, cid []byte) {
426 m.byConnID[string(cid)] = c
427 }
428
429 func (m *connsMap) retireConnID(c *Conn, cid []byte) {
430 delete(m.byConnID, string(cid))
431 }
432
433 func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
434 m.byResetToken[token] = c
435 }
436
437 func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
438 delete(m.byResetToken, token)
439 }
440
441 func (m *connsMap) updateConnIDs(f func(*connsMap)) {
442 m.updateMu.Lock()
443 defer m.updateMu.Unlock()
444 m.updates = append(m.updates, f)
445 m.updateNeeded.Store(true)
446 }
447
448
449 func (m *connsMap) applyUpdates() {
450 m.updateMu.Lock()
451 defer m.updateMu.Unlock()
452 for _, f := range m.updates {
453 f(m)
454 }
455 clear(m.updates)
456 m.updates = m.updates[:0]
457 m.updateNeeded.Store(false)
458 }
459
View as plain text