1
2
3
4
5 package bpf_test
6
7 import (
8 "net"
9 "runtime"
10 "testing"
11 "time"
12
13 "golang.org/x/net/bpf"
14 "golang.org/x/net/ipv4"
15 "golang.org/x/net/ipv6"
16 "golang.org/x/net/nettest"
17 "golang.org/x/sys/cpu"
18 )
19
20
21
22 type virtualMachine interface {
23 Run(in []byte) (int, error)
24 }
25
26
27
28
29
30
31
32
33 func testVM(t *testing.T, filter []bpf.Instruction) (virtualMachine, func(), error) {
34 goVM, err := bpf.NewVM(filter)
35 if err != nil {
36
37
38 return nil, nil, err
39 }
40
41 mvm := &multiVirtualMachine{
42 goVM: goVM,
43
44 t: t,
45 }
46
47
48
49 done := func() {}
50 if runtime.GOOS == "linux" && !cpu.IsBigEndian {
51 osVM, osVMDone := testOSVM(t, filter)
52 done = func() { osVMDone() }
53 mvm.osVM = osVM
54 }
55
56 return mvm, done, nil
57 }
58
59
60 const udpHeaderLen = 8
61
62
63
64 type multiVirtualMachine struct {
65 goVM virtualMachine
66 osVM virtualMachine
67
68 t *testing.T
69 }
70
71 func (mvm *multiVirtualMachine) Run(in []byte) (int, error) {
72 if len(in) < udpHeaderLen {
73 mvm.t.Fatalf("input must be at least length of UDP header (%d), got: %d",
74 udpHeaderLen, len(in))
75 }
76
77
78
79
80 goOut, goErr := mvm.goVM.Run(in)
81 if goOut >= udpHeaderLen {
82 goOut -= udpHeaderLen
83 }
84
85
86
87
88
89 trim := len(in) - udpHeaderLen
90 if goOut > trim {
91 goOut = trim
92 }
93
94
95 if mvm.osVM == nil {
96 return goOut, goErr
97 }
98
99
100
101 osOut, err := mvm.osVM.Run(in[udpHeaderLen:])
102 if err != nil {
103 mvm.t.Fatalf("error while running OS VM: %v", err)
104 }
105
106
107 var mismatch bool
108 if goOut != osOut {
109 mismatch = true
110 mvm.t.Logf("output byte count does not match:\n- go: %v\n- os: %v", goOut, osOut)
111 }
112
113 if mismatch {
114 mvm.t.Fatal("Go BPF and OS BPF packet outputs do not match")
115 }
116
117 return goOut, goErr
118 }
119
120
121
122 type osVirtualMachine struct {
123 l net.PacketConn
124 s net.Conn
125 }
126
127
128
129 func testOSVM(t *testing.T, filter []bpf.Instruction) (virtualMachine, func()) {
130 l, err := nettest.NewLocalPacketListener("udp")
131 if err != nil {
132 t.Fatalf("failed to open OS VM UDP listener: %v", err)
133 }
134
135 prog, err := bpf.Assemble(filter)
136 if err != nil {
137 t.Fatalf("failed to compile BPF program: %v", err)
138 }
139
140 ip := l.LocalAddr().(*net.UDPAddr).IP
141 if ip.To4() != nil && ip.To16() == nil {
142 err = ipv4.NewPacketConn(l).SetBPF(prog)
143 } else {
144 err = ipv6.NewPacketConn(l).SetBPF(prog)
145 }
146 if err != nil {
147 t.Fatalf("failed to attach BPF program to listener: %v", err)
148 }
149
150 s, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
151 if err != nil {
152 t.Fatalf("failed to dial connection to listener: %v", err)
153 }
154
155 done := func() {
156 _ = s.Close()
157 _ = l.Close()
158 }
159
160 return &osVirtualMachine{
161 l: l,
162 s: s,
163 }, done
164 }
165
166
167 func (vm *osVirtualMachine) Run(in []byte) (int, error) {
168 go func() {
169 _, _ = vm.s.Write(in)
170 }()
171
172 vm.l.SetDeadline(time.Now().Add(50 * time.Millisecond))
173
174 var b [512]byte
175 n, _, err := vm.l.ReadFrom(b[:])
176 if err != nil {
177
178
179 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
180 return n, nil
181 }
182
183 return n, err
184 }
185
186 return n, nil
187 }
188
View as plain text