1
2
3
4
5 package ssh
6
7 import (
8 "encoding/binary"
9 "fmt"
10 "io"
11 "log"
12 "sync"
13 "sync/atomic"
14 )
15
16
17
18 const debugMux = false
19
20
21 type chanList struct {
22
23 sync.Mutex
24
25
26
27 chans []*channel
28
29
30
31
32 offset uint32
33 }
34
35
36 func (c *chanList) add(ch *channel) uint32 {
37 c.Lock()
38 defer c.Unlock()
39 for i := range c.chans {
40 if c.chans[i] == nil {
41 c.chans[i] = ch
42 return uint32(i) + c.offset
43 }
44 }
45 c.chans = append(c.chans, ch)
46 return uint32(len(c.chans)-1) + c.offset
47 }
48
49
50 func (c *chanList) getChan(id uint32) *channel {
51 id -= c.offset
52
53 c.Lock()
54 defer c.Unlock()
55 if id < uint32(len(c.chans)) {
56 return c.chans[id]
57 }
58 return nil
59 }
60
61 func (c *chanList) remove(id uint32) {
62 id -= c.offset
63 c.Lock()
64 if id < uint32(len(c.chans)) {
65 c.chans[id] = nil
66 }
67 c.Unlock()
68 }
69
70
71 func (c *chanList) dropAll() []*channel {
72 c.Lock()
73 defer c.Unlock()
74 var r []*channel
75
76 for _, ch := range c.chans {
77 if ch == nil {
78 continue
79 }
80 r = append(r, ch)
81 }
82 c.chans = nil
83 return r
84 }
85
86
87
88 type mux struct {
89 conn packetConn
90 chanList chanList
91
92 incomingChannels chan NewChannel
93
94 globalSentMu sync.Mutex
95 globalResponses chan interface{}
96 incomingRequests chan *Request
97
98 errCond *sync.Cond
99 err error
100 }
101
102
103
104 var globalOff uint32
105
106 func (m *mux) Wait() error {
107 m.errCond.L.Lock()
108 defer m.errCond.L.Unlock()
109 for m.err == nil {
110 m.errCond.Wait()
111 }
112 return m.err
113 }
114
115
116 func newMux(p packetConn) *mux {
117 m := &mux{
118 conn: p,
119 incomingChannels: make(chan NewChannel, chanSize),
120 globalResponses: make(chan interface{}, 1),
121 incomingRequests: make(chan *Request, chanSize),
122 errCond: newCond(),
123 }
124 if debugMux {
125 m.chanList.offset = atomic.AddUint32(&globalOff, 1)
126 }
127
128 go m.loop()
129 return m
130 }
131
132 func (m *mux) sendMessage(msg interface{}) error {
133 p := Marshal(msg)
134 if debugMux {
135 log.Printf("send global(%d): %#v", m.chanList.offset, msg)
136 }
137 return m.conn.writePacket(p)
138 }
139
140 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
141 if wantReply {
142 m.globalSentMu.Lock()
143 defer m.globalSentMu.Unlock()
144 }
145
146 if err := m.sendMessage(globalRequestMsg{
147 Type: name,
148 WantReply: wantReply,
149 Data: payload,
150 }); err != nil {
151 return false, nil, err
152 }
153
154 if !wantReply {
155 return false, nil, nil
156 }
157
158 msg, ok := <-m.globalResponses
159 if !ok {
160 return false, nil, io.EOF
161 }
162 switch msg := msg.(type) {
163 case *globalRequestFailureMsg:
164 return false, msg.Data, nil
165 case *globalRequestSuccessMsg:
166 return true, msg.Data, nil
167 default:
168 return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
169 }
170 }
171
172
173
174 func (m *mux) ackRequest(ok bool, data []byte) error {
175 if ok {
176 return m.sendMessage(globalRequestSuccessMsg{Data: data})
177 }
178 return m.sendMessage(globalRequestFailureMsg{Data: data})
179 }
180
181 func (m *mux) Close() error {
182 return m.conn.Close()
183 }
184
185
186
187 func (m *mux) loop() {
188 var err error
189 for err == nil {
190 err = m.onePacket()
191 }
192
193 for _, ch := range m.chanList.dropAll() {
194 ch.close()
195 }
196
197 close(m.incomingChannels)
198 close(m.incomingRequests)
199 close(m.globalResponses)
200
201 m.conn.Close()
202
203 m.errCond.L.Lock()
204 m.err = err
205 m.errCond.Broadcast()
206 m.errCond.L.Unlock()
207
208 if debugMux {
209 log.Println("loop exit", err)
210 }
211 }
212
213
214 func (m *mux) onePacket() error {
215 packet, err := m.conn.readPacket()
216 if err != nil {
217 return err
218 }
219
220 if debugMux {
221 if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
222 log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
223 } else {
224 p, _ := decode(packet)
225 log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
226 }
227 }
228
229 switch packet[0] {
230 case msgChannelOpen:
231 return m.handleChannelOpen(packet)
232 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
233 return m.handleGlobalPacket(packet)
234 case msgPing:
235 var msg pingMsg
236 if err := Unmarshal(packet, &msg); err != nil {
237 return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
238 }
239 return m.sendMessage(pongMsg(msg))
240 }
241
242
243 if len(packet) < 5 {
244 return parseError(packet[0])
245 }
246 id := binary.BigEndian.Uint32(packet[1:])
247 ch := m.chanList.getChan(id)
248 if ch == nil {
249 return m.handleUnknownChannelPacket(id, packet)
250 }
251
252 return ch.handlePacket(packet)
253 }
254
255 func (m *mux) handleGlobalPacket(packet []byte) error {
256 msg, err := decode(packet)
257 if err != nil {
258 return err
259 }
260
261 switch msg := msg.(type) {
262 case *globalRequestMsg:
263 m.incomingRequests <- &Request{
264 Type: msg.Type,
265 WantReply: msg.WantReply,
266 Payload: msg.Data,
267 mux: m,
268 }
269 case *globalRequestSuccessMsg, *globalRequestFailureMsg:
270 m.globalResponses <- msg
271 default:
272 panic(fmt.Sprintf("not a global message %#v", msg))
273 }
274
275 return nil
276 }
277
278
279 func (m *mux) handleChannelOpen(packet []byte) error {
280 var msg channelOpenMsg
281 if err := Unmarshal(packet, &msg); err != nil {
282 return err
283 }
284
285 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
286 failMsg := channelOpenFailureMsg{
287 PeersID: msg.PeersID,
288 Reason: ConnectionFailed,
289 Message: "invalid request",
290 Language: "en_US.UTF-8",
291 }
292 return m.sendMessage(failMsg)
293 }
294
295 c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
296 c.remoteId = msg.PeersID
297 c.maxRemotePayload = msg.MaxPacketSize
298 c.remoteWin.add(msg.PeersWindow)
299 m.incomingChannels <- c
300 return nil
301 }
302
303 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
304 ch, err := m.openChannel(chanType, extra)
305 if err != nil {
306 return nil, nil, err
307 }
308
309 return ch, ch.incomingRequests, nil
310 }
311
312 func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
313 ch := m.newChannel(chanType, channelOutbound, extra)
314
315 ch.maxIncomingPayload = channelMaxPacket
316
317 open := channelOpenMsg{
318 ChanType: chanType,
319 PeersWindow: ch.myWindow,
320 MaxPacketSize: ch.maxIncomingPayload,
321 TypeSpecificData: extra,
322 PeersID: ch.localId,
323 }
324 if err := m.sendMessage(open); err != nil {
325 return nil, err
326 }
327
328 switch msg := (<-ch.msg).(type) {
329 case *channelOpenConfirmMsg:
330 return ch, nil
331 case *channelOpenFailureMsg:
332 return nil, &OpenChannelError{msg.Reason, msg.Message}
333 default:
334 return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
335 }
336 }
337
338 func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
339 msg, err := decode(packet)
340 if err != nil {
341 return err
342 }
343
344 switch msg := msg.(type) {
345
346
347 case *channelRequestMsg:
348 if msg.WantReply {
349 return m.sendMessage(channelRequestFailureMsg{
350 PeersID: msg.PeersID,
351 })
352 }
353 return nil
354 default:
355 return fmt.Errorf("ssh: invalid channel %d", id)
356 }
357 }
358
View as plain text