...

Source file src/golang.org/x/crypto/ssh/handshake.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"crypto/rand"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"strings"
    15  	"sync"
    16  )
    17  
    18  // debugHandshake, if set, prints messages sent and received.  Key
    19  // exchange messages are printed as if DH were used, so the debug
    20  // messages are wrong when using ECDH.
    21  const debugHandshake = false
    22  
    23  // chanSize sets the amount of buffering SSH connections. This is
    24  // primarily for testing: setting chanSize=0 uncovers deadlocks more
    25  // quickly.
    26  const chanSize = 16
    27  
    28  // keyingTransport is a packet based transport that supports key
    29  // changes. It need not be thread-safe. It should pass through
    30  // msgNewKeys in both directions.
    31  type keyingTransport interface {
    32  	packetConn
    33  
    34  	// prepareKeyChange sets up a key change. The key change for a
    35  	// direction will be effected if a msgNewKeys message is sent
    36  	// or received.
    37  	prepareKeyChange(*algorithms, *kexResult) error
    38  
    39  	// setStrictMode sets the strict KEX mode, notably triggering
    40  	// sequence number resets on sending or receiving msgNewKeys.
    41  	// If the sequence number is already > 1 when setStrictMode
    42  	// is called, an error is returned.
    43  	setStrictMode() error
    44  
    45  	// setInitialKEXDone indicates to the transport that the initial key exchange
    46  	// was completed
    47  	setInitialKEXDone()
    48  }
    49  
    50  // handshakeTransport implements rekeying on top of a keyingTransport
    51  // and offers a thread-safe writePacket() interface.
    52  type handshakeTransport struct {
    53  	conn   keyingTransport
    54  	config *Config
    55  
    56  	serverVersion []byte
    57  	clientVersion []byte
    58  
    59  	// hostKeys is non-empty if we are the server. In that case,
    60  	// it contains all host keys that can be used to sign the
    61  	// connection.
    62  	hostKeys []Signer
    63  
    64  	// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
    65  	// it contains the supported client public key authentication algorithms.
    66  	publicKeyAuthAlgorithms []string
    67  
    68  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
    69  	// we accept these key types from the server as host key.
    70  	hostKeyAlgorithms []string
    71  
    72  	// On read error, incoming is closed, and readError is set.
    73  	incoming  chan []byte
    74  	readError error
    75  
    76  	mu               sync.Mutex
    77  	writeError       error
    78  	sentInitPacket   []byte
    79  	sentInitMsg      *kexInitMsg
    80  	pendingPackets   [][]byte // Used when a key exchange is in progress.
    81  	writePacketsLeft uint32
    82  	writeBytesLeft   int64
    83  
    84  	// If the read loop wants to schedule a kex, it pings this
    85  	// channel, and the write loop will send out a kex
    86  	// message.
    87  	requestKex chan struct{}
    88  
    89  	// If the other side requests or confirms a kex, its kexInit
    90  	// packet is sent here for the write loop to find it.
    91  	startKex    chan *pendingKex
    92  	kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
    93  
    94  	// data for host key checking
    95  	hostKeyCallback HostKeyCallback
    96  	dialAddress     string
    97  	remoteAddr      net.Addr
    98  
    99  	// bannerCallback is non-empty if we are the client and it has been set in
   100  	// ClientConfig. In that case it is called during the user authentication
   101  	// dance to handle a custom server's message.
   102  	bannerCallback BannerCallback
   103  
   104  	// Algorithms agreed in the last key exchange.
   105  	algorithms *algorithms
   106  
   107  	// Counters exclusively owned by readLoop.
   108  	readPacketsLeft uint32
   109  	readBytesLeft   int64
   110  
   111  	// The session ID or nil if first kex did not complete yet.
   112  	sessionID []byte
   113  
   114  	// strictMode indicates if the other side of the handshake indicated
   115  	// that we should be following the strict KEX protocol restrictions.
   116  	strictMode bool
   117  }
   118  
   119  type pendingKex struct {
   120  	otherInit []byte
   121  	done      chan error
   122  }
   123  
   124  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
   125  	t := &handshakeTransport{
   126  		conn:          conn,
   127  		serverVersion: serverVersion,
   128  		clientVersion: clientVersion,
   129  		incoming:      make(chan []byte, chanSize),
   130  		requestKex:    make(chan struct{}, 1),
   131  		startKex:      make(chan *pendingKex),
   132  		kexLoopDone:   make(chan struct{}),
   133  
   134  		config: config,
   135  	}
   136  	t.resetReadThresholds()
   137  	t.resetWriteThresholds()
   138  
   139  	// We always start with a mandatory key exchange.
   140  	t.requestKex <- struct{}{}
   141  	return t
   142  }
   143  
   144  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
   145  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   146  	t.dialAddress = dialAddr
   147  	t.remoteAddr = addr
   148  	t.hostKeyCallback = config.HostKeyCallback
   149  	t.bannerCallback = config.BannerCallback
   150  	if config.HostKeyAlgorithms != nil {
   151  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
   152  	} else {
   153  		t.hostKeyAlgorithms = supportedHostKeyAlgos
   154  	}
   155  	go t.readLoop()
   156  	go t.kexLoop()
   157  	return t
   158  }
   159  
   160  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   161  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   162  	t.hostKeys = config.hostKeys
   163  	t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
   164  	go t.readLoop()
   165  	go t.kexLoop()
   166  	return t
   167  }
   168  
   169  func (t *handshakeTransport) getSessionID() []byte {
   170  	return t.sessionID
   171  }
   172  
   173  // waitSession waits for the session to be established. This should be
   174  // the first thing to call after instantiating handshakeTransport.
   175  func (t *handshakeTransport) waitSession() error {
   176  	p, err := t.readPacket()
   177  	if err != nil {
   178  		return err
   179  	}
   180  	if p[0] != msgNewKeys {
   181  		return fmt.Errorf("ssh: first packet should be msgNewKeys")
   182  	}
   183  
   184  	return nil
   185  }
   186  
   187  func (t *handshakeTransport) id() string {
   188  	if len(t.hostKeys) > 0 {
   189  		return "server"
   190  	}
   191  	return "client"
   192  }
   193  
   194  func (t *handshakeTransport) printPacket(p []byte, write bool) {
   195  	action := "got"
   196  	if write {
   197  		action = "sent"
   198  	}
   199  
   200  	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
   201  		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
   202  	} else {
   203  		msg, err := decode(p)
   204  		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
   205  	}
   206  }
   207  
   208  func (t *handshakeTransport) readPacket() ([]byte, error) {
   209  	p, ok := <-t.incoming
   210  	if !ok {
   211  		return nil, t.readError
   212  	}
   213  	return p, nil
   214  }
   215  
   216  func (t *handshakeTransport) readLoop() {
   217  	first := true
   218  	for {
   219  		p, err := t.readOnePacket(first)
   220  		first = false
   221  		if err != nil {
   222  			t.readError = err
   223  			close(t.incoming)
   224  			break
   225  		}
   226  		// If this is the first kex, and strict KEX mode is enabled,
   227  		// we don't ignore any messages, as they may be used to manipulate
   228  		// the packet sequence numbers.
   229  		if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
   230  			continue
   231  		}
   232  		t.incoming <- p
   233  	}
   234  
   235  	// Stop writers too.
   236  	t.recordWriteError(t.readError)
   237  
   238  	// Unblock the writer should it wait for this.
   239  	close(t.startKex)
   240  
   241  	// Don't close t.requestKex; it's also written to from writePacket.
   242  }
   243  
   244  func (t *handshakeTransport) pushPacket(p []byte) error {
   245  	if debugHandshake {
   246  		t.printPacket(p, true)
   247  	}
   248  	return t.conn.writePacket(p)
   249  }
   250  
   251  func (t *handshakeTransport) getWriteError() error {
   252  	t.mu.Lock()
   253  	defer t.mu.Unlock()
   254  	return t.writeError
   255  }
   256  
   257  func (t *handshakeTransport) recordWriteError(err error) {
   258  	t.mu.Lock()
   259  	defer t.mu.Unlock()
   260  	if t.writeError == nil && err != nil {
   261  		t.writeError = err
   262  	}
   263  }
   264  
   265  func (t *handshakeTransport) requestKeyExchange() {
   266  	select {
   267  	case t.requestKex <- struct{}{}:
   268  	default:
   269  		// something already requested a kex, so do nothing.
   270  	}
   271  }
   272  
   273  func (t *handshakeTransport) resetWriteThresholds() {
   274  	t.writePacketsLeft = packetRekeyThreshold
   275  	if t.config.RekeyThreshold > 0 {
   276  		t.writeBytesLeft = int64(t.config.RekeyThreshold)
   277  	} else if t.algorithms != nil {
   278  		t.writeBytesLeft = t.algorithms.w.rekeyBytes()
   279  	} else {
   280  		t.writeBytesLeft = 1 << 30
   281  	}
   282  }
   283  
   284  func (t *handshakeTransport) kexLoop() {
   285  
   286  write:
   287  	for t.getWriteError() == nil {
   288  		var request *pendingKex
   289  		var sent bool
   290  
   291  		for request == nil || !sent {
   292  			var ok bool
   293  			select {
   294  			case request, ok = <-t.startKex:
   295  				if !ok {
   296  					break write
   297  				}
   298  			case <-t.requestKex:
   299  				break
   300  			}
   301  
   302  			if !sent {
   303  				if err := t.sendKexInit(); err != nil {
   304  					t.recordWriteError(err)
   305  					break
   306  				}
   307  				sent = true
   308  			}
   309  		}
   310  
   311  		if err := t.getWriteError(); err != nil {
   312  			if request != nil {
   313  				request.done <- err
   314  			}
   315  			break
   316  		}
   317  
   318  		// We're not servicing t.requestKex, but that is OK:
   319  		// we never block on sending to t.requestKex.
   320  
   321  		// We're not servicing t.startKex, but the remote end
   322  		// has just sent us a kexInitMsg, so it can't send
   323  		// another key change request, until we close the done
   324  		// channel on the pendingKex request.
   325  
   326  		err := t.enterKeyExchange(request.otherInit)
   327  
   328  		t.mu.Lock()
   329  		t.writeError = err
   330  		t.sentInitPacket = nil
   331  		t.sentInitMsg = nil
   332  
   333  		t.resetWriteThresholds()
   334  
   335  		// we have completed the key exchange. Since the
   336  		// reader is still blocked, it is safe to clear out
   337  		// the requestKex channel. This avoids the situation
   338  		// where: 1) we consumed our own request for the
   339  		// initial kex, and 2) the kex from the remote side
   340  		// caused another send on the requestKex channel,
   341  	clear:
   342  		for {
   343  			select {
   344  			case <-t.requestKex:
   345  				//
   346  			default:
   347  				break clear
   348  			}
   349  		}
   350  
   351  		request.done <- t.writeError
   352  
   353  		// kex finished. Push packets that we received while
   354  		// the kex was in progress. Don't look at t.startKex
   355  		// and don't increment writtenSinceKex: if we trigger
   356  		// another kex while we are still busy with the last
   357  		// one, things will become very confusing.
   358  		for _, p := range t.pendingPackets {
   359  			t.writeError = t.pushPacket(p)
   360  			if t.writeError != nil {
   361  				break
   362  			}
   363  		}
   364  		t.pendingPackets = t.pendingPackets[:0]
   365  		t.mu.Unlock()
   366  	}
   367  
   368  	// Unblock reader.
   369  	t.conn.Close()
   370  
   371  	// drain startKex channel. We don't service t.requestKex
   372  	// because nobody does blocking sends there.
   373  	for request := range t.startKex {
   374  		request.done <- t.getWriteError()
   375  	}
   376  
   377  	// Mark that the loop is done so that Close can return.
   378  	close(t.kexLoopDone)
   379  }
   380  
   381  // The protocol uses uint32 for packet counters, so we can't let them
   382  // reach 1<<32.  We will actually read and write more packets than
   383  // this, though: the other side may send more packets, and after we
   384  // hit this limit on writing we will send a few more packets for the
   385  // key exchange itself.
   386  const packetRekeyThreshold = (1 << 31)
   387  
   388  func (t *handshakeTransport) resetReadThresholds() {
   389  	t.readPacketsLeft = packetRekeyThreshold
   390  	if t.config.RekeyThreshold > 0 {
   391  		t.readBytesLeft = int64(t.config.RekeyThreshold)
   392  	} else if t.algorithms != nil {
   393  		t.readBytesLeft = t.algorithms.r.rekeyBytes()
   394  	} else {
   395  		t.readBytesLeft = 1 << 30
   396  	}
   397  }
   398  
   399  func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
   400  	p, err := t.conn.readPacket()
   401  	if err != nil {
   402  		return nil, err
   403  	}
   404  
   405  	if t.readPacketsLeft > 0 {
   406  		t.readPacketsLeft--
   407  	} else {
   408  		t.requestKeyExchange()
   409  	}
   410  
   411  	if t.readBytesLeft > 0 {
   412  		t.readBytesLeft -= int64(len(p))
   413  	} else {
   414  		t.requestKeyExchange()
   415  	}
   416  
   417  	if debugHandshake {
   418  		t.printPacket(p, false)
   419  	}
   420  
   421  	if first && p[0] != msgKexInit {
   422  		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
   423  	}
   424  
   425  	if p[0] != msgKexInit {
   426  		return p, nil
   427  	}
   428  
   429  	firstKex := t.sessionID == nil
   430  
   431  	kex := pendingKex{
   432  		done:      make(chan error, 1),
   433  		otherInit: p,
   434  	}
   435  	t.startKex <- &kex
   436  	err = <-kex.done
   437  
   438  	if debugHandshake {
   439  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
   440  	}
   441  
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  
   446  	t.resetReadThresholds()
   447  
   448  	// By default, a key exchange is hidden from higher layers by
   449  	// translating it into msgIgnore.
   450  	successPacket := []byte{msgIgnore}
   451  	if firstKex {
   452  		// sendKexInit() for the first kex waits for
   453  		// msgNewKeys so the authentication process is
   454  		// guaranteed to happen over an encrypted transport.
   455  		successPacket = []byte{msgNewKeys}
   456  	}
   457  
   458  	return successPacket, nil
   459  }
   460  
   461  const (
   462  	kexStrictClient = "kex-strict-c-v00@openssh.com"
   463  	kexStrictServer = "kex-strict-s-v00@openssh.com"
   464  )
   465  
   466  // sendKexInit sends a key change message.
   467  func (t *handshakeTransport) sendKexInit() error {
   468  	t.mu.Lock()
   469  	defer t.mu.Unlock()
   470  	if t.sentInitMsg != nil {
   471  		// kexInits may be sent either in response to the other side,
   472  		// or because our side wants to initiate a key change, so we
   473  		// may have already sent a kexInit. In that case, don't send a
   474  		// second kexInit.
   475  		return nil
   476  	}
   477  
   478  	msg := &kexInitMsg{
   479  		CiphersClientServer:     t.config.Ciphers,
   480  		CiphersServerClient:     t.config.Ciphers,
   481  		MACsClientServer:        t.config.MACs,
   482  		MACsServerClient:        t.config.MACs,
   483  		CompressionClientServer: supportedCompressions,
   484  		CompressionServerClient: supportedCompressions,
   485  	}
   486  	io.ReadFull(rand.Reader, msg.Cookie[:])
   487  
   488  	// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
   489  	// and possibly to add the ext-info extension algorithm. Since the slice may be the
   490  	// user owned KeyExchanges, we create our own slice in order to avoid using user
   491  	// owned memory by mistake.
   492  	msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
   493  	msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
   494  
   495  	isServer := len(t.hostKeys) > 0
   496  	if isServer {
   497  		for _, k := range t.hostKeys {
   498  			// If k is a MultiAlgorithmSigner, we restrict the signature
   499  			// algorithms. If k is a AlgorithmSigner, presume it supports all
   500  			// signature algorithms associated with the key format. If k is not
   501  			// an AlgorithmSigner, we can only assume it only supports the
   502  			// algorithms that matches the key format. (This means that Sign
   503  			// can't pick a different default).
   504  			keyFormat := k.PublicKey().Type()
   505  
   506  			switch s := k.(type) {
   507  			case MultiAlgorithmSigner:
   508  				for _, algo := range algorithmsForKeyFormat(keyFormat) {
   509  					if contains(s.Algorithms(), underlyingAlgo(algo)) {
   510  						msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
   511  					}
   512  				}
   513  			case AlgorithmSigner:
   514  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
   515  			default:
   516  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
   517  			}
   518  		}
   519  
   520  		if t.sessionID == nil {
   521  			msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
   522  		}
   523  	} else {
   524  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
   525  
   526  		// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
   527  		// algorithms the server supports for public key authentication. See RFC
   528  		// 8308, Section 2.1.
   529  		//
   530  		// We also send the strict KEX mode extension algorithm, in order to opt
   531  		// into the strict KEX mode.
   532  		if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
   533  			msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
   534  			msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
   535  		}
   536  
   537  	}
   538  
   539  	packet := Marshal(msg)
   540  
   541  	// writePacket destroys the contents, so save a copy.
   542  	packetCopy := make([]byte, len(packet))
   543  	copy(packetCopy, packet)
   544  
   545  	if err := t.pushPacket(packetCopy); err != nil {
   546  		return err
   547  	}
   548  
   549  	t.sentInitMsg = msg
   550  	t.sentInitPacket = packet
   551  
   552  	return nil
   553  }
   554  
   555  func (t *handshakeTransport) writePacket(p []byte) error {
   556  	switch p[0] {
   557  	case msgKexInit:
   558  		return errors.New("ssh: only handshakeTransport can send kexInit")
   559  	case msgNewKeys:
   560  		return errors.New("ssh: only handshakeTransport can send newKeys")
   561  	}
   562  
   563  	t.mu.Lock()
   564  	defer t.mu.Unlock()
   565  	if t.writeError != nil {
   566  		return t.writeError
   567  	}
   568  
   569  	if t.sentInitMsg != nil {
   570  		// Copy the packet so the writer can reuse the buffer.
   571  		cp := make([]byte, len(p))
   572  		copy(cp, p)
   573  		t.pendingPackets = append(t.pendingPackets, cp)
   574  		return nil
   575  	}
   576  
   577  	if t.writeBytesLeft > 0 {
   578  		t.writeBytesLeft -= int64(len(p))
   579  	} else {
   580  		t.requestKeyExchange()
   581  	}
   582  
   583  	if t.writePacketsLeft > 0 {
   584  		t.writePacketsLeft--
   585  	} else {
   586  		t.requestKeyExchange()
   587  	}
   588  
   589  	if err := t.pushPacket(p); err != nil {
   590  		t.writeError = err
   591  	}
   592  
   593  	return nil
   594  }
   595  
   596  func (t *handshakeTransport) Close() error {
   597  	// Close the connection. This should cause the readLoop goroutine to wake up
   598  	// and close t.startKex, which will shut down kexLoop if running.
   599  	err := t.conn.Close()
   600  
   601  	// Wait for the kexLoop goroutine to complete.
   602  	// At that point we know that the readLoop goroutine is complete too,
   603  	// because kexLoop itself waits for readLoop to close the startKex channel.
   604  	<-t.kexLoopDone
   605  
   606  	return err
   607  }
   608  
   609  func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
   610  	if debugHandshake {
   611  		log.Printf("%s entered key exchange", t.id())
   612  	}
   613  
   614  	otherInit := &kexInitMsg{}
   615  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   616  		return err
   617  	}
   618  
   619  	magics := handshakeMagics{
   620  		clientVersion: t.clientVersion,
   621  		serverVersion: t.serverVersion,
   622  		clientKexInit: otherInitPacket,
   623  		serverKexInit: t.sentInitPacket,
   624  	}
   625  
   626  	clientInit := otherInit
   627  	serverInit := t.sentInitMsg
   628  	isClient := len(t.hostKeys) == 0
   629  	if isClient {
   630  		clientInit, serverInit = serverInit, clientInit
   631  
   632  		magics.clientKexInit = t.sentInitPacket
   633  		magics.serverKexInit = otherInitPacket
   634  	}
   635  
   636  	var err error
   637  	t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
   638  	if err != nil {
   639  		return err
   640  	}
   641  
   642  	if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
   643  		t.strictMode = true
   644  		if err := t.conn.setStrictMode(); err != nil {
   645  			return err
   646  		}
   647  	}
   648  
   649  	// We don't send FirstKexFollows, but we handle receiving it.
   650  	//
   651  	// RFC 4253 section 7 defines the kex and the agreement method for
   652  	// first_kex_packet_follows. It states that the guessed packet
   653  	// should be ignored if the "kex algorithm and/or the host
   654  	// key algorithm is guessed wrong (server and client have
   655  	// different preferred algorithm), or if any of the other
   656  	// algorithms cannot be agreed upon". The other algorithms have
   657  	// already been checked above so the kex algorithm and host key
   658  	// algorithm are checked here.
   659  	if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
   660  		// other side sent a kex message for the wrong algorithm,
   661  		// which we have to ignore.
   662  		if _, err := t.conn.readPacket(); err != nil {
   663  			return err
   664  		}
   665  	}
   666  
   667  	kex, ok := kexAlgoMap[t.algorithms.kex]
   668  	if !ok {
   669  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
   670  	}
   671  
   672  	var result *kexResult
   673  	if len(t.hostKeys) > 0 {
   674  		result, err = t.server(kex, &magics)
   675  	} else {
   676  		result, err = t.client(kex, &magics)
   677  	}
   678  
   679  	if err != nil {
   680  		return err
   681  	}
   682  
   683  	firstKeyExchange := t.sessionID == nil
   684  	if firstKeyExchange {
   685  		t.sessionID = result.H
   686  	}
   687  	result.SessionID = t.sessionID
   688  
   689  	if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
   690  		return err
   691  	}
   692  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   693  		return err
   694  	}
   695  
   696  	// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
   697  	// message with the server-sig-algs extension if the client supports it. See
   698  	// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
   699  	if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
   700  		supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
   701  		extInfo := &extInfoMsg{
   702  			NumExtensions: 2,
   703  			Payload:       make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
   704  		}
   705  		extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
   706  		extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
   707  		extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
   708  		extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
   709  		extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
   710  		extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
   711  		extInfo.Payload = appendInt(extInfo.Payload, 1)
   712  		extInfo.Payload = append(extInfo.Payload, "0"...)
   713  		if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
   714  			return err
   715  		}
   716  	}
   717  
   718  	if packet, err := t.conn.readPacket(); err != nil {
   719  		return err
   720  	} else if packet[0] != msgNewKeys {
   721  		return unexpectedMessageError(msgNewKeys, packet[0])
   722  	}
   723  
   724  	if firstKeyExchange {
   725  		// Indicates to the transport that the first key exchange is completed
   726  		// after receiving SSH_MSG_NEWKEYS.
   727  		t.conn.setInitialKEXDone()
   728  	}
   729  
   730  	return nil
   731  }
   732  
   733  // algorithmSignerWrapper is an AlgorithmSigner that only supports the default
   734  // key format algorithm.
   735  //
   736  // This is technically a violation of the AlgorithmSigner interface, but it
   737  // should be unreachable given where we use this. Anyway, at least it returns an
   738  // error instead of panicing or producing an incorrect signature.
   739  type algorithmSignerWrapper struct {
   740  	Signer
   741  }
   742  
   743  func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
   744  	if algorithm != underlyingAlgo(a.PublicKey().Type()) {
   745  		return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
   746  	}
   747  	return a.Sign(rand, data)
   748  }
   749  
   750  func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
   751  	for _, k := range hostKeys {
   752  		if s, ok := k.(MultiAlgorithmSigner); ok {
   753  			if !contains(s.Algorithms(), underlyingAlgo(algo)) {
   754  				continue
   755  			}
   756  		}
   757  
   758  		if algo == k.PublicKey().Type() {
   759  			return algorithmSignerWrapper{k}
   760  		}
   761  
   762  		k, ok := k.(AlgorithmSigner)
   763  		if !ok {
   764  			continue
   765  		}
   766  		for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
   767  			if algo == a {
   768  				return k
   769  			}
   770  		}
   771  	}
   772  	return nil
   773  }
   774  
   775  func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
   776  	hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
   777  	if hostKey == nil {
   778  		return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
   779  	}
   780  
   781  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
   782  	return r, err
   783  }
   784  
   785  func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
   786  	result, err := kex.Client(t.conn, t.config.Rand, magics)
   787  	if err != nil {
   788  		return nil, err
   789  	}
   790  
   791  	hostKey, err := ParsePublicKey(result.HostKey)
   792  	if err != nil {
   793  		return nil, err
   794  	}
   795  
   796  	if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
   797  		return nil, err
   798  	}
   799  
   800  	err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   801  	if err != nil {
   802  		return nil, err
   803  	}
   804  
   805  	return result, nil
   806  }
   807  

View as plain text