...

Source file src/golang.org/x/crypto/ssh/handshake_test.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"reflect"
    15  	"runtime"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  type testChecker struct {
    22  	calls []string
    23  }
    24  
    25  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    26  	if dialAddr == "bad" {
    27  		return fmt.Errorf("dialAddr is bad")
    28  	}
    29  
    30  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    31  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    32  	}
    33  
    34  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    35  
    36  	return nil
    37  }
    38  
    39  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    40  // therefore is buffered (net.Pipe deadlocks if both sides start with
    41  // a write.)
    42  func netPipe() (net.Conn, net.Conn, error) {
    43  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    44  	if err != nil {
    45  		listener, err = net.Listen("tcp", "[::1]:0")
    46  		if err != nil {
    47  			return nil, nil, err
    48  		}
    49  	}
    50  	defer listener.Close()
    51  	c1, err := net.Dial("tcp", listener.Addr().String())
    52  	if err != nil {
    53  		return nil, nil, err
    54  	}
    55  
    56  	c2, err := listener.Accept()
    57  	if err != nil {
    58  		c1.Close()
    59  		return nil, nil, err
    60  	}
    61  
    62  	return c1, c2, nil
    63  }
    64  
    65  // noiseTransport inserts ignore messages to check that the read loop
    66  // and the key exchange filters out these messages.
    67  type noiseTransport struct {
    68  	keyingTransport
    69  }
    70  
    71  func (t *noiseTransport) writePacket(p []byte) error {
    72  	ignore := []byte{msgIgnore}
    73  	if err := t.keyingTransport.writePacket(ignore); err != nil {
    74  		return err
    75  	}
    76  	debug := []byte{msgDebug, 1, 2, 3}
    77  	if err := t.keyingTransport.writePacket(debug); err != nil {
    78  		return err
    79  	}
    80  
    81  	return t.keyingTransport.writePacket(p)
    82  }
    83  
    84  func addNoiseTransport(t keyingTransport) keyingTransport {
    85  	return &noiseTransport{t}
    86  }
    87  
    88  // handshakePair creates two handshakeTransports connected with each
    89  // other. If the noise argument is true, both transports will try to
    90  // confuse the other side by sending ignore and debug messages.
    91  func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
    92  	a, b, err := netPipe()
    93  	if err != nil {
    94  		return nil, nil, err
    95  	}
    96  
    97  	var trC, trS keyingTransport
    98  
    99  	trC = newTransport(a, rand.Reader, true)
   100  	trS = newTransport(b, rand.Reader, false)
   101  	if noise {
   102  		trC = addNoiseTransport(trC)
   103  		trS = addNoiseTransport(trS)
   104  	}
   105  	clientConf.SetDefaults()
   106  
   107  	v := []byte("version")
   108  	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
   109  
   110  	serverConf := &ServerConfig{}
   111  	serverConf.AddHostKey(testSigners["ecdsa"])
   112  	serverConf.AddHostKey(testSigners["rsa"])
   113  	serverConf.SetDefaults()
   114  	server = newServerTransport(trS, v, v, serverConf)
   115  
   116  	if err := server.waitSession(); err != nil {
   117  		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
   118  	}
   119  	if err := client.waitSession(); err != nil {
   120  		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
   121  	}
   122  
   123  	return client, server, nil
   124  }
   125  
   126  func TestHandshakeBasic(t *testing.T) {
   127  	if runtime.GOOS == "plan9" {
   128  		t.Skip("see golang.org/issue/7237")
   129  	}
   130  
   131  	checker := &syncChecker{
   132  		waitCall: make(chan int, 10),
   133  		called:   make(chan int, 10),
   134  	}
   135  
   136  	checker.waitCall <- 1
   137  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   138  	if err != nil {
   139  		t.Fatalf("handshakePair: %v", err)
   140  	}
   141  
   142  	defer trC.Close()
   143  	defer trS.Close()
   144  
   145  	// Let first kex complete normally.
   146  	<-checker.called
   147  
   148  	clientDone := make(chan int, 0)
   149  	gotHalf := make(chan int, 0)
   150  	const N = 20
   151  	errorCh := make(chan error, 1)
   152  
   153  	go func() {
   154  		defer close(clientDone)
   155  		// Client writes a bunch of stuff, and does a key
   156  		// change in the middle. This should not confuse the
   157  		// handshake in progress. We do this twice, so we test
   158  		// that the packet buffer is reset correctly.
   159  		for i := 0; i < N; i++ {
   160  			p := []byte{msgRequestSuccess, byte(i)}
   161  			if err := trC.writePacket(p); err != nil {
   162  				errorCh <- err
   163  				trC.Close()
   164  				return
   165  			}
   166  			if (i % 10) == 5 {
   167  				<-gotHalf
   168  				// halfway through, we request a key change.
   169  				trC.requestKeyExchange()
   170  
   171  				// Wait until we can be sure the key
   172  				// change has really started before we
   173  				// write more.
   174  				<-checker.called
   175  			}
   176  			if (i % 10) == 7 {
   177  				// write some packets until the kex
   178  				// completes, to test buffering of
   179  				// packets.
   180  				checker.waitCall <- 1
   181  			}
   182  		}
   183  		errorCh <- nil
   184  	}()
   185  
   186  	// Server checks that client messages come in cleanly
   187  	i := 0
   188  	for ; i < N; i++ {
   189  		p, err := trS.readPacket()
   190  		if err != nil && err != io.EOF {
   191  			t.Fatalf("server error: %v", err)
   192  		}
   193  		if (i % 10) == 5 {
   194  			gotHalf <- 1
   195  		}
   196  
   197  		want := []byte{msgRequestSuccess, byte(i)}
   198  		if bytes.Compare(p, want) != 0 {
   199  			t.Errorf("message %d: got %v, want %v", i, p, want)
   200  		}
   201  	}
   202  	<-clientDone
   203  	if err := <-errorCh; err != nil {
   204  		t.Fatalf("sendPacket: %v", err)
   205  	}
   206  	if i != N {
   207  		t.Errorf("received %d messages, want 10.", i)
   208  	}
   209  
   210  	close(checker.called)
   211  	if _, ok := <-checker.called; ok {
   212  		// If all went well, we registered exactly 2 key changes: one
   213  		// that establishes the session, and one that we requested
   214  		// additionally.
   215  		t.Fatalf("got another host key checks after 2 handshakes")
   216  	}
   217  }
   218  
   219  func TestForceFirstKex(t *testing.T) {
   220  	// like handshakePair, but must access the keyingTransport.
   221  	checker := &testChecker{}
   222  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   223  	a, b, err := netPipe()
   224  	if err != nil {
   225  		t.Fatalf("netPipe: %v", err)
   226  	}
   227  
   228  	var trC, trS keyingTransport
   229  
   230  	trC = newTransport(a, rand.Reader, true)
   231  
   232  	// This is the disallowed packet:
   233  	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
   234  
   235  	// Rest of the setup.
   236  	trS = newTransport(b, rand.Reader, false)
   237  	clientConf.SetDefaults()
   238  
   239  	v := []byte("version")
   240  	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
   241  
   242  	serverConf := &ServerConfig{}
   243  	serverConf.AddHostKey(testSigners["ecdsa"])
   244  	serverConf.AddHostKey(testSigners["rsa"])
   245  	serverConf.SetDefaults()
   246  	server := newServerTransport(trS, v, v, serverConf)
   247  
   248  	defer client.Close()
   249  	defer server.Close()
   250  
   251  	// We setup the initial key exchange, but the remote side
   252  	// tries to send serviceRequestMsg in cleartext, which is
   253  	// disallowed.
   254  
   255  	if err := server.waitSession(); err == nil {
   256  		t.Errorf("server first kex init should reject unexpected packet")
   257  	}
   258  }
   259  
   260  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   261  	checker := &syncChecker{
   262  		called:   make(chan int, 10),
   263  		waitCall: nil,
   264  	}
   265  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   266  	clientConf.RekeyThreshold = 500
   267  	trC, trS, err := handshakePair(clientConf, "addr", false)
   268  	if err != nil {
   269  		t.Fatalf("handshakePair: %v", err)
   270  	}
   271  	defer trC.Close()
   272  	defer trS.Close()
   273  
   274  	input := make([]byte, 251)
   275  	input[0] = msgRequestSuccess
   276  
   277  	done := make(chan int, 1)
   278  	const numPacket = 5
   279  	go func() {
   280  		defer close(done)
   281  		j := 0
   282  		for ; j < numPacket; j++ {
   283  			if p, err := trS.readPacket(); err != nil {
   284  				break
   285  			} else if !bytes.Equal(input, p) {
   286  				t.Errorf("got packet type %d, want %d", p[0], input[0])
   287  			}
   288  		}
   289  
   290  		if j != numPacket {
   291  			t.Errorf("got %d, want 5 messages", j)
   292  		}
   293  	}()
   294  
   295  	<-checker.called
   296  
   297  	for i := 0; i < numPacket; i++ {
   298  		p := make([]byte, len(input))
   299  		copy(p, input)
   300  		if err := trC.writePacket(p); err != nil {
   301  			t.Errorf("writePacket: %v", err)
   302  		}
   303  		if i == 2 {
   304  			// Make sure the kex is in progress.
   305  			<-checker.called
   306  		}
   307  
   308  	}
   309  	<-done
   310  }
   311  
   312  type syncChecker struct {
   313  	waitCall chan int
   314  	called   chan int
   315  }
   316  
   317  func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   318  	c.called <- 1
   319  	if c.waitCall != nil {
   320  		<-c.waitCall
   321  	}
   322  	return nil
   323  }
   324  
   325  func TestHandshakeAutoRekeyRead(t *testing.T) {
   326  	sync := &syncChecker{
   327  		called:   make(chan int, 2),
   328  		waitCall: nil,
   329  	}
   330  	clientConf := &ClientConfig{
   331  		HostKeyCallback: sync.Check,
   332  	}
   333  	clientConf.RekeyThreshold = 500
   334  
   335  	trC, trS, err := handshakePair(clientConf, "addr", false)
   336  	if err != nil {
   337  		t.Fatalf("handshakePair: %v", err)
   338  	}
   339  	defer trC.Close()
   340  	defer trS.Close()
   341  
   342  	packet := make([]byte, 501)
   343  	packet[0] = msgRequestSuccess
   344  	if err := trS.writePacket(packet); err != nil {
   345  		t.Fatalf("writePacket: %v", err)
   346  	}
   347  
   348  	// While we read out the packet, a key change will be
   349  	// initiated.
   350  	errorCh := make(chan error, 1)
   351  	go func() {
   352  		_, err := trC.readPacket()
   353  		errorCh <- err
   354  	}()
   355  
   356  	if err := <-errorCh; err != nil {
   357  		t.Fatalf("readPacket(client): %v", err)
   358  	}
   359  
   360  	<-sync.called
   361  }
   362  
   363  // errorKeyingTransport generates errors after a given number of
   364  // read/write operations.
   365  type errorKeyingTransport struct {
   366  	packetConn
   367  	readLeft, writeLeft int
   368  }
   369  
   370  func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
   371  	return nil
   372  }
   373  
   374  func (n *errorKeyingTransport) getSessionID() []byte {
   375  	return nil
   376  }
   377  
   378  func (n *errorKeyingTransport) writePacket(packet []byte) error {
   379  	if n.writeLeft == 0 {
   380  		n.Close()
   381  		return errors.New("barf")
   382  	}
   383  
   384  	n.writeLeft--
   385  	return n.packetConn.writePacket(packet)
   386  }
   387  
   388  func (n *errorKeyingTransport) readPacket() ([]byte, error) {
   389  	if n.readLeft == 0 {
   390  		n.Close()
   391  		return nil, errors.New("barf")
   392  	}
   393  
   394  	n.readLeft--
   395  	return n.packetConn.readPacket()
   396  }
   397  
   398  func (n *errorKeyingTransport) setStrictMode() error { return nil }
   399  
   400  func (n *errorKeyingTransport) setInitialKEXDone() {}
   401  
   402  func TestHandshakeErrorHandlingRead(t *testing.T) {
   403  	for i := 0; i < 20; i++ {
   404  		testHandshakeErrorHandlingN(t, i, -1, false)
   405  	}
   406  }
   407  
   408  func TestHandshakeErrorHandlingWrite(t *testing.T) {
   409  	for i := 0; i < 20; i++ {
   410  		testHandshakeErrorHandlingN(t, -1, i, false)
   411  	}
   412  }
   413  
   414  func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
   415  	for i := 0; i < 20; i++ {
   416  		testHandshakeErrorHandlingN(t, i, -1, true)
   417  	}
   418  }
   419  
   420  func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
   421  	for i := 0; i < 20; i++ {
   422  		testHandshakeErrorHandlingN(t, -1, i, true)
   423  	}
   424  }
   425  
   426  // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
   427  // handshakeTransport deadlocks, the go runtime will detect it and
   428  // panic.
   429  func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
   430  	if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" {
   431  		t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS)
   432  	}
   433  	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
   434  
   435  	a, b := memPipe()
   436  	defer a.Close()
   437  	defer b.Close()
   438  
   439  	key := testSigners["ecdsa"]
   440  	serverConf := Config{RekeyThreshold: minRekeyThreshold}
   441  	serverConf.SetDefaults()
   442  	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
   443  	serverConn.hostKeys = []Signer{key}
   444  	go serverConn.readLoop()
   445  	go serverConn.kexLoop()
   446  
   447  	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
   448  	clientConf.SetDefaults()
   449  	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
   450  	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
   451  	clientConn.hostKeyCallback = InsecureIgnoreHostKey()
   452  	go clientConn.readLoop()
   453  	go clientConn.kexLoop()
   454  
   455  	var wg sync.WaitGroup
   456  
   457  	for _, hs := range []packetConn{serverConn, clientConn} {
   458  		if !coupled {
   459  			wg.Add(2)
   460  			go func(c packetConn) {
   461  				for i := 0; ; i++ {
   462  					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
   463  					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
   464  					if err != nil {
   465  						break
   466  					}
   467  				}
   468  				wg.Done()
   469  				c.Close()
   470  			}(hs)
   471  			go func(c packetConn) {
   472  				for {
   473  					_, err := c.readPacket()
   474  					if err != nil {
   475  						break
   476  					}
   477  				}
   478  				wg.Done()
   479  			}(hs)
   480  		} else {
   481  			wg.Add(1)
   482  			go func(c packetConn) {
   483  				for {
   484  					_, err := c.readPacket()
   485  					if err != nil {
   486  						break
   487  					}
   488  					if err := c.writePacket(msg); err != nil {
   489  						break
   490  					}
   491  
   492  				}
   493  				wg.Done()
   494  			}(hs)
   495  		}
   496  	}
   497  	wg.Wait()
   498  }
   499  
   500  func TestDisconnect(t *testing.T) {
   501  	if runtime.GOOS == "plan9" {
   502  		t.Skip("see golang.org/issue/7237")
   503  	}
   504  	checker := &testChecker{}
   505  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   506  	if err != nil {
   507  		t.Fatalf("handshakePair: %v", err)
   508  	}
   509  
   510  	defer trC.Close()
   511  	defer trS.Close()
   512  
   513  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   514  	errMsg := &disconnectMsg{
   515  		Reason:  42,
   516  		Message: "such is life",
   517  	}
   518  	trC.writePacket(Marshal(errMsg))
   519  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   520  
   521  	packet, err := trS.readPacket()
   522  	if err != nil {
   523  		t.Fatalf("readPacket 1: %v", err)
   524  	}
   525  	if packet[0] != msgRequestSuccess {
   526  		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
   527  	}
   528  
   529  	_, err = trS.readPacket()
   530  	if err == nil {
   531  		t.Errorf("readPacket 2 succeeded")
   532  	} else if !reflect.DeepEqual(err, errMsg) {
   533  		t.Errorf("got error %#v, want %#v", err, errMsg)
   534  	}
   535  
   536  	_, err = trS.readPacket()
   537  	if err == nil {
   538  		t.Errorf("readPacket 3 succeeded")
   539  	}
   540  }
   541  
   542  func TestHandshakeRekeyDefault(t *testing.T) {
   543  	clientConf := &ClientConfig{
   544  		Config: Config{
   545  			Ciphers: []string{"aes128-ctr"},
   546  		},
   547  		HostKeyCallback: InsecureIgnoreHostKey(),
   548  	}
   549  	trC, trS, err := handshakePair(clientConf, "addr", false)
   550  	if err != nil {
   551  		t.Fatalf("handshakePair: %v", err)
   552  	}
   553  	defer trC.Close()
   554  	defer trS.Close()
   555  
   556  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   557  	trC.Close()
   558  
   559  	rgb := (1024 + trC.readBytesLeft) >> 30
   560  	wgb := (1024 + trC.writeBytesLeft) >> 30
   561  
   562  	if rgb != 64 {
   563  		t.Errorf("got rekey after %dG read, want 64G", rgb)
   564  	}
   565  	if wgb != 64 {
   566  		t.Errorf("got rekey after %dG write, want 64G", wgb)
   567  	}
   568  }
   569  
   570  func TestHandshakeAEADCipherNoMAC(t *testing.T) {
   571  	for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} {
   572  		checker := &syncChecker{
   573  			called: make(chan int, 1),
   574  		}
   575  		clientConf := &ClientConfig{
   576  			Config: Config{
   577  				Ciphers: []string{cipher},
   578  				MACs:    []string{},
   579  			},
   580  			HostKeyCallback: checker.Check,
   581  		}
   582  		trC, trS, err := handshakePair(clientConf, "addr", false)
   583  		if err != nil {
   584  			t.Fatalf("handshakePair: %v", err)
   585  		}
   586  		defer trC.Close()
   587  		defer trS.Close()
   588  
   589  		<-checker.called
   590  	}
   591  }
   592  
   593  // TestNoSHA2Support tests a host key Signer that is not an AlgorithmSigner and
   594  // therefore can't do SHA-2 signatures. Ensures the server does not advertise
   595  // support for them in this case.
   596  func TestNoSHA2Support(t *testing.T) {
   597  	c1, c2, err := netPipe()
   598  	if err != nil {
   599  		t.Fatalf("netPipe: %v", err)
   600  	}
   601  	defer c1.Close()
   602  	defer c2.Close()
   603  
   604  	serverConf := &ServerConfig{
   605  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   606  			return &Permissions{}, nil
   607  		},
   608  	}
   609  	serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]})
   610  	go func() {
   611  		_, _, _, err := NewServerConn(c1, serverConf)
   612  		if err != nil {
   613  			t.Error(err)
   614  		}
   615  	}()
   616  
   617  	clientConf := &ClientConfig{
   618  		User:            "test",
   619  		Auth:            []AuthMethod{Password("testpw")},
   620  		HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()),
   621  	}
   622  
   623  	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
   624  		t.Fatal(err)
   625  	}
   626  }
   627  
   628  func TestMultiAlgoSignerHandshake(t *testing.T) {
   629  	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
   630  	if !ok {
   631  		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
   632  	}
   633  	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
   634  	if err != nil {
   635  		t.Fatalf("unable to create multi algorithm signer: %v", err)
   636  	}
   637  	c1, c2, err := netPipe()
   638  	if err != nil {
   639  		t.Fatalf("netPipe: %v", err)
   640  	}
   641  	defer c1.Close()
   642  	defer c2.Close()
   643  
   644  	serverConf := &ServerConfig{
   645  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   646  			return &Permissions{}, nil
   647  		},
   648  	}
   649  	serverConf.AddHostKey(multiAlgoSigner)
   650  	go NewServerConn(c1, serverConf)
   651  
   652  	clientConf := &ClientConfig{
   653  		User:              "test",
   654  		Auth:              []AuthMethod{Password("testpw")},
   655  		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
   656  		HostKeyAlgorithms: []string{KeyAlgoRSASHA512},
   657  	}
   658  
   659  	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
   660  		t.Fatal(err)
   661  	}
   662  }
   663  
   664  func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) {
   665  	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
   666  	if !ok {
   667  		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
   668  	}
   669  	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
   670  	if err != nil {
   671  		t.Fatalf("unable to create multi algorithm signer: %v", err)
   672  	}
   673  	c1, c2, err := netPipe()
   674  	if err != nil {
   675  		t.Fatalf("netPipe: %v", err)
   676  	}
   677  	defer c1.Close()
   678  	defer c2.Close()
   679  
   680  	// ssh-rsa is disabled server side
   681  	serverConf := &ServerConfig{
   682  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   683  			return &Permissions{}, nil
   684  		},
   685  	}
   686  	serverConf.AddHostKey(multiAlgoSigner)
   687  	go NewServerConn(c1, serverConf)
   688  
   689  	// the client only supports ssh-rsa
   690  	clientConf := &ClientConfig{
   691  		User:              "test",
   692  		Auth:              []AuthMethod{Password("testpw")},
   693  		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
   694  		HostKeyAlgorithms: []string{KeyAlgoRSA},
   695  	}
   696  
   697  	_, _, _, err = NewClientConn(c2, "", clientConf)
   698  	if err == nil {
   699  		t.Fatal("succeeded connecting with no common hostkey algorithm")
   700  	}
   701  }
   702  
   703  func TestPickIncompatibleHostKeyAlgo(t *testing.T) {
   704  	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
   705  	if !ok {
   706  		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
   707  	}
   708  	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
   709  	if err != nil {
   710  		t.Fatalf("unable to create multi algorithm signer: %v", err)
   711  	}
   712  	signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA)
   713  	if signer != nil {
   714  		t.Fatal("incompatible signer returned")
   715  	}
   716  }
   717  
   718  func TestStrictKEXResetSeqFirstKEX(t *testing.T) {
   719  	if runtime.GOOS == "plan9" {
   720  		t.Skip("see golang.org/issue/7237")
   721  	}
   722  
   723  	checker := &syncChecker{
   724  		waitCall: make(chan int, 10),
   725  		called:   make(chan int, 10),
   726  	}
   727  
   728  	checker.waitCall <- 1
   729  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   730  	if err != nil {
   731  		t.Fatalf("handshakePair: %v", err)
   732  	}
   733  	<-checker.called
   734  
   735  	t.Cleanup(func() {
   736  		trC.Close()
   737  		trS.Close()
   738  	})
   739  
   740  	// Throw away the msgExtInfo packet sent during the handshake by the server
   741  	_, err = trC.readPacket()
   742  	if err != nil {
   743  		t.Fatalf("readPacket failed: %s", err)
   744  	}
   745  
   746  	// close the handshake transports before checking the sequence number to
   747  	// avoid races.
   748  	trC.Close()
   749  	trS.Close()
   750  
   751  	// check that the sequence number counters. We reset after msgNewKeys, but
   752  	// then the server immediately writes msgExtInfo, and we close the
   753  	// transports so we expect read 2, write 0 on the client and read 1, write 1
   754  	// on the server.
   755  	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
   756  		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
   757  		t.Errorf(
   758  			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
   759  			trC.conn.(*transport).reader.seqNum,
   760  			trC.conn.(*transport).writer.seqNum,
   761  			trS.conn.(*transport).reader.seqNum,
   762  			trS.conn.(*transport).writer.seqNum,
   763  		)
   764  	}
   765  }
   766  
   767  func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) {
   768  	if runtime.GOOS == "plan9" {
   769  		t.Skip("see golang.org/issue/7237")
   770  	}
   771  
   772  	checker := &syncChecker{
   773  		waitCall: make(chan int, 10),
   774  		called:   make(chan int, 10),
   775  	}
   776  
   777  	checker.waitCall <- 1
   778  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   779  	if err != nil {
   780  		t.Fatalf("handshakePair: %v", err)
   781  	}
   782  	<-checker.called
   783  
   784  	t.Cleanup(func() {
   785  		trC.Close()
   786  		trS.Close()
   787  	})
   788  
   789  	// Throw away the msgExtInfo packet sent during the handshake by the server
   790  	_, err = trC.readPacket()
   791  	if err != nil {
   792  		t.Fatalf("readPacket failed: %s", err)
   793  	}
   794  
   795  	// write and read five packets on either side to bump the sequence numbers
   796  	for i := 0; i < 5; i++ {
   797  		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
   798  			t.Fatalf("writePacket failed: %s", err)
   799  		}
   800  		if _, err := trS.readPacket(); err != nil {
   801  			t.Fatalf("readPacket failed: %s", err)
   802  		}
   803  		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
   804  			t.Fatalf("writePacket failed: %s", err)
   805  		}
   806  		if _, err := trC.readPacket(); err != nil {
   807  			t.Fatalf("readPacket failed: %s", err)
   808  		}
   809  	}
   810  
   811  	// Request a key exchange, which should cause the sequence numbers to reset
   812  	checker.waitCall <- 1
   813  	trC.requestKeyExchange()
   814  	<-checker.called
   815  
   816  	// write a packet on the client, and then read it, to verify the key change has actually happened, since
   817  	// the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake
   818  	// finishing.
   819  	dummyPacket := []byte{99}
   820  	if err := trS.writePacket(dummyPacket); err != nil {
   821  		t.Fatalf("writePacket failed: %s", err)
   822  	}
   823  	if p, err := trC.readPacket(); err != nil {
   824  		t.Fatalf("readPacket failed: %s", err)
   825  	} else if !bytes.Equal(p, dummyPacket) {
   826  		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
   827  	}
   828  
   829  	// close the handshake transports before checking the sequence number to
   830  	// avoid races.
   831  	trC.Close()
   832  	trS.Close()
   833  
   834  	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
   835  		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
   836  		t.Errorf(
   837  			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
   838  			trC.conn.(*transport).reader.seqNum,
   839  			trC.conn.(*transport).writer.seqNum,
   840  			trS.conn.(*transport).reader.seqNum,
   841  			trS.conn.(*transport).writer.seqNum,
   842  		)
   843  	}
   844  }
   845  
   846  func TestSeqNumIncrease(t *testing.T) {
   847  	if runtime.GOOS == "plan9" {
   848  		t.Skip("see golang.org/issue/7237")
   849  	}
   850  
   851  	checker := &syncChecker{
   852  		waitCall: make(chan int, 10),
   853  		called:   make(chan int, 10),
   854  	}
   855  
   856  	checker.waitCall <- 1
   857  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   858  	if err != nil {
   859  		t.Fatalf("handshakePair: %v", err)
   860  	}
   861  	<-checker.called
   862  
   863  	t.Cleanup(func() {
   864  		trC.Close()
   865  		trS.Close()
   866  	})
   867  
   868  	// Throw away the msgExtInfo packet sent during the handshake by the server
   869  	_, err = trC.readPacket()
   870  	if err != nil {
   871  		t.Fatalf("readPacket failed: %s", err)
   872  	}
   873  
   874  	// write and read five packets on either side to bump the sequence numbers
   875  	for i := 0; i < 5; i++ {
   876  		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
   877  			t.Fatalf("writePacket failed: %s", err)
   878  		}
   879  		if _, err := trS.readPacket(); err != nil {
   880  			t.Fatalf("readPacket failed: %s", err)
   881  		}
   882  		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
   883  			t.Fatalf("writePacket failed: %s", err)
   884  		}
   885  		if _, err := trC.readPacket(); err != nil {
   886  			t.Fatalf("readPacket failed: %s", err)
   887  		}
   888  	}
   889  
   890  	// close the handshake transports before checking the sequence number to
   891  	// avoid races.
   892  	trC.Close()
   893  	trS.Close()
   894  
   895  	if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 ||
   896  		trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 {
   897  		t.Errorf(
   898  			"unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)",
   899  			trC.conn.(*transport).reader.seqNum,
   900  			trC.conn.(*transport).writer.seqNum,
   901  			trS.conn.(*transport).reader.seqNum,
   902  			trS.conn.(*transport).writer.seqNum,
   903  		)
   904  	}
   905  }
   906  
   907  func TestStrictKEXUnexpectedMsg(t *testing.T) {
   908  	if runtime.GOOS == "plan9" {
   909  		t.Skip("see golang.org/issue/7237")
   910  	}
   911  
   912  	// Check that unexpected messages during the handshake cause failure
   913  	_, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true)
   914  	if err == nil {
   915  		t.Fatal("handshake should fail when there are unexpected messages during the handshake")
   916  	}
   917  
   918  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false)
   919  	if err != nil {
   920  		t.Fatalf("handshake failed: %s", err)
   921  	}
   922  
   923  	// Check that ignore/debug pacekts are still ignored outside of the handshake
   924  	if err := trC.writePacket([]byte{msgIgnore}); err != nil {
   925  		t.Fatalf("writePacket failed: %s", err)
   926  	}
   927  	if err := trC.writePacket([]byte{msgDebug}); err != nil {
   928  		t.Fatalf("writePacket failed: %s", err)
   929  	}
   930  	dummyPacket := []byte{99}
   931  	if err := trC.writePacket(dummyPacket); err != nil {
   932  		t.Fatalf("writePacket failed: %s", err)
   933  	}
   934  
   935  	if p, err := trS.readPacket(); err != nil {
   936  		t.Fatalf("readPacket failed: %s", err)
   937  	} else if !bytes.Equal(p, dummyPacket) {
   938  		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
   939  	}
   940  }
   941  
   942  func TestStrictKEXMixed(t *testing.T) {
   943  	// Test that we still support a mixed connection, where one side sends kex-strict but the other
   944  	// side doesn't.
   945  
   946  	a, b, err := netPipe()
   947  	if err != nil {
   948  		t.Fatalf("netPipe failed: %s", err)
   949  	}
   950  
   951  	var trC, trS keyingTransport
   952  
   953  	trC = newTransport(a, rand.Reader, true)
   954  	trS = newTransport(b, rand.Reader, false)
   955  	trS = addNoiseTransport(trS)
   956  
   957  	clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}
   958  	clientConf.SetDefaults()
   959  
   960  	v := []byte("version")
   961  	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
   962  
   963  	serverConf := &ServerConfig{}
   964  	serverConf.AddHostKey(testSigners["ecdsa"])
   965  	serverConf.AddHostKey(testSigners["rsa"])
   966  	serverConf.SetDefaults()
   967  
   968  	transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version"))
   969  	transport.hostKeys = serverConf.hostKeys
   970  	transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms
   971  
   972  	readOneFailure := make(chan error, 1)
   973  	go func() {
   974  		if _, err := transport.readOnePacket(true); err != nil {
   975  			readOneFailure <- err
   976  		}
   977  	}()
   978  
   979  	// Basically sendKexInit, but without the kex-strict extension algorithm
   980  	msg := &kexInitMsg{
   981  		KexAlgos:                transport.config.KeyExchanges,
   982  		CiphersClientServer:     transport.config.Ciphers,
   983  		CiphersServerClient:     transport.config.Ciphers,
   984  		MACsClientServer:        transport.config.MACs,
   985  		MACsServerClient:        transport.config.MACs,
   986  		CompressionClientServer: supportedCompressions,
   987  		CompressionServerClient: supportedCompressions,
   988  		ServerHostKeyAlgos:      []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA},
   989  	}
   990  	packet := Marshal(msg)
   991  	// writePacket destroys the contents, so save a copy.
   992  	packetCopy := make([]byte, len(packet))
   993  	copy(packetCopy, packet)
   994  	if err := transport.pushPacket(packetCopy); err != nil {
   995  		t.Fatalf("pushPacket: %s", err)
   996  	}
   997  	transport.sentInitMsg = msg
   998  	transport.sentInitPacket = packet
   999  
  1000  	if err := transport.getWriteError(); err != nil {
  1001  		t.Fatalf("getWriteError failed: %s", err)
  1002  	}
  1003  	var request *pendingKex
  1004  	select {
  1005  	case err = <-readOneFailure:
  1006  		t.Fatalf("server readOnePacket failed: %s", err)
  1007  	case request = <-transport.startKex:
  1008  		break
  1009  	}
  1010  
  1011  	// We expect the following calls to fail if the side which does not support
  1012  	// kex-strict sends unexpected/ignored packets during the handshake, even if
  1013  	// the other side does support kex-strict.
  1014  
  1015  	if err := transport.enterKeyExchange(request.otherInit); err != nil {
  1016  		t.Fatalf("enterKeyExchange failed: %s", err)
  1017  	}
  1018  	if err := client.waitSession(); err != nil {
  1019  		t.Fatalf("client.waitSession: %v", err)
  1020  	}
  1021  }
  1022  

View as plain text