1
2
3
4
5
6
7 package socket_test
8
9 import (
10 "bytes"
11 "fmt"
12 "io/ioutil"
13 "net"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strings"
19 "syscall"
20 "testing"
21
22 "golang.org/x/net/internal/socket"
23 "golang.org/x/net/nettest"
24 )
25
26 func TestSocket(t *testing.T) {
27 t.Run("Option", func(t *testing.T) {
28 testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
29 })
30 }
31
32 func testSocketOption(t *testing.T, so *socket.Option) {
33 c, err := nettest.NewLocalPacketListener("udp")
34 if err != nil {
35 t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
36 }
37 defer c.Close()
38 cc, err := socket.NewConn(c.(net.Conn))
39 if err != nil {
40 t.Fatal(err)
41 }
42 const N = 2048
43 if err := so.SetInt(cc, N); err != nil {
44 t.Fatal(err)
45 }
46 n, err := so.GetInt(cc)
47 if err != nil {
48 t.Fatal(err)
49 }
50 if n < N {
51 t.Fatalf("got %d; want greater than or equal to %d", n, N)
52 }
53 }
54
55 type mockControl struct {
56 Level int
57 Type int
58 Data []byte
59 }
60
61 func TestControlMessage(t *testing.T) {
62 switch runtime.GOOS {
63 case "windows":
64 t.Skipf("not supported on %s", runtime.GOOS)
65 }
66
67 for _, tt := range []struct {
68 cs []mockControl
69 }{
70 {
71 []mockControl{
72 {Level: 1, Type: 1},
73 },
74 },
75 {
76 []mockControl{
77 {Level: 2, Type: 2, Data: []byte{0xfe}},
78 },
79 },
80 {
81 []mockControl{
82 {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
83 },
84 },
85 {
86 []mockControl{
87 {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
88 },
89 },
90 {
91 []mockControl{
92 {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
93 {Level: 2, Type: 2, Data: []byte{0xfe}},
94 },
95 },
96 } {
97 var w []byte
98 var tailPadLen int
99 mm := socket.NewControlMessage([]int{0})
100 for i, c := range tt.cs {
101 m := socket.NewControlMessage([]int{len(c.Data)})
102 l := len(m) - len(mm)
103 if i == len(tt.cs)-1 && l > len(c.Data) {
104 tailPadLen = l - len(c.Data)
105 }
106 w = append(w, m...)
107 }
108
109 var err error
110 ww := make([]byte, len(w))
111 copy(ww, w)
112 m := socket.ControlMessage(ww)
113 for _, c := range tt.cs {
114 if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
115 t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
116 }
117 copy(m.Data(len(c.Data)), c.Data)
118 m = m.Next(len(c.Data))
119 }
120 m = socket.ControlMessage(w)
121 for _, c := range tt.cs {
122 m, err = m.Marshal(c.Level, c.Type, c.Data)
123 if err != nil {
124 t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
125 }
126 }
127 if !bytes.Equal(ww, w) {
128 t.Fatalf("got %#v; want %#v", ww, w)
129 }
130
131 ws := [][]byte{w}
132 if tailPadLen > 0 {
133
134 nopad := w[:len(w)-tailPadLen]
135 ws = append(ws, [][]byte{nopad}...)
136 }
137 for _, w := range ws {
138 ms, err := socket.ControlMessage(w).Parse()
139 if err != nil {
140 t.Fatalf("(%v).Parse() = %v", tt.cs, err)
141 }
142 for i, m := range ms {
143 lvl, typ, dataLen, err := m.ParseHeader()
144 if err != nil {
145 t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
146 }
147 if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
148 t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
149 }
150 }
151 }
152 }
153 }
154
155 func TestUDP(t *testing.T) {
156 switch runtime.GOOS {
157 case "windows":
158 t.Skipf("not supported on %s", runtime.GOOS)
159 }
160
161 c, err := nettest.NewLocalPacketListener("udp")
162 if err != nil {
163 t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
164 }
165 defer c.Close()
166
167 type wrappedConn struct{ *net.UDPConn }
168 cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)})
169 if err != nil {
170 t.Fatal(err)
171 }
172
173
174 cDialed, err := net.Dial("udp", c.LocalAddr().String())
175 if err != nil {
176 t.Fatal(err)
177 }
178 ccDialed, err := socket.NewConn(cDialed)
179 if err != nil {
180 t.Fatal(err)
181 }
182
183 const data = "HELLO-R-U-THERE"
184 messageTests := []struct {
185 name string
186 conn *socket.Conn
187 dest net.Addr
188 }{
189 {
190 name: "Message",
191 conn: cc,
192 dest: c.LocalAddr(),
193 },
194 {
195 name: "Message-dialed",
196 conn: ccDialed,
197 dest: nil,
198 },
199 }
200 for _, tt := range messageTests {
201 t.Run(tt.name, func(t *testing.T) {
202 wm := socket.Message{
203 Buffers: bytes.SplitAfter([]byte(data), []byte("-")),
204 Addr: tt.dest,
205 }
206 if err := tt.conn.SendMsg(&wm, 0); err != nil {
207 t.Fatal(err)
208 }
209 b := make([]byte, 32)
210 rm := socket.Message{
211 Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
212 }
213 if err := cc.RecvMsg(&rm, 0); err != nil {
214 t.Fatal(err)
215 }
216 received := string(b[:rm.N])
217 if received != data {
218 t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data)
219 }
220 })
221 }
222
223 switch runtime.GOOS {
224 case "android", "linux":
225 messagesTests := []struct {
226 name string
227 conn *socket.Conn
228 dest net.Addr
229 }{
230 {
231 name: "Messages",
232 conn: cc,
233 dest: c.LocalAddr(),
234 },
235 {
236 name: "Messages-dialed",
237 conn: ccDialed,
238 dest: nil,
239 },
240 }
241 for _, tt := range messagesTests {
242 t.Run(tt.name, func(t *testing.T) {
243 wmbs := bytes.SplitAfter([]byte(data), []byte("-"))
244 wms := []socket.Message{
245 {Buffers: wmbs[:1], Addr: tt.dest},
246 {Buffers: wmbs[1:], Addr: tt.dest},
247 }
248 n, err := tt.conn.SendMsgs(wms, 0)
249 if err != nil {
250 t.Fatal(err)
251 }
252 if n != len(wms) {
253 t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms))
254 }
255 rmbs := [][]byte{make([]byte, 32), make([]byte, 32)}
256 rms := []socket.Message{
257 {Buffers: [][]byte{rmbs[0]}},
258 {Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}},
259 }
260 nrecv := 0
261 for nrecv < len(rms) {
262 n, err := cc.RecvMsgs(rms[nrecv:], 0)
263 if err != nil {
264 t.Fatal(err)
265 }
266 nrecv += n
267 }
268 received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N])
269 assembled := received0 + received1
270 assembledReordered := received1 + received0
271 if assembled != data && assembledReordered != data {
272 t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data)
273 }
274 })
275 }
276 t.Run("Messages-undialed-no-dst", func(t *testing.T) {
277
278
279 data := []byte("HELLO-R-U-THERE")
280 wmbs := bytes.SplitAfter(data, []byte("-"))
281 wms := []socket.Message{
282 {Buffers: wmbs[:1], Addr: nil},
283 {Buffers: wmbs[1:], Addr: nil},
284 }
285 n, err := cc.SendMsgs(wms, 0)
286 if n != 0 && err == nil {
287 t.Fatal("expected error, destination address required")
288 }
289 })
290 }
291
292
293
294
295
296
297 wm := socket.Message{
298 Buffers: [][]byte{{}},
299 Addr: c.LocalAddr(),
300 }
301 cc.SendMsg(&wm, 0)
302 wms := []socket.Message{
303 {Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
304 }
305 cc.SendMsgs(wms, 0)
306 }
307
308 func BenchmarkUDP(b *testing.B) {
309 c, err := nettest.NewLocalPacketListener("udp")
310 if err != nil {
311 b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
312 }
313 defer c.Close()
314 cc, err := socket.NewConn(c.(net.Conn))
315 if err != nil {
316 b.Fatal(err)
317 }
318 data := []byte("HELLO-R-U-THERE")
319 wm := socket.Message{
320 Buffers: [][]byte{data},
321 Addr: c.LocalAddr(),
322 }
323 rm := socket.Message{
324 Buffers: [][]byte{make([]byte, 128)},
325 OOB: make([]byte, 128),
326 }
327
328 for M := 1; M <= 1<<9; M = M << 1 {
329 b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
330 for i := 0; i < b.N; i++ {
331 for j := 0; j < M; j++ {
332 if err := cc.SendMsg(&wm, 0); err != nil {
333 b.Fatal(err)
334 }
335 if err := cc.RecvMsg(&rm, 0); err != nil {
336 b.Fatal(err)
337 }
338 }
339 }
340 })
341 switch runtime.GOOS {
342 case "android", "linux":
343 wms := make([]socket.Message, M)
344 for i := range wms {
345 wms[i].Buffers = [][]byte{data}
346 wms[i].Addr = c.LocalAddr()
347 }
348 rms := make([]socket.Message, M)
349 for i := range rms {
350 rms[i].Buffers = [][]byte{make([]byte, 128)}
351 rms[i].OOB = make([]byte, 128)
352 }
353 b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
354 for i := 0; i < b.N; i++ {
355 if _, err := cc.SendMsgs(wms, 0); err != nil {
356 b.Fatal(err)
357 }
358 if _, err := cc.RecvMsgs(rms, 0); err != nil {
359 b.Fatal(err)
360 }
361 }
362 })
363 }
364 }
365 }
366
367 func TestRace(t *testing.T) {
368 tests := []string{
369 `
370 package main
371 import (
372 "log"
373 "net"
374
375 "golang.org/x/net/ipv4"
376 )
377
378 var g byte
379
380 func main() {
381 c, err := net.ListenPacket("udp", "127.0.0.1:0")
382 if err != nil {
383 log.Fatalf("ListenPacket: %v", err)
384 }
385 cc := ipv4.NewPacketConn(c)
386 sync := make(chan bool)
387 src := make([]byte, 100)
388 dst := make([]byte, 100)
389 go func() {
390 if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
391 log.Fatalf("WriteTo: %v", err)
392 }
393 }()
394 go func() {
395 if _, _, _, err := cc.ReadFrom(dst); err != nil {
396 log.Fatalf("ReadFrom: %v", err)
397 }
398 sync <- true
399 }()
400 g = dst[0]
401 <-sync
402 }
403 `,
404 `
405 package main
406 import (
407 "log"
408 "net"
409
410 "golang.org/x/net/ipv4"
411 )
412
413 func main() {
414 c, err := net.ListenPacket("udp", "127.0.0.1:0")
415 if err != nil {
416 log.Fatalf("ListenPacket: %v", err)
417 }
418 cc := ipv4.NewPacketConn(c)
419 sync := make(chan bool)
420 src := make([]byte, 100)
421 dst := make([]byte, 100)
422 go func() {
423 if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
424 log.Fatalf("WriteTo: %v", err)
425 }
426 sync <- true
427 }()
428 src[0] = 0
429 go func() {
430 if _, _, _, err := cc.ReadFrom(dst); err != nil {
431 log.Fatalf("ReadFrom: %v", err)
432 }
433 }()
434 <-sync
435 }
436 `,
437 }
438 platforms := map[string]bool{
439 "linux/amd64": true,
440 "linux/ppc64le": true,
441 "linux/arm64": true,
442 }
443 if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
444 t.Skip("skipping test on non-race-enabled host.")
445 }
446 if runtime.Compiler == "gccgo" {
447 t.Skip("skipping race test when built with gccgo")
448 }
449 dir, err := ioutil.TempDir("", "testrace")
450 if err != nil {
451 t.Fatalf("failed to create temp directory: %v", err)
452 }
453 defer os.RemoveAll(dir)
454 goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
455 t.Logf("%s version", goBinary)
456 got, err := exec.Command(goBinary, "version").CombinedOutput()
457 if len(got) > 0 {
458 t.Logf("%s", got)
459 }
460 if err != nil {
461 t.Fatalf("go version failed: %v", err)
462 }
463 for i, test := range tests {
464 t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
465 src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
466 if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
467 t.Fatalf("failed to write file: %v", err)
468 }
469 t.Logf("%s run -race %s", goBinary, src)
470 got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
471 if len(got) > 0 {
472 t.Logf("%s", got)
473 }
474 if strings.Contains(string(got), "-race requires cgo") {
475 t.Log("CGO is not enabled so can't use -race")
476 } else if !strings.Contains(string(got), "WARNING: DATA RACE") {
477 t.Errorf("race not detected for test %d: err:%v", i, err)
478 }
479 })
480 }
481 }
482
View as plain text