1
2
3
4
5
6
7 package test
8
9 import (
10 "bytes"
11 "fmt"
12 "io"
13 "math/rand"
14 "net"
15 "runtime"
16 "testing"
17 "time"
18 )
19
20 type closeWriter interface {
21 CloseWrite() error
22 }
23
24 func testPortForward(t *testing.T, n, listenAddr string) {
25 server := newServer(t)
26 conn := server.Dial(clientConfig())
27 defer conn.Close()
28
29 sshListener, err := conn.Listen(n, listenAddr)
30 if err != nil {
31 if runtime.GOOS == "darwin" && err == io.EOF {
32 t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
33 }
34 t.Fatal(err)
35 }
36
37 errCh := make(chan error, 1)
38
39 go func() {
40 defer close(errCh)
41 sshConn, err := sshListener.Accept()
42 if err != nil {
43 errCh <- fmt.Errorf("listen.Accept failed: %v", err)
44 return
45 }
46 defer sshConn.Close()
47
48 _, err = io.Copy(sshConn, sshConn)
49 if err != nil && err != io.EOF {
50 errCh <- fmt.Errorf("ssh client copy: %v", err)
51 }
52 }()
53
54 forwardedAddr := sshListener.Addr().String()
55 netConn, err := net.Dial(n, forwardedAddr)
56 if err != nil {
57 t.Fatalf("net dial failed: %v", err)
58 }
59
60 readChan := make(chan []byte)
61 go func() {
62 data, _ := io.ReadAll(netConn)
63 readChan <- data
64 }()
65
66
67 data := make([]byte, 100*1000)
68 for i := range data {
69 data[i] = byte(i % 255)
70 }
71
72 var sent []byte
73 for len(sent) < 1000*1000 {
74
75 m := rand.Intn(len(data))
76 n, err := netConn.Write(data[:m])
77 if err != nil {
78 break
79 }
80 sent = append(sent, data[:n]...)
81 }
82 if err := netConn.(closeWriter).CloseWrite(); err != nil {
83 t.Errorf("netConn.CloseWrite: %v", err)
84 }
85
86
87 err = <-errCh
88 if err != nil {
89 t.Fatalf("server: %v", err)
90 }
91
92 read := <-readChan
93
94 if len(sent) != len(read) {
95 t.Fatalf("got %d bytes, want %d", len(read), len(sent))
96 }
97 if bytes.Compare(sent, read) != 0 {
98 t.Fatalf("read back data does not match")
99 }
100
101 if err := sshListener.Close(); err != nil {
102 t.Fatalf("sshListener.Close: %v", err)
103 }
104
105
106 netConn, err = net.Dial(n, forwardedAddr)
107 if err == nil {
108 netConn.Close()
109 t.Errorf("still listening to %s after closing", forwardedAddr)
110 }
111 }
112
113 func TestPortForwardTCP(t *testing.T) {
114 testPortForward(t, "tcp", "localhost:0")
115 }
116
117 func TestPortForwardUnix(t *testing.T) {
118 addr, cleanup := newTempSocket(t)
119 defer cleanup()
120 testPortForward(t, "unix", addr)
121 }
122
123 func testAcceptClose(t *testing.T, n, listenAddr string) {
124 server := newServer(t)
125 conn := server.Dial(clientConfig())
126
127 sshListener, err := conn.Listen(n, listenAddr)
128 if err != nil {
129 if runtime.GOOS == "darwin" && err == io.EOF {
130 t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
131 }
132 t.Fatal(err)
133 }
134
135 quit := make(chan error, 1)
136 go func() {
137 for {
138 c, err := sshListener.Accept()
139 if err != nil {
140 quit <- err
141 break
142 }
143 c.Close()
144 }
145 }()
146 sshListener.Close()
147
148 select {
149 case <-time.After(1 * time.Second):
150 t.Errorf("timeout: listener did not close.")
151 case err := <-quit:
152 t.Logf("quit as expected (error %v)", err)
153 }
154 }
155
156 func TestAcceptCloseTCP(t *testing.T) {
157 testAcceptClose(t, "tcp", "localhost:0")
158 }
159
160 func TestAcceptCloseUnix(t *testing.T) {
161 addr, cleanup := newTempSocket(t)
162 defer cleanup()
163 testAcceptClose(t, "unix", addr)
164 }
165
166
167 func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
168 server := newServer(t)
169 client := server.Dial(clientConfig())
170
171 sshListener, err := client.Listen(n, listenAddr)
172 if err != nil {
173 if runtime.GOOS == "darwin" && err == io.EOF {
174 t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
175 }
176 t.Fatal(err)
177 }
178
179 quit := make(chan error, 1)
180 go func() {
181 for {
182 c, err := sshListener.Accept()
183 if err != nil {
184 quit <- err
185 break
186 }
187 c.Close()
188 }
189 }()
190
191
192
193 server.lastDialConn.Close()
194
195 err = <-quit
196 t.Logf("quit as expected (error %v)", err)
197 }
198
199 func TestPortForwardConnectionCloseTCP(t *testing.T) {
200 testPortForwardConnectionClose(t, "tcp", "localhost:0")
201 }
202
203 func TestPortForwardConnectionCloseUnix(t *testing.T) {
204 addr, cleanup := newTempSocket(t)
205 defer cleanup()
206 testPortForwardConnectionClose(t, "unix", addr)
207 }
208
View as plain text