...

Source file src/golang.org/x/crypto/ssh/mux.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  // debugMux, if set, causes messages in the connection protocol to be
    17  // logged.
    18  const debugMux = false
    19  
    20  // chanList is a thread safe channel list.
    21  type chanList struct {
    22  	// protects concurrent access to chans
    23  	sync.Mutex
    24  
    25  	// chans are indexed by the local id of the channel, which the
    26  	// other side should send in the PeersId field.
    27  	chans []*channel
    28  
    29  	// This is a debugging aid: it offsets all IDs by this
    30  	// amount. This helps distinguish otherwise identical
    31  	// server/client muxes
    32  	offset uint32
    33  }
    34  
    35  // Assigns a channel ID to the given channel.
    36  func (c *chanList) add(ch *channel) uint32 {
    37  	c.Lock()
    38  	defer c.Unlock()
    39  	for i := range c.chans {
    40  		if c.chans[i] == nil {
    41  			c.chans[i] = ch
    42  			return uint32(i) + c.offset
    43  		}
    44  	}
    45  	c.chans = append(c.chans, ch)
    46  	return uint32(len(c.chans)-1) + c.offset
    47  }
    48  
    49  // getChan returns the channel for the given ID.
    50  func (c *chanList) getChan(id uint32) *channel {
    51  	id -= c.offset
    52  
    53  	c.Lock()
    54  	defer c.Unlock()
    55  	if id < uint32(len(c.chans)) {
    56  		return c.chans[id]
    57  	}
    58  	return nil
    59  }
    60  
    61  func (c *chanList) remove(id uint32) {
    62  	id -= c.offset
    63  	c.Lock()
    64  	if id < uint32(len(c.chans)) {
    65  		c.chans[id] = nil
    66  	}
    67  	c.Unlock()
    68  }
    69  
    70  // dropAll forgets all channels it knows, returning them in a slice.
    71  func (c *chanList) dropAll() []*channel {
    72  	c.Lock()
    73  	defer c.Unlock()
    74  	var r []*channel
    75  
    76  	for _, ch := range c.chans {
    77  		if ch == nil {
    78  			continue
    79  		}
    80  		r = append(r, ch)
    81  	}
    82  	c.chans = nil
    83  	return r
    84  }
    85  
    86  // mux represents the state for the SSH connection protocol, which
    87  // multiplexes many channels onto a single packet transport.
    88  type mux struct {
    89  	conn     packetConn
    90  	chanList chanList
    91  
    92  	incomingChannels chan NewChannel
    93  
    94  	globalSentMu     sync.Mutex
    95  	globalResponses  chan interface{}
    96  	incomingRequests chan *Request
    97  
    98  	errCond *sync.Cond
    99  	err     error
   100  }
   101  
   102  // When debugging, each new chanList instantiation has a different
   103  // offset.
   104  var globalOff uint32
   105  
   106  func (m *mux) Wait() error {
   107  	m.errCond.L.Lock()
   108  	defer m.errCond.L.Unlock()
   109  	for m.err == nil {
   110  		m.errCond.Wait()
   111  	}
   112  	return m.err
   113  }
   114  
   115  // newMux returns a mux that runs over the given connection.
   116  func newMux(p packetConn) *mux {
   117  	m := &mux{
   118  		conn:             p,
   119  		incomingChannels: make(chan NewChannel, chanSize),
   120  		globalResponses:  make(chan interface{}, 1),
   121  		incomingRequests: make(chan *Request, chanSize),
   122  		errCond:          newCond(),
   123  	}
   124  	if debugMux {
   125  		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
   126  	}
   127  
   128  	go m.loop()
   129  	return m
   130  }
   131  
   132  func (m *mux) sendMessage(msg interface{}) error {
   133  	p := Marshal(msg)
   134  	if debugMux {
   135  		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
   136  	}
   137  	return m.conn.writePacket(p)
   138  }
   139  
   140  func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
   141  	if wantReply {
   142  		m.globalSentMu.Lock()
   143  		defer m.globalSentMu.Unlock()
   144  	}
   145  
   146  	if err := m.sendMessage(globalRequestMsg{
   147  		Type:      name,
   148  		WantReply: wantReply,
   149  		Data:      payload,
   150  	}); err != nil {
   151  		return false, nil, err
   152  	}
   153  
   154  	if !wantReply {
   155  		return false, nil, nil
   156  	}
   157  
   158  	msg, ok := <-m.globalResponses
   159  	if !ok {
   160  		return false, nil, io.EOF
   161  	}
   162  	switch msg := msg.(type) {
   163  	case *globalRequestFailureMsg:
   164  		return false, msg.Data, nil
   165  	case *globalRequestSuccessMsg:
   166  		return true, msg.Data, nil
   167  	default:
   168  		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
   169  	}
   170  }
   171  
   172  // ackRequest must be called after processing a global request that
   173  // has WantReply set.
   174  func (m *mux) ackRequest(ok bool, data []byte) error {
   175  	if ok {
   176  		return m.sendMessage(globalRequestSuccessMsg{Data: data})
   177  	}
   178  	return m.sendMessage(globalRequestFailureMsg{Data: data})
   179  }
   180  
   181  func (m *mux) Close() error {
   182  	return m.conn.Close()
   183  }
   184  
   185  // loop runs the connection machine. It will process packets until an
   186  // error is encountered. To synchronize on loop exit, use mux.Wait.
   187  func (m *mux) loop() {
   188  	var err error
   189  	for err == nil {
   190  		err = m.onePacket()
   191  	}
   192  
   193  	for _, ch := range m.chanList.dropAll() {
   194  		ch.close()
   195  	}
   196  
   197  	close(m.incomingChannels)
   198  	close(m.incomingRequests)
   199  	close(m.globalResponses)
   200  
   201  	m.conn.Close()
   202  
   203  	m.errCond.L.Lock()
   204  	m.err = err
   205  	m.errCond.Broadcast()
   206  	m.errCond.L.Unlock()
   207  
   208  	if debugMux {
   209  		log.Println("loop exit", err)
   210  	}
   211  }
   212  
   213  // onePacket reads and processes one packet.
   214  func (m *mux) onePacket() error {
   215  	packet, err := m.conn.readPacket()
   216  	if err != nil {
   217  		return err
   218  	}
   219  
   220  	if debugMux {
   221  		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
   222  			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
   223  		} else {
   224  			p, _ := decode(packet)
   225  			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
   226  		}
   227  	}
   228  
   229  	switch packet[0] {
   230  	case msgChannelOpen:
   231  		return m.handleChannelOpen(packet)
   232  	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
   233  		return m.handleGlobalPacket(packet)
   234  	case msgPing:
   235  		var msg pingMsg
   236  		if err := Unmarshal(packet, &msg); err != nil {
   237  			return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
   238  		}
   239  		return m.sendMessage(pongMsg(msg))
   240  	}
   241  
   242  	// assume a channel packet.
   243  	if len(packet) < 5 {
   244  		return parseError(packet[0])
   245  	}
   246  	id := binary.BigEndian.Uint32(packet[1:])
   247  	ch := m.chanList.getChan(id)
   248  	if ch == nil {
   249  		return m.handleUnknownChannelPacket(id, packet)
   250  	}
   251  
   252  	return ch.handlePacket(packet)
   253  }
   254  
   255  func (m *mux) handleGlobalPacket(packet []byte) error {
   256  	msg, err := decode(packet)
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	switch msg := msg.(type) {
   262  	case *globalRequestMsg:
   263  		m.incomingRequests <- &Request{
   264  			Type:      msg.Type,
   265  			WantReply: msg.WantReply,
   266  			Payload:   msg.Data,
   267  			mux:       m,
   268  		}
   269  	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
   270  		m.globalResponses <- msg
   271  	default:
   272  		panic(fmt.Sprintf("not a global message %#v", msg))
   273  	}
   274  
   275  	return nil
   276  }
   277  
   278  // handleChannelOpen schedules a channel to be Accept()ed.
   279  func (m *mux) handleChannelOpen(packet []byte) error {
   280  	var msg channelOpenMsg
   281  	if err := Unmarshal(packet, &msg); err != nil {
   282  		return err
   283  	}
   284  
   285  	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
   286  		failMsg := channelOpenFailureMsg{
   287  			PeersID:  msg.PeersID,
   288  			Reason:   ConnectionFailed,
   289  			Message:  "invalid request",
   290  			Language: "en_US.UTF-8",
   291  		}
   292  		return m.sendMessage(failMsg)
   293  	}
   294  
   295  	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
   296  	c.remoteId = msg.PeersID
   297  	c.maxRemotePayload = msg.MaxPacketSize
   298  	c.remoteWin.add(msg.PeersWindow)
   299  	m.incomingChannels <- c
   300  	return nil
   301  }
   302  
   303  func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
   304  	ch, err := m.openChannel(chanType, extra)
   305  	if err != nil {
   306  		return nil, nil, err
   307  	}
   308  
   309  	return ch, ch.incomingRequests, nil
   310  }
   311  
   312  func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
   313  	ch := m.newChannel(chanType, channelOutbound, extra)
   314  
   315  	ch.maxIncomingPayload = channelMaxPacket
   316  
   317  	open := channelOpenMsg{
   318  		ChanType:         chanType,
   319  		PeersWindow:      ch.myWindow,
   320  		MaxPacketSize:    ch.maxIncomingPayload,
   321  		TypeSpecificData: extra,
   322  		PeersID:          ch.localId,
   323  	}
   324  	if err := m.sendMessage(open); err != nil {
   325  		return nil, err
   326  	}
   327  
   328  	switch msg := (<-ch.msg).(type) {
   329  	case *channelOpenConfirmMsg:
   330  		return ch, nil
   331  	case *channelOpenFailureMsg:
   332  		return nil, &OpenChannelError{msg.Reason, msg.Message}
   333  	default:
   334  		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
   335  	}
   336  }
   337  
   338  func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
   339  	msg, err := decode(packet)
   340  	if err != nil {
   341  		return err
   342  	}
   343  
   344  	switch msg := msg.(type) {
   345  	// RFC 4254 section 5.4 says unrecognized channel requests should
   346  	// receive a failure response.
   347  	case *channelRequestMsg:
   348  		if msg.WantReply {
   349  			return m.sendMessage(channelRequestFailureMsg{
   350  				PeersID: msg.PeersID,
   351  			})
   352  		}
   353  		return nil
   354  	default:
   355  		return fmt.Errorf("ssh: invalid channel %d", id)
   356  	}
   357  }
   358  

View as plain text