1
2
3
4
5
6
7 package quic
8
9 import (
10 "bytes"
11 "crypto/rand"
12 )
13
14
15 type connIDState struct {
16
17
18
19
20
21
22
23
24 local []connID
25 remote []remoteConnID
26
27 nextLocalSeq int64
28 retireRemotePriorTo int64
29 peerActiveConnIDLimit int64
30
31 originalDstConnID []byte
32 retrySrcConnID []byte
33
34 needSend bool
35 }
36
37
38 type connID struct {
39
40 cid []byte
41
42
43
44
45
46 seq int64
47
48
49 retired bool
50
51
52
53
54
55
56
57
58 send sentVal
59 }
60
61
62 type remoteConnID struct {
63 connID
64 resetToken statelessResetToken
65 }
66
67 func (s *connIDState) initClient(c *Conn) error {
68
69
70 locid, err := c.newConnID(0)
71 if err != nil {
72 return err
73 }
74 s.local = append(s.local, connID{
75 seq: 0,
76 cid: locid,
77 })
78 s.nextLocalSeq = 1
79 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
80 conns.addConnID(c, locid)
81 })
82
83
84
85 remid, err := c.newConnID(-1)
86 if err != nil {
87 return err
88 }
89 s.remote = append(s.remote, remoteConnID{
90 connID: connID{
91 seq: -1,
92 cid: remid,
93 },
94 })
95 s.originalDstConnID = remid
96 return nil
97 }
98
99 func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
100 dstConnID := cloneBytes(cids.dstConnID)
101
102
103
104 s.local = append(s.local, connID{
105 seq: -1,
106 cid: dstConnID,
107 })
108
109
110
111 locid, err := c.newConnID(0)
112 if err != nil {
113 return err
114 }
115 s.local = append(s.local, connID{
116 seq: 0,
117 cid: locid,
118 })
119 s.nextLocalSeq = 1
120 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
121 conns.addConnID(c, dstConnID)
122 conns.addConnID(c, locid)
123 })
124
125
126 s.remote = append(s.remote, remoteConnID{
127 connID: connID{
128 seq: 0,
129 cid: cloneBytes(cids.srcConnID),
130 },
131 })
132 return nil
133 }
134
135
136 func (s *connIDState) srcConnID() []byte {
137 if s.local[0].seq == -1 && len(s.local) > 1 {
138
139 return s.local[1].cid
140 }
141 return s.local[0].cid
142 }
143
144
145 func (s *connIDState) dstConnID() (cid []byte, ok bool) {
146 for i := range s.remote {
147 if !s.remote[i].retired {
148 return s.remote[i].cid, true
149 }
150 }
151 return nil, false
152 }
153
154
155
156 func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
157 for i := range s.remote {
158
159
160 if !s.remote[i].retired {
161 return s.remote[i].resetToken == resetToken
162 }
163 }
164 return false
165 }
166
167
168
169 func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
170 s.peerActiveConnIDLimit = lim
171 return s.issueLocalIDs(c)
172 }
173
174 func (s *connIDState) issueLocalIDs(c *Conn) error {
175 toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
176 for i := range s.local {
177 if s.local[i].seq != -1 && !s.local[i].retired {
178 toIssue--
179 }
180 }
181 var newIDs [][]byte
182 for toIssue > 0 {
183 cid, err := c.newConnID(s.nextLocalSeq)
184 if err != nil {
185 return err
186 }
187 newIDs = append(newIDs, cid)
188 s.local = append(s.local, connID{
189 seq: s.nextLocalSeq,
190 cid: cid,
191 })
192 s.local[len(s.local)-1].send.setUnsent()
193 s.nextLocalSeq++
194 s.needSend = true
195 toIssue--
196 }
197 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
198 for _, cid := range newIDs {
199 conns.addConnID(c, cid)
200 }
201 })
202 return nil
203 }
204
205
206
207 func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error {
208
209
210
211
212 if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) {
213 return localTransportError{
214 code: errTransportParameter,
215 reason: "original_destination_connection_id mismatch",
216 }
217 }
218 s.originalDstConnID = nil
219
220
221 if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) {
222 return localTransportError{
223 code: errTransportParameter,
224 reason: "retry_source_connection_id mismatch",
225 }
226 }
227 s.retrySrcConnID = nil
228
229 if len(s.remote) == 0 || s.remote[0].seq != 0 {
230 return localTransportError{
231 code: errInternal,
232 reason: "remote connection id missing",
233 }
234 }
235 if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
236 return localTransportError{
237 code: errTransportParameter,
238 reason: "initial_source_connection_id mismatch",
239 }
240 }
241 if len(p.statelessResetToken) > 0 {
242 if c.side == serverSide {
243 return localTransportError{
244 code: errTransportParameter,
245 reason: "client sent stateless_reset_token",
246 }
247 }
248 token := statelessResetToken(p.statelessResetToken)
249 s.remote[0].resetToken = token
250 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
251 conns.addResetToken(c, token)
252 })
253 }
254 return nil
255 }
256
257
258
259 func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
260 switch {
261 case ptype == packetTypeInitial && c.side == clientSide:
262 if len(s.remote) == 1 && s.remote[0].seq == -1 {
263
264
265
266 s.remote[0] = remoteConnID{
267 connID: connID{
268 seq: 0,
269 cid: cloneBytes(srcConnID),
270 },
271 }
272 }
273 case ptype == packetTypeHandshake && c.side == serverSide:
274 if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
275
276
277
278 cid := s.local[0].cid
279 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
280 conns.retireConnID(c, cid)
281 })
282 s.local = append(s.local[:0], s.local[1:]...)
283 }
284 }
285 }
286
287 func (s *connIDState) handleRetryPacket(srcConnID []byte) {
288 if len(s.remote) != 1 || s.remote[0].seq != -1 {
289 panic("BUG: handling retry with non-transient remote conn id")
290 }
291 s.retrySrcConnID = cloneBytes(srcConnID)
292 s.remote[0].cid = s.retrySrcConnID
293 }
294
295 func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error {
296 if len(s.remote[0].cid) == 0 {
297
298
299
300
301 return localTransportError{
302 code: errProtocolViolation,
303 reason: "NEW_CONNECTION_ID from peer with zero-length DCID",
304 }
305 }
306
307 if retire > s.retireRemotePriorTo {
308 s.retireRemotePriorTo = retire
309 }
310
311 have := false
312 active := 0
313 for i := range s.remote {
314 rcid := &s.remote[i]
315 if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
316 s.retireRemote(rcid)
317 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
318 conns.retireResetToken(c, rcid.resetToken)
319 })
320 }
321 if !rcid.retired {
322 active++
323 }
324 if rcid.seq == seq {
325 if !bytes.Equal(rcid.cid, cid) {
326 return localTransportError{
327 code: errProtocolViolation,
328 reason: "NEW_CONNECTION_ID does not match prior id",
329 }
330 }
331 have = true
332 }
333 }
334
335 if !have {
336
337
338
339
340
341 s.remote = append(s.remote, remoteConnID{
342 connID: connID{
343 seq: seq,
344 cid: cloneBytes(cid),
345 },
346 resetToken: resetToken,
347 })
348 if seq < s.retireRemotePriorTo {
349
350 s.retireRemote(&s.remote[len(s.remote)-1])
351 } else {
352 active++
353 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
354 conns.addResetToken(c, resetToken)
355 })
356 }
357 }
358
359 if active > activeConnIDLimit {
360
361
362
363 return localTransportError{
364 code: errConnectionIDLimit,
365 reason: "active_connection_id_limit exceeded",
366 }
367 }
368
369
370
371
372
373
374
375 if len(s.remote) > 4*activeConnIDLimit {
376 return localTransportError{
377 code: errConnectionIDLimit,
378 reason: "too many unacknowledged RETIRE_CONNECTION_ID frames",
379 }
380 }
381
382 return nil
383 }
384
385
386 func (s *connIDState) retireRemote(rcid *remoteConnID) {
387 rcid.retired = true
388 rcid.send.setUnsent()
389 s.needSend = true
390 }
391
392 func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
393 if seq >= s.nextLocalSeq {
394 return localTransportError{
395 code: errProtocolViolation,
396 reason: "RETIRE_CONNECTION_ID for unissued sequence number",
397 }
398 }
399 for i := range s.local {
400 if s.local[i].seq == seq {
401 cid := s.local[i].cid
402 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
403 conns.retireConnID(c, cid)
404 })
405 s.local = append(s.local[:i], s.local[i+1:]...)
406 break
407 }
408 }
409 s.issueLocalIDs(c)
410 return nil
411 }
412
413 func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) {
414 for i := range s.local {
415 if s.local[i].seq != seq {
416 continue
417 }
418 s.local[i].send.ackOrLoss(pnum, fate)
419 if fate != packetAcked {
420 s.needSend = true
421 }
422 return
423 }
424 }
425
426 func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
427 for i := 0; i < len(s.remote); i++ {
428 if s.remote[i].seq != seq {
429 continue
430 }
431 if fate == packetAcked {
432
433
434 s.remote = append(s.remote[:i], s.remote[i+1:]...)
435 } else {
436
437 s.needSend = true
438 s.remote[i].send.ackOrLoss(pnum, fate)
439 }
440 return
441 }
442 }
443
444
445
446
447
448
449 func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
450 if !s.needSend && !pto {
451
452 return true
453 }
454 retireBefore := int64(0)
455 if s.local[0].seq != -1 {
456 retireBefore = s.local[0].seq
457 }
458 for i := range s.local {
459 if !s.local[i].send.shouldSendPTO(pto) {
460 continue
461 }
462 if !c.w.appendNewConnectionIDFrame(
463 s.local[i].seq,
464 retireBefore,
465 s.local[i].cid,
466 c.endpoint.resetGen.tokenForConnID(s.local[i].cid),
467 ) {
468 return false
469 }
470 s.local[i].send.setSent(pnum)
471 }
472 for i := range s.remote {
473 if !s.remote[i].send.shouldSendPTO(pto) {
474 continue
475 }
476 if !c.w.appendRetireConnectionIDFrame(s.remote[i].seq) {
477 return false
478 }
479 s.remote[i].send.setSent(pnum)
480 }
481 s.needSend = false
482 return true
483 }
484
485 func cloneBytes(b []byte) []byte {
486 n := make([]byte, len(b))
487 copy(n, b)
488 return n
489 }
490
491 func (c *Conn) newConnID(seq int64) ([]byte, error) {
492 if c.testHooks != nil {
493 return c.testHooks.newConnID(seq)
494 }
495 return newRandomConnID(seq)
496 }
497
498 func newRandomConnID(_ int64) ([]byte, error) {
499
500
501 id := make([]byte, connIDLen)
502 if _, err := rand.Read(id); err != nil {
503
504
505
506 return nil, err
507 }
508 return id, nil
509 }
510
View as plain text