1
2
3
4
5
6
7 package quic
8
9 import (
10 "bytes"
11 "context"
12 "crypto/tls"
13 "io"
14 "log/slog"
15 "net"
16 "net/netip"
17 "testing"
18 "time"
19
20 "golang.org/x/net/internal/quic/qlog"
21 )
22
23 func TestConnect(t *testing.T) {
24 newLocalConnPair(t, &Config{}, &Config{})
25 }
26
27 func TestStreamTransfer(t *testing.T) {
28 ctx := context.Background()
29 cli, srv := newLocalConnPair(t, &Config{}, &Config{})
30 data := makeTestData(1 << 20)
31
32 srvdone := make(chan struct{})
33 go func() {
34 defer close(srvdone)
35 s, err := srv.AcceptStream(ctx)
36 if err != nil {
37 t.Errorf("AcceptStream: %v", err)
38 return
39 }
40 b, err := io.ReadAll(s)
41 if err != nil {
42 t.Errorf("io.ReadAll(s): %v", err)
43 return
44 }
45 if !bytes.Equal(b, data) {
46 t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
47 }
48 if err := s.Close(); err != nil {
49 t.Errorf("s.Close() = %v", err)
50 }
51 }()
52
53 s, err := cli.NewSendOnlyStream(ctx)
54 if err != nil {
55 t.Fatalf("NewStream: %v", err)
56 }
57 n, err := io.Copy(s, bytes.NewBuffer(data))
58 if n != int64(len(data)) || err != nil {
59 t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
60 }
61 if err := s.Close(); err != nil {
62 t.Fatalf("s.Close() = %v", err)
63 }
64 }
65
66 func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
67 t.Helper()
68 ctx := context.Background()
69 e1 := newLocalEndpoint(t, serverSide, conf1)
70 e2 := newLocalEndpoint(t, clientSide, conf2)
71 c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String())
72 if err != nil {
73 t.Fatal(err)
74 }
75 c1, err := e1.Accept(ctx)
76 if err != nil {
77 t.Fatal(err)
78 }
79 return c2, c1
80 }
81
82 func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint {
83 t.Helper()
84 if conf.TLSConfig == nil {
85 newConf := *conf
86 conf = &newConf
87 conf.TLSConfig = newTestTLSConfig(side)
88 }
89 if conf.QLogLogger == nil {
90 conf.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
91 Level: QLogLevelFrame,
92 Dir: *qlogdir,
93 }))
94 }
95 e, err := Listen("udp", "127.0.0.1:0", conf)
96 if err != nil {
97 t.Fatal(err)
98 }
99 t.Cleanup(func() {
100 e.Close(canceledContext())
101 })
102 return e
103 }
104
105 type testEndpoint struct {
106 t *testing.T
107 e *Endpoint
108 now time.Time
109 recvc chan *datagram
110 idlec chan struct{}
111 conns map[*Conn]*testConn
112 acceptQueue []*testConn
113 configTransportParams []func(*transportParameters)
114 configTestConn []func(*testConn)
115 sentDatagrams [][]byte
116 peerTLSConn *tls.QUICConn
117 lastInitialDstConnID []byte
118 }
119
120 func newTestEndpoint(t *testing.T, config *Config) *testEndpoint {
121 te := &testEndpoint{
122 t: t,
123 now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
124 recvc: make(chan *datagram),
125 idlec: make(chan struct{}),
126 conns: make(map[*Conn]*testConn),
127 }
128 var err error
129 te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te))
130 if err != nil {
131 t.Fatal(err)
132 }
133 t.Cleanup(te.cleanup)
134 return te
135 }
136
137 func (te *testEndpoint) cleanup() {
138 te.e.Close(canceledContext())
139 }
140
141 func (te *testEndpoint) wait() {
142 select {
143 case te.idlec <- struct{}{}:
144 case <-te.e.closec:
145 }
146 for _, tc := range te.conns {
147 tc.wait()
148 }
149 }
150
151
152
153 func (te *testEndpoint) accept() *testConn {
154 if len(te.acceptQueue) == 0 {
155 te.t.Fatalf("accept: expected available conn, but found none")
156 }
157 tc := te.acceptQueue[0]
158 te.acceptQueue = te.acceptQueue[1:]
159 return tc
160 }
161
162 func (te *testEndpoint) write(d *datagram) {
163 te.recvc <- d
164 te.wait()
165 }
166
167 var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
168
169 func (te *testEndpoint) writeDatagram(d *testDatagram) {
170 te.t.Helper()
171 logDatagram(te.t, "<- endpoint under test receives", d)
172 var buf []byte
173 for _, p := range d.packets {
174 tc := te.connForDestination(p.dstConnID)
175 if p.ptype != packetTypeRetry && tc != nil {
176 space := spaceForPacketType(p.ptype)
177 if p.num >= tc.peerNextPacketNum[space] {
178 tc.peerNextPacketNum[space] = p.num + 1
179 }
180 }
181 if p.ptype == packetTypeInitial {
182 te.lastInitialDstConnID = p.dstConnID
183 }
184 pad := 0
185 if p.ptype == packetType1RTT {
186 pad = d.paddedSize - len(buf)
187 }
188 buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...)
189 }
190 for len(buf) < d.paddedSize {
191 buf = append(buf, 0)
192 }
193 addr := d.addr
194 if !addr.IsValid() {
195 addr = testClientAddr
196 }
197 te.write(&datagram{
198 b: buf,
199 addr: addr,
200 })
201 }
202
203 func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn {
204 for _, tc := range te.conns {
205 for _, loc := range tc.conn.connIDState.local {
206 if bytes.Equal(loc.cid, dstConnID) {
207 return tc
208 }
209 }
210 }
211 return nil
212 }
213
214 func (te *testEndpoint) connForSource(srcConnID []byte) *testConn {
215 for _, tc := range te.conns {
216 for _, loc := range tc.conn.connIDState.remote {
217 if bytes.Equal(loc.cid, srcConnID) {
218 return tc
219 }
220 }
221 }
222 return nil
223 }
224
225 func (te *testEndpoint) read() []byte {
226 te.t.Helper()
227 te.wait()
228 if len(te.sentDatagrams) == 0 {
229 return nil
230 }
231 d := te.sentDatagrams[0]
232 te.sentDatagrams = te.sentDatagrams[1:]
233 return d
234 }
235
236 func (te *testEndpoint) readDatagram() *testDatagram {
237 te.t.Helper()
238 buf := te.read()
239 if buf == nil {
240 return nil
241 }
242 p, _ := parseGenericLongHeaderPacket(buf)
243 tc := te.connForSource(p.dstConnID)
244 d := parseTestDatagram(te.t, te, tc, buf)
245 logDatagram(te.t, "-> endpoint under test sends", d)
246 return d
247 }
248
249
250 func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
251 te.t.Helper()
252 got := te.readDatagram()
253 if !datagramEqual(got, want) {
254 te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
255 }
256 }
257
258
259 func (te *testEndpoint) wantIdle(expectation string) {
260 if got := te.readDatagram(); got != nil {
261 te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
262 }
263 }
264
265
266 func (te *testEndpoint) advance(d time.Duration) {
267 te.t.Helper()
268 te.advanceTo(te.now.Add(d))
269 }
270
271
272 func (te *testEndpoint) advanceTo(now time.Time) {
273 te.t.Helper()
274 if te.now.After(now) {
275 te.t.Fatalf("time moved backwards: %v -> %v", te.now, now)
276 }
277 te.now = now
278 for _, tc := range te.conns {
279 if !tc.timer.After(te.now) {
280 tc.conn.sendMsg(timerEvent{})
281 tc.wait()
282 }
283 }
284 }
285
286
287 type testEndpointHooks testEndpoint
288
289 func (te *testEndpointHooks) timeNow() time.Time {
290 return te.now
291 }
292
293 func (te *testEndpointHooks) newConn(c *Conn) {
294 tc := newTestConnForConn(te.t, (*testEndpoint)(te), c)
295 te.conns[c] = tc
296 }
297
298
299 type testEndpointUDPConn testEndpoint
300
301 func (te *testEndpointUDPConn) Close() error {
302 close(te.recvc)
303 return nil
304 }
305
306 func (te *testEndpointUDPConn) LocalAddr() net.Addr {
307 return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
308 }
309
310 func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
311 for {
312 select {
313 case d, ok := <-te.recvc:
314 if !ok {
315 return 0, 0, 0, netip.AddrPort{}, io.EOF
316 }
317 n = copy(b, d.b)
318 return n, 0, 0, d.addr, nil
319 case <-te.idlec:
320 }
321 }
322 }
323
324 func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
325 te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...))
326 return len(b), nil
327 }
328
View as plain text