Source file
src/net/splice_test.go
Documentation: net
1
2
3
4
5
6
7 package net
8
9 import (
10 "io"
11 "log"
12 "os"
13 "os/exec"
14 "strconv"
15 "sync"
16 "testing"
17 "time"
18 )
19
20 func TestSplice(t *testing.T) {
21 t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
22 if !testableNetwork("unixgram") {
23 t.Skip("skipping unix-to-tcp tests")
24 }
25 t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
26 t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
27 t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
28 t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
29 t.Run("no-unixpacket", testSpliceNoUnixpacket)
30 t.Run("no-unixgram", testSpliceNoUnixgram)
31 }
32
33 func testSpliceToFile(t *testing.T, upNet, downNet string) {
34 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
35 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
36 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
37 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
38 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
39 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
40 }
41
42 func testSplice(t *testing.T, upNet, downNet string) {
43 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
44 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
45 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
46 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
47 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
48 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
49 t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
50 t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
51 }
52
53 type spliceTestCase struct {
54 upNet, downNet string
55
56 chunkSize, totalSize int
57 limitReadSize int
58 }
59
60 func (tc spliceTestCase) test(t *testing.T) {
61 clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
62 defer serverUp.Close()
63 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
64 if err != nil {
65 t.Fatal(err)
66 }
67 defer cleanup()
68 clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
69 defer serverDown.Close()
70 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
71 if err != nil {
72 t.Fatal(err)
73 }
74 defer cleanup()
75 var (
76 r io.Reader = serverUp
77 size = tc.totalSize
78 )
79 if tc.limitReadSize > 0 {
80 if tc.limitReadSize < size {
81 size = tc.limitReadSize
82 }
83
84 r = &io.LimitedReader{
85 N: int64(tc.limitReadSize),
86 R: serverUp,
87 }
88 defer serverUp.Close()
89 }
90 n, err := io.Copy(serverDown, r)
91 serverDown.Close()
92 if err != nil {
93 t.Fatal(err)
94 }
95 if want := int64(size); want != n {
96 t.Errorf("want %d bytes spliced, got %d", want, n)
97 }
98
99 if tc.limitReadSize > 0 {
100 wantN := 0
101 if tc.limitReadSize > size {
102 wantN = tc.limitReadSize - size
103 }
104
105 if n := r.(*io.LimitedReader).N; n != int64(wantN) {
106 t.Errorf("r.N = %d, want %d", n, wantN)
107 }
108 }
109 }
110
111 func (tc spliceTestCase) testFile(t *testing.T) {
112 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
113 if err != nil {
114 t.Fatal(err)
115 }
116 defer f.Close()
117
118 client, server := spliceTestSocketPair(t, tc.upNet)
119 defer server.Close()
120
121 cleanup, err := startSpliceClient(client, "w", tc.chunkSize, tc.totalSize)
122 if err != nil {
123 client.Close()
124 t.Fatal("failed to start splice client:", err)
125 }
126 defer cleanup()
127
128 var (
129 r io.Reader = server
130 actualSize = tc.totalSize
131 )
132 if tc.limitReadSize > 0 {
133 if tc.limitReadSize < actualSize {
134 actualSize = tc.limitReadSize
135 }
136
137 r = &io.LimitedReader{
138 N: int64(tc.limitReadSize),
139 R: r,
140 }
141 }
142
143 got, err := io.Copy(f, r)
144 if err != nil {
145 t.Fatalf("failed to ReadFrom with error: %v", err)
146 }
147 if want := int64(actualSize); got != want {
148 t.Errorf("got %d bytes, want %d", got, want)
149 }
150 if tc.limitReadSize > 0 {
151 wantN := 0
152 if tc.limitReadSize > actualSize {
153 wantN = tc.limitReadSize - actualSize
154 }
155
156 if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
157 t.Errorf("r.N = %d, want %d", gotN, wantN)
158 }
159 }
160 }
161
162 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
163
164
165
166 if upNet == "unix" || downNet == "unix" {
167 t.Skip("skipping test on unix socket")
168 }
169
170 clientUp, serverUp := spliceTestSocketPair(t, upNet)
171 defer clientUp.Close()
172 clientDown, serverDown := spliceTestSocketPair(t, downNet)
173 defer clientDown.Close()
174
175 serverUp.Close()
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193 msg := "bye"
194 go func() {
195 serverDown.(io.ReaderFrom).ReadFrom(serverUp)
196 io.WriteString(serverDown, msg)
197 serverDown.Close()
198 }()
199
200 buf := make([]byte, 3)
201 _, err := io.ReadFull(clientDown, buf)
202 if err != nil {
203 t.Errorf("clientDown: %v", err)
204 }
205 if string(buf) != msg {
206 t.Errorf("clientDown got %q, want %q", buf, msg)
207 }
208 }
209
210 func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
211 front := newLocalListener(t, upNet)
212 defer front.Close()
213 back := newLocalListener(t, downNet)
214 defer back.Close()
215
216 var wg sync.WaitGroup
217 wg.Add(2)
218
219 proxy := func() {
220 src, err := front.Accept()
221 if err != nil {
222 return
223 }
224 dst, err := Dial(downNet, back.Addr().String())
225 if err != nil {
226 return
227 }
228 defer dst.Close()
229 defer src.Close()
230 go func() {
231 io.Copy(src, dst)
232 wg.Done()
233 }()
234 go func() {
235 io.Copy(dst, src)
236 wg.Done()
237 }()
238 }
239
240 go proxy()
241
242 toFront, err := Dial(upNet, front.Addr().String())
243 if err != nil {
244 t.Fatal(err)
245 }
246
247 io.WriteString(toFront, "foo")
248 toFront.Close()
249
250 fromProxy, err := back.Accept()
251 if err != nil {
252 t.Fatal(err)
253 }
254 defer fromProxy.Close()
255
256 _, err = io.ReadAll(fromProxy)
257 if err != nil {
258 t.Fatal(err)
259 }
260
261 wg.Wait()
262 }
263
264 func testSpliceNoUnixpacket(t *testing.T) {
265 clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
266 defer clientUp.Close()
267 defer serverUp.Close()
268 clientDown, serverDown := spliceTestSocketPair(t, "tcp")
269 defer clientDown.Close()
270 defer serverDown.Close()
271
272
273
274
275
276
277
278
279 _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
280 if err != nil || handled != false {
281 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
282 }
283 }
284
285 func testSpliceNoUnixgram(t *testing.T) {
286 addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
287 if err != nil {
288 t.Fatal(err)
289 }
290 defer os.Remove(addr.Name)
291 up, err := ListenUnixgram("unixgram", addr)
292 if err != nil {
293 t.Fatal(err)
294 }
295 defer up.Close()
296 clientDown, serverDown := spliceTestSocketPair(t, "tcp")
297 defer clientDown.Close()
298 defer serverDown.Close()
299
300 _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
301 if err != nil || handled != false {
302 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
303 }
304 }
305
306 func BenchmarkSplice(b *testing.B) {
307 testHookUninstaller.Do(uninstallTestHooks)
308
309 b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
310 b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
311 b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
312 }
313
314 func benchSplice(b *testing.B, upNet, downNet string) {
315 for i := 0; i <= 10; i++ {
316 chunkSize := 1 << uint(i+10)
317 tc := spliceTestCase{
318 upNet: upNet,
319 downNet: downNet,
320 chunkSize: chunkSize,
321 }
322
323 b.Run(strconv.Itoa(chunkSize), tc.bench)
324 }
325 }
326
327 func (tc spliceTestCase) bench(b *testing.B) {
328
329 useSplice := true
330
331 clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
332 defer serverUp.Close()
333
334 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
335 if err != nil {
336 b.Fatal(err)
337 }
338 defer cleanup()
339
340 clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
341 defer serverDown.Close()
342
343 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
344 if err != nil {
345 b.Fatal(err)
346 }
347 defer cleanup()
348
349 b.SetBytes(int64(tc.chunkSize))
350 b.ResetTimer()
351
352 if useSplice {
353 _, err := io.Copy(serverDown, serverUp)
354 if err != nil {
355 b.Fatal(err)
356 }
357 } else {
358 type onlyReader struct {
359 io.Reader
360 }
361 _, err := io.Copy(serverDown, onlyReader{serverUp})
362 if err != nil {
363 b.Fatal(err)
364 }
365 }
366 }
367
368 func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
369 t.Helper()
370 ln := newLocalListener(t, net)
371 defer ln.Close()
372 var cerr, serr error
373 acceptDone := make(chan struct{})
374 go func() {
375 server, serr = ln.Accept()
376 acceptDone <- struct{}{}
377 }()
378 client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
379 <-acceptDone
380 if cerr != nil {
381 if server != nil {
382 server.Close()
383 }
384 t.Fatal(cerr)
385 }
386 if serr != nil {
387 if client != nil {
388 client.Close()
389 }
390 t.Fatal(serr)
391 }
392 return client, server
393 }
394
395 func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
396 f, err := conn.(interface{ File() (*os.File, error) }).File()
397 if err != nil {
398 return nil, err
399 }
400
401 cmd := exec.Command(os.Args[0], os.Args[1:]...)
402 cmd.Env = []string{
403 "GO_NET_TEST_SPLICE=1",
404 "GO_NET_TEST_SPLICE_OP=" + op,
405 "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
406 "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
407 "TMPDIR=" + os.Getenv("TMPDIR"),
408 }
409 cmd.ExtraFiles = append(cmd.ExtraFiles, f)
410 cmd.Stdout = os.Stdout
411 cmd.Stderr = os.Stderr
412
413 if err := cmd.Start(); err != nil {
414 return nil, err
415 }
416
417 donec := make(chan struct{})
418 go func() {
419 cmd.Wait()
420 conn.Close()
421 f.Close()
422 close(donec)
423 }()
424
425 return func() {
426 select {
427 case <-donec:
428 case <-time.After(5 * time.Second):
429 log.Printf("killing splice client after 5 second shutdown timeout")
430 cmd.Process.Kill()
431 select {
432 case <-donec:
433 case <-time.After(5 * time.Second):
434 log.Printf("splice client didn't die after 10 seconds")
435 }
436 }
437 }, nil
438 }
439
440 func init() {
441 if os.Getenv("GO_NET_TEST_SPLICE") == "" {
442 return
443 }
444 defer os.Exit(0)
445
446 f := os.NewFile(uintptr(3), "splice-test-conn")
447 defer f.Close()
448
449 conn, err := FileConn(f)
450 if err != nil {
451 log.Fatal(err)
452 }
453
454 var chunkSize int
455 if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
456 log.Fatal(err)
457 }
458 buf := make([]byte, chunkSize)
459
460 var totalSize int
461 if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
462 log.Fatal(err)
463 }
464
465 var fn func([]byte) (int, error)
466 switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
467 case "r":
468 fn = conn.Read
469 case "w":
470 defer conn.Close()
471
472 fn = conn.Write
473 default:
474 log.Fatalf("unknown op %q", op)
475 }
476
477 var n int
478 for count := 0; count < totalSize; count += n {
479 if count+chunkSize > totalSize {
480 buf = buf[:totalSize-count]
481 }
482
483 var err error
484 if n, err = fn(buf); err != nil {
485 return
486 }
487 }
488 }
489
490 func BenchmarkSpliceFile(b *testing.B) {
491 b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
492 b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
493 }
494
495 func benchmarkSpliceFile(b *testing.B, proto string) {
496 for i := 0; i <= 10; i++ {
497 size := 1 << (i + 10)
498 bench := spliceFileBench{
499 proto: proto,
500 chunkSize: size,
501 }
502 b.Run(strconv.Itoa(size), bench.benchSpliceFile)
503 }
504 }
505
506 type spliceFileBench struct {
507 proto string
508 chunkSize int
509 }
510
511 func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
512 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
513 if err != nil {
514 b.Fatal(err)
515 }
516 defer f.Close()
517
518 totalSize := b.N * bench.chunkSize
519
520 client, server := spliceTestSocketPair(b, bench.proto)
521 defer server.Close()
522
523 cleanup, err := startSpliceClient(client, "w", bench.chunkSize, totalSize)
524 if err != nil {
525 client.Close()
526 b.Fatalf("failed to start splice client: %v", err)
527 }
528 defer cleanup()
529
530 b.ReportAllocs()
531 b.SetBytes(int64(bench.chunkSize))
532 b.ResetTimer()
533
534 got, err := io.Copy(f, server)
535 if err != nil {
536 b.Fatalf("failed to ReadFrom with error: %v", err)
537 }
538 if want := int64(totalSize); got != want {
539 b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
540 }
541 }
542
View as plain text