1  
     2  
     3  
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"os"
    13  	"sync"
    14  	"time"
    15  )
    16  
    17  
    18  
    19  type Client struct {
    20  	Conn
    21  
    22  	handleForwardsOnce sync.Once 
    23  
    24  	forwards        forwardList 
    25  	mu              sync.Mutex
    26  	channelHandlers map[string]chan NewChannel
    27  }
    28  
    29  
    30  
    31  
    32  func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
    33  	c.mu.Lock()
    34  	defer c.mu.Unlock()
    35  	if c.channelHandlers == nil {
    36  		
    37  		c := make(chan NewChannel)
    38  		close(c)
    39  		return c
    40  	}
    41  
    42  	ch := c.channelHandlers[channelType]
    43  	if ch != nil {
    44  		return nil
    45  	}
    46  
    47  	ch = make(chan NewChannel, chanSize)
    48  	c.channelHandlers[channelType] = ch
    49  	return ch
    50  }
    51  
    52  
    53  func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
    54  	conn := &Client{
    55  		Conn:            c,
    56  		channelHandlers: make(map[string]chan NewChannel, 1),
    57  	}
    58  
    59  	go conn.handleGlobalRequests(reqs)
    60  	go conn.handleChannelOpens(chans)
    61  	go func() {
    62  		conn.Wait()
    63  		conn.forwards.closeAll()
    64  	}()
    65  	return conn
    66  }
    67  
    68  
    69  
    70  
    71  func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
    72  	fullConf := *config
    73  	fullConf.SetDefaults()
    74  	if fullConf.HostKeyCallback == nil {
    75  		c.Close()
    76  		return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
    77  	}
    78  
    79  	conn := &connection{
    80  		sshConn: sshConn{conn: c, user: fullConf.User},
    81  	}
    82  
    83  	if err := conn.clientHandshake(addr, &fullConf); err != nil {
    84  		c.Close()
    85  		return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
    86  	}
    87  	conn.mux = newMux(conn.transport)
    88  	return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
    89  }
    90  
    91  
    92  
    93  func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
    94  	if config.ClientVersion != "" {
    95  		c.clientVersion = []byte(config.ClientVersion)
    96  	} else {
    97  		c.clientVersion = []byte(packageVersion)
    98  	}
    99  	var err error
   100  	c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	c.transport = newClientTransport(
   106  		newTransport(c.sshConn.conn, config.Rand, true ),
   107  		c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
   108  	if err := c.transport.waitSession(); err != nil {
   109  		return err
   110  	}
   111  
   112  	c.sessionID = c.transport.getSessionID()
   113  	return c.clientAuthenticate(config)
   114  }
   115  
   116  
   117  
   118  func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error {
   119  	sig, rest, ok := parseSignatureBody(result.Signature)
   120  	if len(rest) > 0 || !ok {
   121  		return errors.New("ssh: signature parse error")
   122  	}
   123  
   124  	if a := underlyingAlgo(algo); sig.Format != a {
   125  		return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a)
   126  	}
   127  
   128  	return hostKey.Verify(result.H, sig)
   129  }
   130  
   131  
   132  
   133  func (c *Client) NewSession() (*Session, error) {
   134  	ch, in, err := c.OpenChannel("session", nil)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	return newSession(ch, in)
   139  }
   140  
   141  func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
   142  	for r := range incoming {
   143  		
   144  		
   145  		r.Reply(false, nil)
   146  	}
   147  }
   148  
   149  
   150  func (c *Client) handleChannelOpens(in <-chan NewChannel) {
   151  	for ch := range in {
   152  		c.mu.Lock()
   153  		handler := c.channelHandlers[ch.ChannelType()]
   154  		c.mu.Unlock()
   155  
   156  		if handler != nil {
   157  			handler <- ch
   158  		} else {
   159  			ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
   160  		}
   161  	}
   162  
   163  	c.mu.Lock()
   164  	for _, ch := range c.channelHandlers {
   165  		close(ch)
   166  	}
   167  	c.channelHandlers = nil
   168  	c.mu.Unlock()
   169  }
   170  
   171  
   172  
   173  
   174  
   175  
   176  func Dial(network, addr string, config *ClientConfig) (*Client, error) {
   177  	conn, err := net.DialTimeout(network, addr, config.Timeout)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	c, chans, reqs, err := NewClientConn(conn, addr, config)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	return NewClient(c, chans, reqs), nil
   186  }
   187  
   188  
   189  
   190  
   191  
   192  
   193  type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
   194  
   195  
   196  
   197  type BannerCallback func(message string) error
   198  
   199  
   200  
   201  type ClientConfig struct {
   202  	
   203  	
   204  	Config
   205  
   206  	
   207  	User string
   208  
   209  	
   210  	
   211  	
   212  	Auth []AuthMethod
   213  
   214  	
   215  	
   216  	
   217  	
   218  	
   219  	HostKeyCallback HostKeyCallback
   220  
   221  	
   222  	
   223  	
   224  	
   225  	BannerCallback BannerCallback
   226  
   227  	
   228  	
   229  	ClientVersion string
   230  
   231  	
   232  	
   233  	
   234  	
   235  	
   236  	HostKeyAlgorithms []string
   237  
   238  	
   239  	
   240  	
   241  	Timeout time.Duration
   242  }
   243  
   244  
   245  
   246  
   247  func InsecureIgnoreHostKey() HostKeyCallback {
   248  	return func(hostname string, remote net.Addr, key PublicKey) error {
   249  		return nil
   250  	}
   251  }
   252  
   253  type fixedHostKey struct {
   254  	key PublicKey
   255  }
   256  
   257  func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
   258  	if f.key == nil {
   259  		return fmt.Errorf("ssh: required host key was nil")
   260  	}
   261  	if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
   262  		return fmt.Errorf("ssh: host key mismatch")
   263  	}
   264  	return nil
   265  }
   266  
   267  
   268  
   269  func FixedHostKey(key PublicKey) HostKeyCallback {
   270  	hk := &fixedHostKey{key}
   271  	return hk.check
   272  }
   273  
   274  
   275  
   276  func BannerDisplayStderr() BannerCallback {
   277  	return func(banner string) error {
   278  		_, err := os.Stderr.WriteString(banner)
   279  
   280  		return err
   281  	}
   282  }
   283  
View as plain text