1
2
3
4
5
6
7 package quic
8
9 import (
10 "bytes"
11 "encoding/binary"
12 "encoding/hex"
13 "reflect"
14 "strings"
15 "testing"
16 )
17
18 func TestPacketHeader(t *testing.T) {
19 for _, test := range []struct {
20 name string
21 packet []byte
22 isLongHeader bool
23 packetType packetType
24 dstConnID []byte
25 }{{
26
27
28 name: "rfc9001_a1",
29 packet: unhex(`
30 c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11
31 `),
32 isLongHeader: true,
33 packetType: packetTypeInitial,
34 dstConnID: unhex(`8394c8f03e515708`),
35 }, {
36
37
38 name: "rfc9001_a3",
39 packet: unhex(`
40 cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a
41 `),
42 isLongHeader: true,
43 packetType: packetTypeInitial,
44 dstConnID: []byte{},
45 }, {
46
47 name: "rfc9001_a4",
48 packet: unhex(`
49 ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f
50 0f2496ba
51 `),
52 isLongHeader: true,
53 packetType: packetTypeRetry,
54 dstConnID: []byte{},
55 }, {
56
57 name: "rfc9001_a5",
58 packet: unhex(`
59 4cfe4189655e5cd55c41f69080575d7999c25a5bfb
60 `),
61 isLongHeader: false,
62 packetType: packetType1RTT,
63 dstConnID: unhex(`fe4189655e5cd55c`),
64 }, {
65
66 name: "version_negotiation",
67 packet: unhex(`
68 80 00000000 01ff0001020304
69 `),
70 isLongHeader: true,
71 packetType: packetTypeVersionNegotiation,
72 dstConnID: []byte{0xff},
73 }, {
74
75 name: "truncated_after_connid_length",
76 packet: unhex(`
77 cf0000000105
78 `),
79 isLongHeader: true,
80 packetType: packetTypeInitial,
81 dstConnID: nil,
82 }, {
83
84 name: "truncated_after_version",
85 packet: unhex(`
86 cf00000001
87 `),
88 isLongHeader: true,
89 packetType: packetTypeInitial,
90 dstConnID: nil,
91 }, {
92
93 name: "truncated_in_version",
94 packet: unhex(`
95 cf000000
96 `),
97 isLongHeader: true,
98 packetType: packetTypeInvalid,
99 dstConnID: nil,
100 }} {
101 t.Run(test.name, func(t *testing.T) {
102 if got, want := isLongHeader(test.packet[0]), test.isLongHeader; got != want {
103 t.Errorf("packet %x:\nisLongHeader(packet) = %v, want %v", test.packet, got, want)
104 }
105 if got, want := getPacketType(test.packet), test.packetType; got != want {
106 t.Errorf("packet %x:\ngetPacketType(packet) = %v, want %v", test.packet, got, want)
107 }
108 gotConnID, gotOK := dstConnIDForDatagram(test.packet)
109 wantConnID, wantOK := test.dstConnID, test.dstConnID != nil
110 if !bytes.Equal(gotConnID, wantConnID) || gotOK != wantOK {
111 t.Errorf("packet %x:\ndstConnIDForDatagram(packet) = {%x}, %v; want {%x}, %v", test.packet, gotConnID, gotOK, wantConnID, wantOK)
112 }
113 })
114 }
115 }
116
117 func TestEncodeDecodeVersionNegotiation(t *testing.T) {
118 dstConnID := []byte("this is a very long destination connection id")
119 srcConnID := []byte("this is a very long source connection id")
120 versions := []uint32{1, 0xffffffff}
121 got := appendVersionNegotiation([]byte{}, dstConnID, srcConnID, versions...)
122 want := bytes.Join([][]byte{{
123 0b1100_0000,
124 0, 0, 0, 0,
125 byte(len(dstConnID)),
126 }, dstConnID, {
127 byte(len(srcConnID)),
128 }, srcConnID, {
129 0x00, 0x00, 0x00, 0x01,
130 0xff, 0xff, 0xff, 0xff,
131 }}, nil)
132 if !bytes.Equal(got, want) {
133 t.Fatalf("appendVersionNegotiation(nil, %x, %x, %v):\ngot %x\nwant %x",
134 dstConnID, srcConnID, versions, got, want)
135 }
136 gotDst, gotSrc, gotVersionBytes := parseVersionNegotiation(got)
137 if got, want := gotDst, dstConnID; !bytes.Equal(got, want) {
138 t.Errorf("parseVersionNegotiation: got dstConnID = %x, want %x", got, want)
139 }
140 if got, want := gotSrc, srcConnID; !bytes.Equal(got, want) {
141 t.Errorf("parseVersionNegotiation: got srcConnID = %x, want %x", got, want)
142 }
143 var gotVersions []uint32
144 for len(gotVersionBytes) >= 4 {
145 gotVersions = append(gotVersions, binary.BigEndian.Uint32(gotVersionBytes))
146 gotVersionBytes = gotVersionBytes[4:]
147 }
148 if got, want := gotVersions, versions; !reflect.DeepEqual(got, want) {
149 t.Errorf("parseVersionNegotiation: got versions = %v, want %v", got, want)
150 }
151 }
152
153 func TestParseGenericLongHeaderPacket(t *testing.T) {
154 for _, test := range []struct {
155 name string
156 packet []byte
157 version uint32
158 dstConnID []byte
159 srcConnID []byte
160 data []byte
161 }{{
162 name: "long header packet",
163 packet: unhex(`
164 80 01020304 04a1a2a3a4 05b1b2b3b4b5 c1
165 `),
166 version: 0x01020304,
167 dstConnID: unhex(`a1a2a3a4`),
168 srcConnID: unhex(`b1b2b3b4b5`),
169 data: unhex(`c1`),
170 }, {
171 name: "zero everything",
172 packet: unhex(`
173 80 00000000 00 00
174 `),
175 version: 0,
176 dstConnID: []byte{},
177 srcConnID: []byte{},
178 data: []byte{},
179 }} {
180 t.Run(test.name, func(t *testing.T) {
181 p, ok := parseGenericLongHeaderPacket(test.packet)
182 if !ok {
183 t.Fatalf("parseGenericLongHeaderPacket() = _, false; want true")
184 }
185 if got, want := p.version, test.version; got != want {
186 t.Errorf("version = %v, want %v", got, want)
187 }
188 if got, want := p.dstConnID, test.dstConnID; !bytes.Equal(got, want) {
189 t.Errorf("Destination Connection ID = {%x}, want {%x}", got, want)
190 }
191 if got, want := p.srcConnID, test.srcConnID; !bytes.Equal(got, want) {
192 t.Errorf("Source Connection ID = {%x}, want {%x}", got, want)
193 }
194 if got, want := p.data, test.data; !bytes.Equal(got, want) {
195 t.Errorf("Data = {%x}, want {%x}", got, want)
196 }
197 })
198 }
199 }
200
201 func TestParseGenericLongHeaderPacketErrors(t *testing.T) {
202 for _, test := range []struct {
203 name string
204 packet []byte
205 }{{
206 name: "short header packet",
207 packet: unhex(`
208 00 01020304 04a1a2a3a4 05b1b2b3b4b5 c1
209 `),
210 }, {
211 name: "packet too short",
212 packet: unhex(`
213 80 000000
214 `),
215 }, {
216 name: "destination id too long",
217 packet: unhex(`
218 80 00000000 02 00
219 `),
220 }, {
221 name: "source id too long",
222 packet: unhex(`
223 80 00000000 00 01
224 `),
225 }} {
226 t.Run(test.name, func(t *testing.T) {
227 _, ok := parseGenericLongHeaderPacket(test.packet)
228 if ok {
229 t.Fatalf("parseGenericLongHeaderPacket() = _, true; want false")
230 }
231 })
232 }
233 }
234
235 func unhex(s string) []byte {
236 b, err := hex.DecodeString(strings.Map(func(c rune) rune {
237 switch c {
238 case ' ', '\t', '\n':
239 return -1
240 }
241 return c
242 }, s))
243 if err != nil {
244 panic(err)
245 }
246 return b
247 }
248
View as plain text