1
2
3
4
5
6
7 package unix_test
8
9 import (
10 "bytes"
11 "errors"
12 "net"
13 "os"
14 "testing"
15 "time"
16
17 "golang.org/x/sys/unix"
18 )
19
20
21
22
23
24 func TestSCMCredentials(t *testing.T) {
25 socketTypeTests := []struct {
26 socketType int
27 dataLen int
28 }{
29 {
30 unix.SOCK_STREAM,
31 1,
32 }, {
33 unix.SOCK_DGRAM,
34 0,
35 },
36 }
37
38 for _, tt := range socketTypeTests {
39 fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0)
40 if err != nil {
41 t.Fatalf("Socketpair: %v", err)
42 }
43
44 err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1)
45 if err != nil {
46 unix.Close(fds[0])
47 unix.Close(fds[1])
48 t.Fatalf("SetsockoptInt: %v", err)
49 }
50
51 srvFile := os.NewFile(uintptr(fds[0]), "server")
52 cliFile := os.NewFile(uintptr(fds[1]), "client")
53 defer srvFile.Close()
54 defer cliFile.Close()
55
56 srv, err := net.FileConn(srvFile)
57 if err != nil {
58 t.Errorf("FileConn: %v", err)
59 return
60 }
61 defer srv.Close()
62
63 cli, err := net.FileConn(cliFile)
64 if err != nil {
65 t.Errorf("FileConn: %v", err)
66 return
67 }
68 defer cli.Close()
69
70 var ucred unix.Ucred
71 ucred.Pid = int32(os.Getpid())
72 ucred.Uid = uint32(os.Getuid())
73 ucred.Gid = uint32(os.Getgid())
74 oob := unix.UnixCredentials(&ucred)
75
76
77 n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
78 if err != nil {
79 t.Fatalf("WriteMsgUnix: %v", err)
80 }
81 if n != 0 {
82 t.Fatalf("WriteMsgUnix n = %d, want 0", n)
83 }
84 if oobn != len(oob) {
85 t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
86 }
87
88 oob2 := make([]byte, 10*len(oob))
89 n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
90 if err != nil {
91 t.Fatalf("ReadMsgUnix: %v", err)
92 }
93 if flags != 0 && flags != unix.MSG_CMSG_CLOEXEC {
94 t.Fatalf("ReadMsgUnix flags = %#x, want 0 or %#x (MSG_CMSG_CLOEXEC)", flags, unix.MSG_CMSG_CLOEXEC)
95 }
96 if n != tt.dataLen {
97 t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
98 }
99 if oobn2 != oobn {
100
101
102 t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
103 }
104 oob2 = oob2[:oobn2]
105 if !bytes.Equal(oob, oob2) {
106 t.Fatal("ReadMsgUnix oob bytes don't match")
107 }
108
109 scm, err := unix.ParseSocketControlMessage(oob2)
110 if err != nil {
111 t.Fatalf("ParseSocketControlMessage: %v", err)
112 }
113 newUcred, err := unix.ParseUnixCredentials(&scm[0])
114 if err != nil {
115 t.Fatalf("ParseUnixCredentials: %v", err)
116 }
117 if *newUcred != ucred {
118 t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
119 }
120 }
121 }
122
123 func TestPktInfo(t *testing.T) {
124 testcases := []struct {
125 network string
126 address *net.UDPAddr
127 }{
128 {"udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}},
129 {"udp6", &net.UDPAddr{IP: net.ParseIP("::1")}},
130 }
131 for _, test := range testcases {
132 t.Run(test.network, func(t *testing.T) {
133 conn, err := net.ListenUDP(test.network, test.address)
134 if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) {
135 t.Skipf("%v is not available", test.address)
136 }
137 if err != nil {
138 t.Fatal("Listen:", err)
139 }
140 defer conn.Close()
141
142 var pktInfo []byte
143 var src net.IP
144 switch test.network {
145 case "udp4":
146 var info4 unix.Inet4Pktinfo
147 src = net.ParseIP("127.0.0.2").To4()
148 copy(info4.Spec_dst[:], src)
149 pktInfo = unix.PktInfo4(&info4)
150
151 case "udp6":
152 var info6 unix.Inet6Pktinfo
153 src = net.ParseIP("2001:0DB8::1")
154 copy(info6.Addr[:], src)
155 pktInfo = unix.PktInfo6(&info6)
156
157 raw, err := conn.SyscallConn()
158 if err != nil {
159 t.Fatal("SyscallConn:", err)
160 }
161 var opErr error
162 err = raw.Control(func(fd uintptr) {
163 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_FREEBIND, 1)
164 })
165 if err != nil {
166 t.Fatal("Control:", err)
167 }
168 if errors.Is(opErr, unix.ENOPROTOOPT) {
169
170
171 t.Skip("IPV6_FREEBIND not supported")
172 }
173 if opErr != nil {
174 t.Fatal("Can't enable IPV6_FREEBIND:", opErr)
175 }
176 }
177
178 msg := []byte{1}
179 addr := conn.LocalAddr().(*net.UDPAddr)
180 _, _, err = conn.WriteMsgUDP(msg, pktInfo, addr)
181 if err != nil {
182 t.Fatal("WriteMsgUDP:", err)
183 }
184
185 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
186 _, _, _, remote, err := conn.ReadMsgUDP(msg, nil)
187 if err != nil {
188 t.Fatal("ReadMsgUDP:", err)
189 }
190
191 if !remote.IP.Equal(src) {
192 t.Errorf("Got packet from %v, want %v", remote.IP, src)
193 }
194 })
195 }
196 }
197
198 func TestParseOrigDstAddr(t *testing.T) {
199 testcases := []struct {
200 network string
201 address *net.UDPAddr
202 }{
203 {"udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}},
204 {"udp6", &net.UDPAddr{IP: net.IPv6loopback}},
205 }
206
207 for _, test := range testcases {
208 t.Run(test.network, func(t *testing.T) {
209 conn, err := net.ListenUDP(test.network, test.address)
210 if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) {
211 t.Skipf("%v is not available", test.address)
212 }
213 if err != nil {
214 t.Fatal("Listen:", err)
215 }
216 defer conn.Close()
217
218 raw, err := conn.SyscallConn()
219 if err != nil {
220 t.Fatal("SyscallConn:", err)
221 }
222
223 var opErr error
224 err = raw.Control(func(fd uintptr) {
225 switch test.network {
226 case "udp4":
227 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1)
228 case "udp6":
229 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
230 }
231 })
232 if err != nil {
233 t.Fatal("Control:", err)
234 }
235 if opErr != nil {
236 t.Fatal("Can't enable RECVORIGDSTADDR:", err)
237 }
238
239 msg := []byte{1}
240 addr := conn.LocalAddr().(*net.UDPAddr)
241 _, err = conn.WriteToUDP(msg, addr)
242 if err != nil {
243 t.Fatal("WriteToUDP:", err)
244 }
245
246 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
247 oob := make([]byte, unix.CmsgSpace(unix.SizeofSockaddrInet6))
248 _, oobn, _, _, err := conn.ReadMsgUDP(msg, oob)
249 if err != nil {
250 t.Fatal("ReadMsgUDP:", err)
251 }
252
253 scms, err := unix.ParseSocketControlMessage(oob[:oobn])
254 if err != nil {
255 t.Fatal("ParseSocketControlMessage:", err)
256 }
257
258 sa, err := unix.ParseOrigDstAddr(&scms[0])
259 if err != nil {
260 t.Fatal("ParseOrigDstAddr:", err)
261 }
262
263 switch test.network {
264 case "udp4":
265 sa4, ok := sa.(*unix.SockaddrInet4)
266 if !ok {
267 t.Fatalf("Got %T not *SockaddrInet4", sa)
268 }
269
270 lo := net.IPv4(127, 0, 0, 1)
271 if addr := net.IP(sa4.Addr[:]); !lo.Equal(addr) {
272 t.Errorf("Got address %v, want %v", addr, lo)
273 }
274
275 if sa4.Port != addr.Port {
276 t.Errorf("Got port %d, want %d", sa4.Port, addr.Port)
277 }
278
279 case "udp6":
280 sa6, ok := sa.(*unix.SockaddrInet6)
281 if !ok {
282 t.Fatalf("Got %T, want *SockaddrInet6", sa)
283 }
284
285 if addr := net.IP(sa6.Addr[:]); !net.IPv6loopback.Equal(addr) {
286 t.Errorf("Got address %v, want %v", addr, net.IPv6loopback)
287 }
288
289 if sa6.Port != addr.Port {
290 t.Errorf("Got port %d, want %d", sa6.Port, addr.Port)
291 }
292 }
293 })
294 }
295 }
296
View as plain text