...

Source file src/golang.org/x/crypto/ssh/test/forward_unix_test.go

Documentation: golang.org/x/crypto/ssh/test

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
     6  
     7  package test
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"io"
    13  	"math/rand"
    14  	"net"
    15  	"runtime"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  type closeWriter interface {
    21  	CloseWrite() error
    22  }
    23  
    24  func testPortForward(t *testing.T, n, listenAddr string) {
    25  	server := newServer(t)
    26  	conn := server.Dial(clientConfig())
    27  	defer conn.Close()
    28  
    29  	sshListener, err := conn.Listen(n, listenAddr)
    30  	if err != nil {
    31  		if runtime.GOOS == "darwin" && err == io.EOF {
    32  			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
    33  		}
    34  		t.Fatal(err)
    35  	}
    36  
    37  	errCh := make(chan error, 1)
    38  
    39  	go func() {
    40  		defer close(errCh)
    41  		sshConn, err := sshListener.Accept()
    42  		if err != nil {
    43  			errCh <- fmt.Errorf("listen.Accept failed: %v", err)
    44  			return
    45  		}
    46  		defer sshConn.Close()
    47  
    48  		_, err = io.Copy(sshConn, sshConn)
    49  		if err != nil && err != io.EOF {
    50  			errCh <- fmt.Errorf("ssh client copy: %v", err)
    51  		}
    52  	}()
    53  
    54  	forwardedAddr := sshListener.Addr().String()
    55  	netConn, err := net.Dial(n, forwardedAddr)
    56  	if err != nil {
    57  		t.Fatalf("net dial failed: %v", err)
    58  	}
    59  
    60  	readChan := make(chan []byte)
    61  	go func() {
    62  		data, _ := io.ReadAll(netConn)
    63  		readChan <- data
    64  	}()
    65  
    66  	// Invent some data.
    67  	data := make([]byte, 100*1000)
    68  	for i := range data {
    69  		data[i] = byte(i % 255)
    70  	}
    71  
    72  	var sent []byte
    73  	for len(sent) < 1000*1000 {
    74  		// Send random sized chunks
    75  		m := rand.Intn(len(data))
    76  		n, err := netConn.Write(data[:m])
    77  		if err != nil {
    78  			break
    79  		}
    80  		sent = append(sent, data[:n]...)
    81  	}
    82  	if err := netConn.(closeWriter).CloseWrite(); err != nil {
    83  		t.Errorf("netConn.CloseWrite: %v", err)
    84  	}
    85  
    86  	// Check for errors on server goroutine
    87  	err = <-errCh
    88  	if err != nil {
    89  		t.Fatalf("server: %v", err)
    90  	}
    91  
    92  	read := <-readChan
    93  
    94  	if len(sent) != len(read) {
    95  		t.Fatalf("got %d bytes, want %d", len(read), len(sent))
    96  	}
    97  	if bytes.Compare(sent, read) != 0 {
    98  		t.Fatalf("read back data does not match")
    99  	}
   100  
   101  	if err := sshListener.Close(); err != nil {
   102  		t.Fatalf("sshListener.Close: %v", err)
   103  	}
   104  
   105  	// Check that the forward disappeared.
   106  	netConn, err = net.Dial(n, forwardedAddr)
   107  	if err == nil {
   108  		netConn.Close()
   109  		t.Errorf("still listening to %s after closing", forwardedAddr)
   110  	}
   111  }
   112  
   113  func TestPortForwardTCP(t *testing.T) {
   114  	testPortForward(t, "tcp", "localhost:0")
   115  }
   116  
   117  func TestPortForwardUnix(t *testing.T) {
   118  	addr, cleanup := newTempSocket(t)
   119  	defer cleanup()
   120  	testPortForward(t, "unix", addr)
   121  }
   122  
   123  func testAcceptClose(t *testing.T, n, listenAddr string) {
   124  	server := newServer(t)
   125  	conn := server.Dial(clientConfig())
   126  
   127  	sshListener, err := conn.Listen(n, listenAddr)
   128  	if err != nil {
   129  		if runtime.GOOS == "darwin" && err == io.EOF {
   130  			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
   131  		}
   132  		t.Fatal(err)
   133  	}
   134  
   135  	quit := make(chan error, 1)
   136  	go func() {
   137  		for {
   138  			c, err := sshListener.Accept()
   139  			if err != nil {
   140  				quit <- err
   141  				break
   142  			}
   143  			c.Close()
   144  		}
   145  	}()
   146  	sshListener.Close()
   147  
   148  	select {
   149  	case <-time.After(1 * time.Second):
   150  		t.Errorf("timeout: listener did not close.")
   151  	case err := <-quit:
   152  		t.Logf("quit as expected (error %v)", err)
   153  	}
   154  }
   155  
   156  func TestAcceptCloseTCP(t *testing.T) {
   157  	testAcceptClose(t, "tcp", "localhost:0")
   158  }
   159  
   160  func TestAcceptCloseUnix(t *testing.T) {
   161  	addr, cleanup := newTempSocket(t)
   162  	defer cleanup()
   163  	testAcceptClose(t, "unix", addr)
   164  }
   165  
   166  // Check that listeners exit if the underlying client transport dies.
   167  func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
   168  	server := newServer(t)
   169  	client := server.Dial(clientConfig())
   170  
   171  	sshListener, err := client.Listen(n, listenAddr)
   172  	if err != nil {
   173  		if runtime.GOOS == "darwin" && err == io.EOF {
   174  			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
   175  		}
   176  		t.Fatal(err)
   177  	}
   178  
   179  	quit := make(chan error, 1)
   180  	go func() {
   181  		for {
   182  			c, err := sshListener.Accept()
   183  			if err != nil {
   184  				quit <- err
   185  				break
   186  			}
   187  			c.Close()
   188  		}
   189  	}()
   190  
   191  	// It would be even nicer if we closed the server side, but it
   192  	// is more involved as the fd for that side is dup()ed.
   193  	server.lastDialConn.Close()
   194  
   195  	err = <-quit
   196  	t.Logf("quit as expected (error %v)", err)
   197  }
   198  
   199  func TestPortForwardConnectionCloseTCP(t *testing.T) {
   200  	testPortForwardConnectionClose(t, "tcp", "localhost:0")
   201  }
   202  
   203  func TestPortForwardConnectionCloseUnix(t *testing.T) {
   204  	addr, cleanup := newTempSocket(t)
   205  	defer cleanup()
   206  	testPortForwardConnectionClose(t, "unix", addr)
   207  }
   208  

View as plain text