...

Source file src/golang.org/x/crypto/ssh/session_test.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  // Session tests.
     8  
     9  import (
    10  	"bytes"
    11  	crypto_rand "crypto/rand"
    12  	"errors"
    13  	"io"
    14  	"math/rand"
    15  	"net"
    16  	"sync"
    17  	"testing"
    18  
    19  	"golang.org/x/crypto/ssh/terminal"
    20  )
    21  
    22  type serverType func(Channel, <-chan *Request, *testing.T)
    23  
    24  // dial constructs a new test server and returns a *ClientConn.
    25  func dial(handler serverType, t *testing.T) *Client {
    26  	c1, c2, err := netPipe()
    27  	if err != nil {
    28  		t.Fatalf("netPipe: %v", err)
    29  	}
    30  
    31  	var wg sync.WaitGroup
    32  	t.Cleanup(wg.Wait)
    33  	wg.Add(1)
    34  	go func() {
    35  		defer func() {
    36  			c1.Close()
    37  			wg.Done()
    38  		}()
    39  		conf := ServerConfig{
    40  			NoClientAuth: true,
    41  		}
    42  		conf.AddHostKey(testSigners["rsa"])
    43  
    44  		conn, chans, reqs, err := NewServerConn(c1, &conf)
    45  		if err != nil {
    46  			t.Errorf("Unable to handshake: %v", err)
    47  			return
    48  		}
    49  		wg.Add(1)
    50  		go func() {
    51  			DiscardRequests(reqs)
    52  			wg.Done()
    53  		}()
    54  
    55  		for newCh := range chans {
    56  			if newCh.ChannelType() != "session" {
    57  				newCh.Reject(UnknownChannelType, "unknown channel type")
    58  				continue
    59  			}
    60  
    61  			ch, inReqs, err := newCh.Accept()
    62  			if err != nil {
    63  				t.Errorf("Accept: %v", err)
    64  				continue
    65  			}
    66  			wg.Add(1)
    67  			go func() {
    68  				handler(ch, inReqs, t)
    69  				wg.Done()
    70  			}()
    71  		}
    72  		if err := conn.Wait(); err != io.EOF {
    73  			t.Logf("server exit reason: %v", err)
    74  		}
    75  	}()
    76  
    77  	config := &ClientConfig{
    78  		User:            "testuser",
    79  		HostKeyCallback: InsecureIgnoreHostKey(),
    80  	}
    81  
    82  	conn, chans, reqs, err := NewClientConn(c2, "", config)
    83  	if err != nil {
    84  		t.Fatalf("unable to dial remote side: %v", err)
    85  	}
    86  
    87  	return NewClient(conn, chans, reqs)
    88  }
    89  
    90  // Test a simple string is returned to session.Stdout.
    91  func TestSessionShell(t *testing.T) {
    92  	conn := dial(shellHandler, t)
    93  	defer conn.Close()
    94  	session, err := conn.NewSession()
    95  	if err != nil {
    96  		t.Fatalf("Unable to request new session: %v", err)
    97  	}
    98  	defer session.Close()
    99  	stdout := new(bytes.Buffer)
   100  	session.Stdout = stdout
   101  	if err := session.Shell(); err != nil {
   102  		t.Fatalf("Unable to execute command: %s", err)
   103  	}
   104  	if err := session.Wait(); err != nil {
   105  		t.Fatalf("Remote command did not exit cleanly: %v", err)
   106  	}
   107  	actual := stdout.String()
   108  	if actual != "golang" {
   109  		t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
   110  	}
   111  }
   112  
   113  // TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
   114  
   115  // Test a simple string is returned via StdoutPipe.
   116  func TestSessionStdoutPipe(t *testing.T) {
   117  	conn := dial(shellHandler, t)
   118  	defer conn.Close()
   119  	session, err := conn.NewSession()
   120  	if err != nil {
   121  		t.Fatalf("Unable to request new session: %v", err)
   122  	}
   123  	defer session.Close()
   124  	stdout, err := session.StdoutPipe()
   125  	if err != nil {
   126  		t.Fatalf("Unable to request StdoutPipe(): %v", err)
   127  	}
   128  	var buf bytes.Buffer
   129  	if err := session.Shell(); err != nil {
   130  		t.Fatalf("Unable to execute command: %v", err)
   131  	}
   132  	done := make(chan bool, 1)
   133  	go func() {
   134  		if _, err := io.Copy(&buf, stdout); err != nil {
   135  			t.Errorf("Copy of stdout failed: %v", err)
   136  		}
   137  		done <- true
   138  	}()
   139  	if err := session.Wait(); err != nil {
   140  		t.Fatalf("Remote command did not exit cleanly: %v", err)
   141  	}
   142  	<-done
   143  	actual := buf.String()
   144  	if actual != "golang" {
   145  		t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
   146  	}
   147  }
   148  
   149  // Test that a simple string is returned via the Output helper,
   150  // and that stderr is discarded.
   151  func TestSessionOutput(t *testing.T) {
   152  	conn := dial(fixedOutputHandler, t)
   153  	defer conn.Close()
   154  	session, err := conn.NewSession()
   155  	if err != nil {
   156  		t.Fatalf("Unable to request new session: %v", err)
   157  	}
   158  	defer session.Close()
   159  
   160  	buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
   161  	if err != nil {
   162  		t.Error("Remote command did not exit cleanly:", err)
   163  	}
   164  	w := "this-is-stdout."
   165  	g := string(buf)
   166  	if g != w {
   167  		t.Error("Remote command did not return expected string:")
   168  		t.Logf("want %q", w)
   169  		t.Logf("got  %q", g)
   170  	}
   171  }
   172  
   173  // Test that both stdout and stderr are returned
   174  // via the CombinedOutput helper.
   175  func TestSessionCombinedOutput(t *testing.T) {
   176  	conn := dial(fixedOutputHandler, t)
   177  	defer conn.Close()
   178  	session, err := conn.NewSession()
   179  	if err != nil {
   180  		t.Fatalf("Unable to request new session: %v", err)
   181  	}
   182  	defer session.Close()
   183  
   184  	buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
   185  	if err != nil {
   186  		t.Error("Remote command did not exit cleanly:", err)
   187  	}
   188  	const stdout = "this-is-stdout."
   189  	const stderr = "this-is-stderr."
   190  	g := string(buf)
   191  	if g != stdout+stderr && g != stderr+stdout {
   192  		t.Error("Remote command did not return expected string:")
   193  		t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
   194  		t.Logf("got  %q", g)
   195  	}
   196  }
   197  
   198  // Test non-0 exit status is returned correctly.
   199  func TestExitStatusNonZero(t *testing.T) {
   200  	conn := dial(exitStatusNonZeroHandler, t)
   201  	defer conn.Close()
   202  	session, err := conn.NewSession()
   203  	if err != nil {
   204  		t.Fatalf("Unable to request new session: %v", err)
   205  	}
   206  	defer session.Close()
   207  	if err := session.Shell(); err != nil {
   208  		t.Fatalf("Unable to execute command: %v", err)
   209  	}
   210  	err = session.Wait()
   211  	if err == nil {
   212  		t.Fatalf("expected command to fail but it didn't")
   213  	}
   214  	e, ok := err.(*ExitError)
   215  	if !ok {
   216  		t.Fatalf("expected *ExitError but got %T", err)
   217  	}
   218  	if e.ExitStatus() != 15 {
   219  		t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
   220  	}
   221  }
   222  
   223  // Test 0 exit status is returned correctly.
   224  func TestExitStatusZero(t *testing.T) {
   225  	conn := dial(exitStatusZeroHandler, t)
   226  	defer conn.Close()
   227  	session, err := conn.NewSession()
   228  	if err != nil {
   229  		t.Fatalf("Unable to request new session: %v", err)
   230  	}
   231  	defer session.Close()
   232  
   233  	if err := session.Shell(); err != nil {
   234  		t.Fatalf("Unable to execute command: %v", err)
   235  	}
   236  	err = session.Wait()
   237  	if err != nil {
   238  		t.Fatalf("expected nil but got %v", err)
   239  	}
   240  }
   241  
   242  // Test exit signal and status are both returned correctly.
   243  func TestExitSignalAndStatus(t *testing.T) {
   244  	conn := dial(exitSignalAndStatusHandler, t)
   245  	defer conn.Close()
   246  	session, err := conn.NewSession()
   247  	if err != nil {
   248  		t.Fatalf("Unable to request new session: %v", err)
   249  	}
   250  	defer session.Close()
   251  	if err := session.Shell(); err != nil {
   252  		t.Fatalf("Unable to execute command: %v", err)
   253  	}
   254  	err = session.Wait()
   255  	if err == nil {
   256  		t.Fatalf("expected command to fail but it didn't")
   257  	}
   258  	e, ok := err.(*ExitError)
   259  	if !ok {
   260  		t.Fatalf("expected *ExitError but got %T", err)
   261  	}
   262  	if e.Signal() != "TERM" || e.ExitStatus() != 15 {
   263  		t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   264  	}
   265  }
   266  
   267  // Test exit signal and status are both returned correctly.
   268  func TestKnownExitSignalOnly(t *testing.T) {
   269  	conn := dial(exitSignalHandler, t)
   270  	defer conn.Close()
   271  	session, err := conn.NewSession()
   272  	if err != nil {
   273  		t.Fatalf("Unable to request new session: %v", err)
   274  	}
   275  	defer session.Close()
   276  	if err := session.Shell(); err != nil {
   277  		t.Fatalf("Unable to execute command: %v", err)
   278  	}
   279  	err = session.Wait()
   280  	if err == nil {
   281  		t.Fatalf("expected command to fail but it didn't")
   282  	}
   283  	e, ok := err.(*ExitError)
   284  	if !ok {
   285  		t.Fatalf("expected *ExitError but got %T", err)
   286  	}
   287  	if e.Signal() != "TERM" || e.ExitStatus() != 143 {
   288  		t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   289  	}
   290  }
   291  
   292  // Test exit signal and status are both returned correctly.
   293  func TestUnknownExitSignal(t *testing.T) {
   294  	conn := dial(exitSignalUnknownHandler, t)
   295  	defer conn.Close()
   296  	session, err := conn.NewSession()
   297  	if err != nil {
   298  		t.Fatalf("Unable to request new session: %v", err)
   299  	}
   300  	defer session.Close()
   301  	if err := session.Shell(); err != nil {
   302  		t.Fatalf("Unable to execute command: %v", err)
   303  	}
   304  	err = session.Wait()
   305  	if err == nil {
   306  		t.Fatalf("expected command to fail but it didn't")
   307  	}
   308  	e, ok := err.(*ExitError)
   309  	if !ok {
   310  		t.Fatalf("expected *ExitError but got %T", err)
   311  	}
   312  	if e.Signal() != "SYS" || e.ExitStatus() != 128 {
   313  		t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   314  	}
   315  }
   316  
   317  func TestExitWithoutStatusOrSignal(t *testing.T) {
   318  	conn := dial(exitWithoutSignalOrStatus, t)
   319  	defer conn.Close()
   320  	session, err := conn.NewSession()
   321  	if err != nil {
   322  		t.Fatalf("Unable to request new session: %v", err)
   323  	}
   324  	defer session.Close()
   325  	if err := session.Shell(); err != nil {
   326  		t.Fatalf("Unable to execute command: %v", err)
   327  	}
   328  	err = session.Wait()
   329  	if err == nil {
   330  		t.Fatalf("expected command to fail but it didn't")
   331  	}
   332  	if _, ok := err.(*ExitMissingError); !ok {
   333  		t.Fatalf("got %T want *ExitMissingError", err)
   334  	}
   335  }
   336  
   337  // windowTestBytes is the number of bytes that we'll send to the SSH server.
   338  const windowTestBytes = 16000 * 200
   339  
   340  // TestServerWindow writes random data to the server. The server is expected to echo
   341  // the same data back, which is compared against the original.
   342  func TestServerWindow(t *testing.T) {
   343  	origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
   344  	io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
   345  	origBytes := origBuf.Bytes()
   346  
   347  	conn := dial(echoHandler, t)
   348  	defer conn.Close()
   349  	session, err := conn.NewSession()
   350  	if err != nil {
   351  		t.Fatal(err)
   352  	}
   353  	defer session.Close()
   354  
   355  	serverStdin, err := session.StdinPipe()
   356  	if err != nil {
   357  		t.Fatalf("StdinPipe failed: %v", err)
   358  	}
   359  
   360  	result := make(chan []byte)
   361  	go func() {
   362  		defer close(result)
   363  		echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
   364  		serverStdout, err := session.StdoutPipe()
   365  		if err != nil {
   366  			t.Errorf("StdoutPipe failed: %v", err)
   367  			return
   368  		}
   369  		n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
   370  		if err != nil && err != io.EOF {
   371  			t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
   372  		}
   373  		result <- echoedBuf.Bytes()
   374  	}()
   375  
   376  	written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
   377  	if err != nil {
   378  		t.Errorf("failed to copy origBuf to serverStdin: %v", err)
   379  	} else if written != windowTestBytes {
   380  		t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes)
   381  	}
   382  
   383  	echoedBytes := <-result
   384  
   385  	if !bytes.Equal(origBytes, echoedBytes) {
   386  		t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
   387  	}
   388  }
   389  
   390  // Verify the client can handle a keepalive packet from the server.
   391  func TestClientHandlesKeepalives(t *testing.T) {
   392  	conn := dial(channelKeepaliveSender, t)
   393  	defer conn.Close()
   394  	session, err := conn.NewSession()
   395  	if err != nil {
   396  		t.Fatal(err)
   397  	}
   398  	defer session.Close()
   399  	if err := session.Shell(); err != nil {
   400  		t.Fatalf("Unable to execute command: %v", err)
   401  	}
   402  	err = session.Wait()
   403  	if err != nil {
   404  		t.Fatalf("expected nil but got: %v", err)
   405  	}
   406  }
   407  
   408  type exitStatusMsg struct {
   409  	Status uint32
   410  }
   411  
   412  type exitSignalMsg struct {
   413  	Signal     string
   414  	CoreDumped bool
   415  	Errmsg     string
   416  	Lang       string
   417  }
   418  
   419  func handleTerminalRequests(in <-chan *Request) {
   420  	for req := range in {
   421  		ok := false
   422  		switch req.Type {
   423  		case "shell":
   424  			ok = true
   425  			if len(req.Payload) > 0 {
   426  				// We don't accept any commands, only the default shell.
   427  				ok = false
   428  			}
   429  		case "env":
   430  			ok = true
   431  		}
   432  		req.Reply(ok, nil)
   433  	}
   434  }
   435  
   436  func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
   437  	term := terminal.NewTerminal(ch, prompt)
   438  	go handleTerminalRequests(in)
   439  	return term
   440  }
   441  
   442  func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
   443  	defer ch.Close()
   444  	// this string is returned to stdout
   445  	shell := newServerShell(ch, in, "> ")
   446  	readLine(shell, t)
   447  	sendStatus(0, ch, t)
   448  }
   449  
   450  func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
   451  	defer ch.Close()
   452  	shell := newServerShell(ch, in, "> ")
   453  	readLine(shell, t)
   454  	sendStatus(15, ch, t)
   455  }
   456  
   457  func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
   458  	defer ch.Close()
   459  	shell := newServerShell(ch, in, "> ")
   460  	readLine(shell, t)
   461  	sendStatus(15, ch, t)
   462  	sendSignal("TERM", ch, t)
   463  }
   464  
   465  func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
   466  	defer ch.Close()
   467  	shell := newServerShell(ch, in, "> ")
   468  	readLine(shell, t)
   469  	sendSignal("TERM", ch, t)
   470  }
   471  
   472  func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
   473  	defer ch.Close()
   474  	shell := newServerShell(ch, in, "> ")
   475  	readLine(shell, t)
   476  	sendSignal("SYS", ch, t)
   477  }
   478  
   479  func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
   480  	defer ch.Close()
   481  	shell := newServerShell(ch, in, "> ")
   482  	readLine(shell, t)
   483  }
   484  
   485  func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
   486  	defer ch.Close()
   487  	// this string is returned to stdout
   488  	shell := newServerShell(ch, in, "golang")
   489  	readLine(shell, t)
   490  	sendStatus(0, ch, t)
   491  }
   492  
   493  // Ignores the command, writes fixed strings to stderr and stdout.
   494  // Strings are "this-is-stdout." and "this-is-stderr.".
   495  func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
   496  	defer ch.Close()
   497  	_, err := ch.Read(nil)
   498  
   499  	req, ok := <-in
   500  	if !ok {
   501  		t.Fatalf("error: expected channel request, got: %#v", err)
   502  		return
   503  	}
   504  
   505  	// ignore request, always send some text
   506  	req.Reply(true, nil)
   507  
   508  	_, err = io.WriteString(ch, "this-is-stdout.")
   509  	if err != nil {
   510  		t.Fatalf("error writing on server: %v", err)
   511  	}
   512  	_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
   513  	if err != nil {
   514  		t.Fatalf("error writing on server: %v", err)
   515  	}
   516  	sendStatus(0, ch, t)
   517  }
   518  
   519  func readLine(shell *terminal.Terminal, t *testing.T) {
   520  	if _, err := shell.ReadLine(); err != nil && err != io.EOF {
   521  		t.Errorf("unable to read line: %v", err)
   522  	}
   523  }
   524  
   525  func sendStatus(status uint32, ch Channel, t *testing.T) {
   526  	msg := exitStatusMsg{
   527  		Status: status,
   528  	}
   529  	if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
   530  		t.Errorf("unable to send status: %v", err)
   531  	}
   532  }
   533  
   534  func sendSignal(signal string, ch Channel, t *testing.T) {
   535  	sig := exitSignalMsg{
   536  		Signal:     signal,
   537  		CoreDumped: false,
   538  		Errmsg:     "Process terminated",
   539  		Lang:       "en-GB-oed",
   540  	}
   541  	if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
   542  		t.Errorf("unable to send signal: %v", err)
   543  	}
   544  }
   545  
   546  func discardHandler(ch Channel, t *testing.T) {
   547  	defer ch.Close()
   548  	io.Copy(io.Discard, ch)
   549  }
   550  
   551  func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
   552  	defer ch.Close()
   553  	if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
   554  		t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
   555  	}
   556  }
   557  
   558  // copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
   559  // buffer size to exercise more code paths.
   560  func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
   561  	var (
   562  		buf       = make([]byte, 32*1024)
   563  		written   int
   564  		remaining = n
   565  	)
   566  	for remaining > 0 {
   567  		l := rand.Intn(1 << 15)
   568  		if remaining < l {
   569  			l = remaining
   570  		}
   571  		nr, er := src.Read(buf[:l])
   572  		nw, ew := dst.Write(buf[:nr])
   573  		remaining -= nw
   574  		written += nw
   575  		if ew != nil {
   576  			return written, ew
   577  		}
   578  		if nr != nw {
   579  			return written, io.ErrShortWrite
   580  		}
   581  		if er != nil && er != io.EOF {
   582  			return written, er
   583  		}
   584  	}
   585  	return written, nil
   586  }
   587  
   588  func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
   589  	defer ch.Close()
   590  	shell := newServerShell(ch, in, "> ")
   591  	readLine(shell, t)
   592  	if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
   593  		t.Errorf("unable to send channel keepalive request: %v", err)
   594  	}
   595  	sendStatus(0, ch, t)
   596  }
   597  
   598  func TestClientWriteEOF(t *testing.T) {
   599  	conn := dial(simpleEchoHandler, t)
   600  	defer conn.Close()
   601  
   602  	session, err := conn.NewSession()
   603  	if err != nil {
   604  		t.Fatal(err)
   605  	}
   606  	defer session.Close()
   607  	stdin, err := session.StdinPipe()
   608  	if err != nil {
   609  		t.Fatalf("StdinPipe failed: %v", err)
   610  	}
   611  	stdout, err := session.StdoutPipe()
   612  	if err != nil {
   613  		t.Fatalf("StdoutPipe failed: %v", err)
   614  	}
   615  
   616  	data := []byte(`0000`)
   617  	_, err = stdin.Write(data)
   618  	if err != nil {
   619  		t.Fatalf("Write failed: %v", err)
   620  	}
   621  	stdin.Close()
   622  
   623  	res, err := io.ReadAll(stdout)
   624  	if err != nil {
   625  		t.Fatalf("Read failed: %v", err)
   626  	}
   627  
   628  	if !bytes.Equal(data, res) {
   629  		t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
   630  	}
   631  }
   632  
   633  func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
   634  	defer ch.Close()
   635  	data, err := io.ReadAll(ch)
   636  	if err != nil {
   637  		t.Errorf("handler read error: %v", err)
   638  	}
   639  	_, err = ch.Write(data)
   640  	if err != nil {
   641  		t.Errorf("handler write error: %v", err)
   642  	}
   643  }
   644  
   645  func TestSessionID(t *testing.T) {
   646  	c1, c2, err := netPipe()
   647  	if err != nil {
   648  		t.Fatalf("netPipe: %v", err)
   649  	}
   650  	defer c1.Close()
   651  	defer c2.Close()
   652  
   653  	serverID := make(chan []byte, 1)
   654  	clientID := make(chan []byte, 1)
   655  
   656  	serverConf := &ServerConfig{
   657  		NoClientAuth: true,
   658  	}
   659  	serverConf.AddHostKey(testSigners["ecdsa"])
   660  	clientConf := &ClientConfig{
   661  		HostKeyCallback: InsecureIgnoreHostKey(),
   662  		User:            "user",
   663  	}
   664  
   665  	var wg sync.WaitGroup
   666  	t.Cleanup(wg.Wait)
   667  
   668  	srvErrCh := make(chan error, 1)
   669  	wg.Add(1)
   670  	go func() {
   671  		defer wg.Done()
   672  		conn, chans, reqs, err := NewServerConn(c1, serverConf)
   673  		srvErrCh <- err
   674  		if err != nil {
   675  			return
   676  		}
   677  		serverID <- conn.SessionID()
   678  		wg.Add(1)
   679  		go func() {
   680  			DiscardRequests(reqs)
   681  			wg.Done()
   682  		}()
   683  		for ch := range chans {
   684  			ch.Reject(Prohibited, "")
   685  		}
   686  	}()
   687  
   688  	cliErrCh := make(chan error, 1)
   689  	wg.Add(1)
   690  	go func() {
   691  		defer wg.Done()
   692  		conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
   693  		cliErrCh <- err
   694  		if err != nil {
   695  			return
   696  		}
   697  		clientID <- conn.SessionID()
   698  		wg.Add(1)
   699  		go func() {
   700  			DiscardRequests(reqs)
   701  			wg.Done()
   702  		}()
   703  		for ch := range chans {
   704  			ch.Reject(Prohibited, "")
   705  		}
   706  	}()
   707  
   708  	if err := <-srvErrCh; err != nil {
   709  		t.Fatalf("server handshake: %v", err)
   710  	}
   711  
   712  	if err := <-cliErrCh; err != nil {
   713  		t.Fatalf("client handshake: %v", err)
   714  	}
   715  
   716  	s := <-serverID
   717  	c := <-clientID
   718  	if bytes.Compare(s, c) != 0 {
   719  		t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
   720  	} else if len(s) == 0 {
   721  		t.Errorf("client and server SessionID were empty.")
   722  	}
   723  }
   724  
   725  type noReadConn struct {
   726  	readSeen bool
   727  	net.Conn
   728  }
   729  
   730  func (c *noReadConn) Close() error {
   731  	return nil
   732  }
   733  
   734  func (c *noReadConn) Read(b []byte) (int, error) {
   735  	c.readSeen = true
   736  	return 0, errors.New("noReadConn error")
   737  }
   738  
   739  func TestInvalidServerConfiguration(t *testing.T) {
   740  	c1, c2, err := netPipe()
   741  	if err != nil {
   742  		t.Fatalf("netPipe: %v", err)
   743  	}
   744  	defer c1.Close()
   745  	defer c2.Close()
   746  
   747  	serveConn := noReadConn{Conn: c1}
   748  	serverConf := &ServerConfig{}
   749  
   750  	NewServerConn(&serveConn, serverConf)
   751  	if serveConn.readSeen {
   752  		t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
   753  	}
   754  
   755  	serverConf.AddHostKey(testSigners["ecdsa"])
   756  
   757  	NewServerConn(&serveConn, serverConf)
   758  	if serveConn.readSeen {
   759  		t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
   760  	}
   761  }
   762  
   763  func TestHostKeyAlgorithms(t *testing.T) {
   764  	serverConf := &ServerConfig{
   765  		NoClientAuth: true,
   766  	}
   767  	serverConf.AddHostKey(testSigners["rsa"])
   768  	serverConf.AddHostKey(testSigners["ecdsa"])
   769  
   770  	var wg sync.WaitGroup
   771  	t.Cleanup(wg.Wait)
   772  	connect := func(clientConf *ClientConfig, want string) {
   773  		var alg string
   774  		clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
   775  			alg = key.Type()
   776  			return nil
   777  		}
   778  		c1, c2, err := netPipe()
   779  		if err != nil {
   780  			t.Fatalf("netPipe: %v", err)
   781  		}
   782  		defer c1.Close()
   783  		defer c2.Close()
   784  
   785  		wg.Add(1)
   786  		go func() {
   787  			NewServerConn(c1, serverConf)
   788  			wg.Done()
   789  		}()
   790  		_, _, _, err = NewClientConn(c2, "", clientConf)
   791  		if err != nil {
   792  			t.Fatalf("NewClientConn: %v", err)
   793  		}
   794  		if alg != want {
   795  			t.Errorf("selected key algorithm %s, want %s", alg, want)
   796  		}
   797  	}
   798  
   799  	// By default, we get the preferred algorithm, which is ECDSA 256.
   800  
   801  	clientConf := &ClientConfig{
   802  		HostKeyCallback: InsecureIgnoreHostKey(),
   803  	}
   804  	connect(clientConf, KeyAlgoECDSA256)
   805  
   806  	// Client asks for RSA explicitly.
   807  	clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA}
   808  	connect(clientConf, KeyAlgoRSA)
   809  
   810  	// Client asks for RSA-SHA2-512 explicitly.
   811  	clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512}
   812  	// We get back an "ssh-rsa" key but the verification happened
   813  	// with an RSA-SHA2-512 signature.
   814  	connect(clientConf, KeyAlgoRSA)
   815  
   816  	c1, c2, err := netPipe()
   817  	if err != nil {
   818  		t.Fatalf("netPipe: %v", err)
   819  	}
   820  	defer c1.Close()
   821  	defer c2.Close()
   822  
   823  	wg.Add(1)
   824  	go func() {
   825  		NewServerConn(c1, serverConf)
   826  		wg.Done()
   827  	}()
   828  	clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
   829  	_, _, _, err = NewClientConn(c2, "", clientConf)
   830  	if err == nil {
   831  		t.Fatal("succeeded connecting with unknown hostkey algorithm")
   832  	}
   833  }
   834  
   835  func TestServerClientAuthCallback(t *testing.T) {
   836  	c1, c2, err := netPipe()
   837  	if err != nil {
   838  		t.Fatalf("netPipe: %v", err)
   839  	}
   840  	defer c1.Close()
   841  	defer c2.Close()
   842  
   843  	userCh := make(chan string, 1)
   844  
   845  	serverConf := &ServerConfig{
   846  		NoClientAuth: true,
   847  		NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
   848  			userCh <- conn.User()
   849  			return nil, nil
   850  		},
   851  	}
   852  	const someUsername = "some-username"
   853  
   854  	serverConf.AddHostKey(testSigners["ecdsa"])
   855  	clientConf := &ClientConfig{
   856  		HostKeyCallback: InsecureIgnoreHostKey(),
   857  		User:            someUsername,
   858  	}
   859  
   860  	var wg sync.WaitGroup
   861  	t.Cleanup(wg.Wait)
   862  	wg.Add(1)
   863  	go func() {
   864  		defer wg.Done()
   865  		_, chans, reqs, err := NewServerConn(c1, serverConf)
   866  		if err != nil {
   867  			t.Errorf("server handshake: %v", err)
   868  			userCh <- "error"
   869  			return
   870  		}
   871  		wg.Add(1)
   872  		go func() {
   873  			DiscardRequests(reqs)
   874  			wg.Done()
   875  		}()
   876  		for ch := range chans {
   877  			ch.Reject(Prohibited, "")
   878  		}
   879  	}()
   880  
   881  	conn, _, _, err := NewClientConn(c2, "", clientConf)
   882  	if err != nil {
   883  		t.Fatalf("client handshake: %v", err)
   884  		return
   885  	}
   886  	conn.Close()
   887  
   888  	got := <-userCh
   889  	if got != someUsername {
   890  		t.Errorf("username = %q; want %q", got, someUsername)
   891  	}
   892  }
   893  

View as plain text