...

Source file src/golang.org/x/crypto/ssh/tcpip.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/rand"
    13  	"net"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  // Listen requests the remote peer open a listening socket on
    21  // addr. Incoming connections will be available by calling Accept on
    22  // the returned net.Listener. The listener must be serviced, or the
    23  // SSH connection may hang.
    24  // N must be "tcp", "tcp4", "tcp6", or "unix".
    25  func (c *Client) Listen(n, addr string) (net.Listener, error) {
    26  	switch n {
    27  	case "tcp", "tcp4", "tcp6":
    28  		laddr, err := net.ResolveTCPAddr(n, addr)
    29  		if err != nil {
    30  			return nil, err
    31  		}
    32  		return c.ListenTCP(laddr)
    33  	case "unix":
    34  		return c.ListenUnix(addr)
    35  	default:
    36  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
    37  	}
    38  }
    39  
    40  // Automatic port allocation is broken with OpenSSH before 6.0. See
    41  // also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
    42  // particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
    43  // rather than the actual port number. This means you can never open
    44  // two different listeners with auto allocated ports. We work around
    45  // this by trying explicit ports until we succeed.
    46  
    47  const openSSHPrefix = "OpenSSH_"
    48  
    49  var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
    50  
    51  // isBrokenOpenSSHVersion returns true if the given version string
    52  // specifies a version of OpenSSH that is known to have a bug in port
    53  // forwarding.
    54  func isBrokenOpenSSHVersion(versionStr string) bool {
    55  	i := strings.Index(versionStr, openSSHPrefix)
    56  	if i < 0 {
    57  		return false
    58  	}
    59  	i += len(openSSHPrefix)
    60  	j := i
    61  	for ; j < len(versionStr); j++ {
    62  		if versionStr[j] < '0' || versionStr[j] > '9' {
    63  			break
    64  		}
    65  	}
    66  	version, _ := strconv.Atoi(versionStr[i:j])
    67  	return version < 6
    68  }
    69  
    70  // autoPortListenWorkaround simulates automatic port allocation by
    71  // trying random ports repeatedly.
    72  func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
    73  	var sshListener net.Listener
    74  	var err error
    75  	const tries = 10
    76  	for i := 0; i < tries; i++ {
    77  		addr := *laddr
    78  		addr.Port = 1024 + portRandomizer.Intn(60000)
    79  		sshListener, err = c.ListenTCP(&addr)
    80  		if err == nil {
    81  			laddr.Port = addr.Port
    82  			return sshListener, err
    83  		}
    84  	}
    85  	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
    86  }
    87  
    88  // RFC 4254 7.1
    89  type channelForwardMsg struct {
    90  	addr  string
    91  	rport uint32
    92  }
    93  
    94  // handleForwards starts goroutines handling forwarded connections.
    95  // It's called on first use by (*Client).ListenTCP to not launch
    96  // goroutines until needed.
    97  func (c *Client) handleForwards() {
    98  	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
    99  	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
   100  }
   101  
   102  // ListenTCP requests the remote peer open a listening socket
   103  // on laddr. Incoming connections will be available by calling
   104  // Accept on the returned net.Listener.
   105  func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
   106  	c.handleForwardsOnce.Do(c.handleForwards)
   107  	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
   108  		return c.autoPortListenWorkaround(laddr)
   109  	}
   110  
   111  	m := channelForwardMsg{
   112  		laddr.IP.String(),
   113  		uint32(laddr.Port),
   114  	}
   115  	// send message
   116  	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	if !ok {
   121  		return nil, errors.New("ssh: tcpip-forward request denied by peer")
   122  	}
   123  
   124  	// If the original port was 0, then the remote side will
   125  	// supply a real port number in the response.
   126  	if laddr.Port == 0 {
   127  		var p struct {
   128  			Port uint32
   129  		}
   130  		if err := Unmarshal(resp, &p); err != nil {
   131  			return nil, err
   132  		}
   133  		laddr.Port = int(p.Port)
   134  	}
   135  
   136  	// Register this forward, using the port number we obtained.
   137  	ch := c.forwards.add(laddr)
   138  
   139  	return &tcpListener{laddr, c, ch}, nil
   140  }
   141  
   142  // forwardList stores a mapping between remote
   143  // forward requests and the tcpListeners.
   144  type forwardList struct {
   145  	sync.Mutex
   146  	entries []forwardEntry
   147  }
   148  
   149  // forwardEntry represents an established mapping of a laddr on a
   150  // remote ssh server to a channel connected to a tcpListener.
   151  type forwardEntry struct {
   152  	laddr net.Addr
   153  	c     chan forward
   154  }
   155  
   156  // forward represents an incoming forwarded tcpip connection. The
   157  // arguments to add/remove/lookup should be address as specified in
   158  // the original forward-request.
   159  type forward struct {
   160  	newCh NewChannel // the ssh client channel underlying this forward
   161  	raddr net.Addr   // the raddr of the incoming connection
   162  }
   163  
   164  func (l *forwardList) add(addr net.Addr) chan forward {
   165  	l.Lock()
   166  	defer l.Unlock()
   167  	f := forwardEntry{
   168  		laddr: addr,
   169  		c:     make(chan forward, 1),
   170  	}
   171  	l.entries = append(l.entries, f)
   172  	return f.c
   173  }
   174  
   175  // See RFC 4254, section 7.2
   176  type forwardedTCPPayload struct {
   177  	Addr       string
   178  	Port       uint32
   179  	OriginAddr string
   180  	OriginPort uint32
   181  }
   182  
   183  // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
   184  func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
   185  	if port == 0 || port > 65535 {
   186  		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
   187  	}
   188  	ip := net.ParseIP(string(addr))
   189  	if ip == nil {
   190  		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
   191  	}
   192  	return &net.TCPAddr{IP: ip, Port: int(port)}, nil
   193  }
   194  
   195  func (l *forwardList) handleChannels(in <-chan NewChannel) {
   196  	for ch := range in {
   197  		var (
   198  			laddr net.Addr
   199  			raddr net.Addr
   200  			err   error
   201  		)
   202  		switch channelType := ch.ChannelType(); channelType {
   203  		case "forwarded-tcpip":
   204  			var payload forwardedTCPPayload
   205  			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
   206  				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
   207  				continue
   208  			}
   209  
   210  			// RFC 4254 section 7.2 specifies that incoming
   211  			// addresses should list the address, in string
   212  			// format. It is implied that this should be an IP
   213  			// address, as it would be impossible to connect to it
   214  			// otherwise.
   215  			laddr, err = parseTCPAddr(payload.Addr, payload.Port)
   216  			if err != nil {
   217  				ch.Reject(ConnectionFailed, err.Error())
   218  				continue
   219  			}
   220  			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
   221  			if err != nil {
   222  				ch.Reject(ConnectionFailed, err.Error())
   223  				continue
   224  			}
   225  
   226  		case "forwarded-streamlocal@openssh.com":
   227  			var payload forwardedStreamLocalPayload
   228  			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
   229  				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
   230  				continue
   231  			}
   232  			laddr = &net.UnixAddr{
   233  				Name: payload.SocketPath,
   234  				Net:  "unix",
   235  			}
   236  			raddr = &net.UnixAddr{
   237  				Name: "@",
   238  				Net:  "unix",
   239  			}
   240  		default:
   241  			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
   242  		}
   243  		if ok := l.forward(laddr, raddr, ch); !ok {
   244  			// Section 7.2, implementations MUST reject spurious incoming
   245  			// connections.
   246  			ch.Reject(Prohibited, "no forward for address")
   247  			continue
   248  		}
   249  
   250  	}
   251  }
   252  
   253  // remove removes the forward entry, and the channel feeding its
   254  // listener.
   255  func (l *forwardList) remove(addr net.Addr) {
   256  	l.Lock()
   257  	defer l.Unlock()
   258  	for i, f := range l.entries {
   259  		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
   260  			l.entries = append(l.entries[:i], l.entries[i+1:]...)
   261  			close(f.c)
   262  			return
   263  		}
   264  	}
   265  }
   266  
   267  // closeAll closes and clears all forwards.
   268  func (l *forwardList) closeAll() {
   269  	l.Lock()
   270  	defer l.Unlock()
   271  	for _, f := range l.entries {
   272  		close(f.c)
   273  	}
   274  	l.entries = nil
   275  }
   276  
   277  func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
   278  	l.Lock()
   279  	defer l.Unlock()
   280  	for _, f := range l.entries {
   281  		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
   282  			f.c <- forward{newCh: ch, raddr: raddr}
   283  			return true
   284  		}
   285  	}
   286  	return false
   287  }
   288  
   289  type tcpListener struct {
   290  	laddr *net.TCPAddr
   291  
   292  	conn *Client
   293  	in   <-chan forward
   294  }
   295  
   296  // Accept waits for and returns the next connection to the listener.
   297  func (l *tcpListener) Accept() (net.Conn, error) {
   298  	s, ok := <-l.in
   299  	if !ok {
   300  		return nil, io.EOF
   301  	}
   302  	ch, incoming, err := s.newCh.Accept()
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  	go DiscardRequests(incoming)
   307  
   308  	return &chanConn{
   309  		Channel: ch,
   310  		laddr:   l.laddr,
   311  		raddr:   s.raddr,
   312  	}, nil
   313  }
   314  
   315  // Close closes the listener.
   316  func (l *tcpListener) Close() error {
   317  	m := channelForwardMsg{
   318  		l.laddr.IP.String(),
   319  		uint32(l.laddr.Port),
   320  	}
   321  
   322  	// this also closes the listener.
   323  	l.conn.forwards.remove(l.laddr)
   324  	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
   325  	if err == nil && !ok {
   326  		err = errors.New("ssh: cancel-tcpip-forward failed")
   327  	}
   328  	return err
   329  }
   330  
   331  // Addr returns the listener's network address.
   332  func (l *tcpListener) Addr() net.Addr {
   333  	return l.laddr
   334  }
   335  
   336  // DialContext initiates a connection to the addr from the remote host.
   337  //
   338  // The provided Context must be non-nil. If the context expires before the
   339  // connection is complete, an error is returned. Once successfully connected,
   340  // any expiration of the context will not affect the connection.
   341  //
   342  // See func Dial for additional information.
   343  func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
   344  	if err := ctx.Err(); err != nil {
   345  		return nil, err
   346  	}
   347  	type connErr struct {
   348  		conn net.Conn
   349  		err  error
   350  	}
   351  	ch := make(chan connErr)
   352  	go func() {
   353  		conn, err := c.Dial(n, addr)
   354  		select {
   355  		case ch <- connErr{conn, err}:
   356  		case <-ctx.Done():
   357  			if conn != nil {
   358  				conn.Close()
   359  			}
   360  		}
   361  	}()
   362  	select {
   363  	case res := <-ch:
   364  		return res.conn, res.err
   365  	case <-ctx.Done():
   366  		return nil, ctx.Err()
   367  	}
   368  }
   369  
   370  // Dial initiates a connection to the addr from the remote host.
   371  // The resulting connection has a zero LocalAddr() and RemoteAddr().
   372  func (c *Client) Dial(n, addr string) (net.Conn, error) {
   373  	var ch Channel
   374  	switch n {
   375  	case "tcp", "tcp4", "tcp6":
   376  		// Parse the address into host and numeric port.
   377  		host, portString, err := net.SplitHostPort(addr)
   378  		if err != nil {
   379  			return nil, err
   380  		}
   381  		port, err := strconv.ParseUint(portString, 10, 16)
   382  		if err != nil {
   383  			return nil, err
   384  		}
   385  		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
   386  		if err != nil {
   387  			return nil, err
   388  		}
   389  		// Use a zero address for local and remote address.
   390  		zeroAddr := &net.TCPAddr{
   391  			IP:   net.IPv4zero,
   392  			Port: 0,
   393  		}
   394  		return &chanConn{
   395  			Channel: ch,
   396  			laddr:   zeroAddr,
   397  			raddr:   zeroAddr,
   398  		}, nil
   399  	case "unix":
   400  		var err error
   401  		ch, err = c.dialStreamLocal(addr)
   402  		if err != nil {
   403  			return nil, err
   404  		}
   405  		return &chanConn{
   406  			Channel: ch,
   407  			laddr: &net.UnixAddr{
   408  				Name: "@",
   409  				Net:  "unix",
   410  			},
   411  			raddr: &net.UnixAddr{
   412  				Name: addr,
   413  				Net:  "unix",
   414  			},
   415  		}, nil
   416  	default:
   417  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
   418  	}
   419  }
   420  
   421  // DialTCP connects to the remote address raddr on the network net,
   422  // which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
   423  // as the local address for the connection.
   424  func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
   425  	if laddr == nil {
   426  		laddr = &net.TCPAddr{
   427  			IP:   net.IPv4zero,
   428  			Port: 0,
   429  		}
   430  	}
   431  	ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  	return &chanConn{
   436  		Channel: ch,
   437  		laddr:   laddr,
   438  		raddr:   raddr,
   439  	}, nil
   440  }
   441  
   442  // RFC 4254 7.2
   443  type channelOpenDirectMsg struct {
   444  	raddr string
   445  	rport uint32
   446  	laddr string
   447  	lport uint32
   448  }
   449  
   450  func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
   451  	msg := channelOpenDirectMsg{
   452  		raddr: raddr,
   453  		rport: uint32(rport),
   454  		laddr: laddr,
   455  		lport: uint32(lport),
   456  	}
   457  	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
   458  	if err != nil {
   459  		return nil, err
   460  	}
   461  	go DiscardRequests(in)
   462  	return ch, err
   463  }
   464  
   465  type tcpChan struct {
   466  	Channel // the backing channel
   467  }
   468  
   469  // chanConn fulfills the net.Conn interface without
   470  // the tcpChan having to hold laddr or raddr directly.
   471  type chanConn struct {
   472  	Channel
   473  	laddr, raddr net.Addr
   474  }
   475  
   476  // LocalAddr returns the local network address.
   477  func (t *chanConn) LocalAddr() net.Addr {
   478  	return t.laddr
   479  }
   480  
   481  // RemoteAddr returns the remote network address.
   482  func (t *chanConn) RemoteAddr() net.Addr {
   483  	return t.raddr
   484  }
   485  
   486  // SetDeadline sets the read and write deadlines associated
   487  // with the connection.
   488  func (t *chanConn) SetDeadline(deadline time.Time) error {
   489  	if err := t.SetReadDeadline(deadline); err != nil {
   490  		return err
   491  	}
   492  	return t.SetWriteDeadline(deadline)
   493  }
   494  
   495  // SetReadDeadline sets the read deadline.
   496  // A zero value for t means Read will not time out.
   497  // After the deadline, the error from Read will implement net.Error
   498  // with Timeout() == true.
   499  func (t *chanConn) SetReadDeadline(deadline time.Time) error {
   500  	// for compatibility with previous version,
   501  	// the error message contains "tcpChan"
   502  	return errors.New("ssh: tcpChan: deadline not supported")
   503  }
   504  
   505  // SetWriteDeadline exists to satisfy the net.Conn interface
   506  // but is not implemented by this type.  It always returns an error.
   507  func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
   508  	return errors.New("ssh: tcpChan: deadline not supported")
   509  }
   510  

View as plain text