1
2
3
4
5
6
7 package test
8
9
10
11 import (
12 "context"
13 "fmt"
14 "io"
15 "net"
16 "strings"
17 "testing"
18 )
19
20 type dialTester interface {
21 TestServerConn(t *testing.T, c net.Conn)
22 TestClientConn(t *testing.T, c net.Conn)
23 }
24
25 func testDial(t *testing.T, n, listenAddr string, x dialTester) {
26 server := newServer(t)
27 sshConn := server.Dial(clientConfig())
28 defer sshConn.Close()
29
30 l, err := net.Listen(n, listenAddr)
31 if err != nil {
32 t.Fatalf("Listen: %v", err)
33 }
34 defer l.Close()
35
36 testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
37 go func() {
38 for {
39 c, err := l.Accept()
40 if err != nil {
41 break
42 }
43 x.TestServerConn(t, c)
44
45 io.WriteString(c, testData)
46 c.Close()
47 }
48 }()
49
50 ctx, cancel := context.WithCancel(context.Background())
51 conn, err := sshConn.DialContext(ctx, n, l.Addr().String())
52
53
54 cancel()
55 if err != nil {
56 skipIfIssue64959(t, err)
57 t.Fatalf("Dial: %v", err)
58 }
59 x.TestClientConn(t, conn)
60 defer conn.Close()
61 b, err := io.ReadAll(conn)
62 if err != nil {
63 t.Fatalf("ReadAll: %v", err)
64 }
65 t.Logf("got %q", string(b))
66 if string(b) != testData {
67 t.Fatalf("expected %q, got %q", testData, string(b))
68 }
69 }
70
71 type tcpDialTester struct {
72 listenAddr string
73 }
74
75 func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
76 host := strings.Split(x.listenAddr, ":")[0]
77 prefix := host + ":"
78 if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
79 t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
80 }
81 if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
82 t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
83 }
84 }
85
86 func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
87
88 if c.LocalAddr().String() != "0.0.0.0:0" {
89 t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
90 }
91 if c.RemoteAddr().String() != "0.0.0.0:0" {
92 t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
93 }
94 }
95
96 func TestDialTCP(t *testing.T) {
97 x := &tcpDialTester{
98 listenAddr: "127.0.0.1:0",
99 }
100 testDial(t, "tcp", x.listenAddr, x)
101 }
102
103 type unixDialTester struct {
104 listenAddr string
105 }
106
107 func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
108 if c.LocalAddr().String() != x.listenAddr {
109 t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
110 }
111 if c.RemoteAddr().String() != "@" && c.RemoteAddr().String() != "" {
112 t.Fatalf("expected \"@\" or \"\", got %q", c.RemoteAddr().String())
113 }
114 }
115
116 func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
117 if c.RemoteAddr().String() != x.listenAddr {
118 t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
119 }
120 if c.LocalAddr().String() != "@" {
121 t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
122 }
123 }
124
125 func TestDialUnix(t *testing.T) {
126 addr, cleanup := newTempSocket(t)
127 defer cleanup()
128 x := &unixDialTester{
129 listenAddr: addr,
130 }
131 testDial(t, "unix", x.listenAddr, x)
132 }
133
View as plain text