1
2
3
4
5 package ssh
6
7 import (
8 "encoding/binary"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "sync"
14 )
15
16 const (
17 minPacketLength = 9
18
19
20
21 channelMaxPacket = 1 << 15
22
23 channelWindowSize = 64 * channelMaxPacket
24 )
25
26
27
28 type NewChannel interface {
29
30
31
32 Accept() (Channel, <-chan *Request, error)
33
34
35
36 Reject(reason RejectionReason, message string) error
37
38
39
40 ChannelType() string
41
42
43
44 ExtraData() []byte
45 }
46
47
48
49 type Channel interface {
50
51 Read(data []byte) (int, error)
52
53
54 Write(data []byte) (int, error)
55
56
57
58 Close() error
59
60
61
62
63 CloseWrite() error
64
65
66
67
68
69
70
71
72 SendRequest(name string, wantReply bool, payload []byte) (bool, error)
73
74
75
76
77
78 Stderr() io.ReadWriter
79 }
80
81
82
83
84 type Request struct {
85 Type string
86 WantReply bool
87 Payload []byte
88
89 ch *channel
90 mux *mux
91 }
92
93
94
95
96 func (r *Request) Reply(ok bool, payload []byte) error {
97 if !r.WantReply {
98 return nil
99 }
100
101 if r.ch == nil {
102 return r.mux.ackRequest(ok, payload)
103 }
104
105 return r.ch.ackRequest(ok)
106 }
107
108
109
110 type RejectionReason uint32
111
112 const (
113 Prohibited RejectionReason = iota + 1
114 ConnectionFailed
115 UnknownChannelType
116 ResourceShortage
117 )
118
119
120 func (r RejectionReason) String() string {
121 switch r {
122 case Prohibited:
123 return "administratively prohibited"
124 case ConnectionFailed:
125 return "connect failed"
126 case UnknownChannelType:
127 return "unknown channel type"
128 case ResourceShortage:
129 return "resource shortage"
130 }
131 return fmt.Sprintf("unknown reason %d", int(r))
132 }
133
134 func min(a uint32, b int) uint32 {
135 if a < uint32(b) {
136 return a
137 }
138 return uint32(b)
139 }
140
141 type channelDirection uint8
142
143 const (
144 channelInbound channelDirection = iota
145 channelOutbound
146 )
147
148
149
150 type channel struct {
151
152 chanType string
153 extraData []byte
154 localId, remoteId uint32
155
156
157
158
159
160 maxIncomingPayload uint32
161 maxRemotePayload uint32
162
163 mux *mux
164
165
166
167 decided bool
168
169
170
171 direction channelDirection
172
173
174 msg chan interface{}
175
176
177
178
179 sentRequestMu sync.Mutex
180
181 incomingRequests chan *Request
182
183 sentEOF bool
184
185
186 remoteWin window
187 pending *buffer
188 extPending *buffer
189
190
191
192 windowMu sync.Mutex
193 myWindow uint32
194 myConsumed uint32
195
196
197
198
199
200 writeMu sync.Mutex
201 sentClose bool
202
203
204
205 packetPool map[uint32][]byte
206 }
207
208
209
210 func (ch *channel) writePacket(packet []byte) error {
211 ch.writeMu.Lock()
212 if ch.sentClose {
213 ch.writeMu.Unlock()
214 return io.EOF
215 }
216 ch.sentClose = (packet[0] == msgChannelClose)
217 err := ch.mux.conn.writePacket(packet)
218 ch.writeMu.Unlock()
219 return err
220 }
221
222 func (ch *channel) sendMessage(msg interface{}) error {
223 if debugMux {
224 log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
225 }
226
227 p := Marshal(msg)
228 binary.BigEndian.PutUint32(p[1:], ch.remoteId)
229 return ch.writePacket(p)
230 }
231
232
233
234 func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
235 if ch.sentEOF {
236 return 0, io.EOF
237 }
238
239 opCode := byte(msgChannelData)
240 headerLength := uint32(9)
241 if extendedCode > 0 {
242 headerLength += 4
243 opCode = msgChannelExtendedData
244 }
245
246 ch.writeMu.Lock()
247 packet := ch.packetPool[extendedCode]
248
249
250
251 ch.writeMu.Unlock()
252
253 for len(data) > 0 {
254 space := min(ch.maxRemotePayload, len(data))
255 if space, err = ch.remoteWin.reserve(space); err != nil {
256 return n, err
257 }
258 if want := headerLength + space; uint32(cap(packet)) < want {
259 packet = make([]byte, want)
260 } else {
261 packet = packet[:want]
262 }
263
264 todo := data[:space]
265
266 packet[0] = opCode
267 binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
268 if extendedCode > 0 {
269 binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
270 }
271 binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
272 copy(packet[headerLength:], todo)
273 if err = ch.writePacket(packet); err != nil {
274 return n, err
275 }
276
277 n += len(todo)
278 data = data[len(todo):]
279 }
280
281 ch.writeMu.Lock()
282 ch.packetPool[extendedCode] = packet
283 ch.writeMu.Unlock()
284
285 return n, err
286 }
287
288 func (ch *channel) handleData(packet []byte) error {
289 headerLen := 9
290 isExtendedData := packet[0] == msgChannelExtendedData
291 if isExtendedData {
292 headerLen = 13
293 }
294 if len(packet) < headerLen {
295
296 return parseError(packet[0])
297 }
298
299 var extended uint32
300 if isExtendedData {
301 extended = binary.BigEndian.Uint32(packet[5:])
302 }
303
304 length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
305 if length == 0 {
306 return nil
307 }
308 if length > ch.maxIncomingPayload {
309
310 return errors.New("ssh: incoming packet exceeds maximum payload size")
311 }
312
313 data := packet[headerLen:]
314 if length != uint32(len(data)) {
315 return errors.New("ssh: wrong packet length")
316 }
317
318 ch.windowMu.Lock()
319 if ch.myWindow < length {
320 ch.windowMu.Unlock()
321
322 return errors.New("ssh: remote side wrote too much")
323 }
324 ch.myWindow -= length
325 ch.windowMu.Unlock()
326
327 if extended == 1 {
328 ch.extPending.write(data)
329 } else if extended > 0 {
330
331 } else {
332 ch.pending.write(data)
333 }
334 return nil
335 }
336
337 func (c *channel) adjustWindow(adj uint32) error {
338 c.windowMu.Lock()
339
340
341 c.myConsumed += adj
342 var sendAdj uint32
343 if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
344 (c.myWindow < channelWindowSize/2) {
345 sendAdj = c.myConsumed
346 c.myConsumed = 0
347 c.myWindow += sendAdj
348 }
349 c.windowMu.Unlock()
350 if sendAdj == 0 {
351 return nil
352 }
353 return c.sendMessage(windowAdjustMsg{
354 AdditionalBytes: sendAdj,
355 })
356 }
357
358 func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
359 switch extended {
360 case 1:
361 n, err = c.extPending.Read(data)
362 case 0:
363 n, err = c.pending.Read(data)
364 default:
365 return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
366 }
367
368 if n > 0 {
369 err = c.adjustWindow(uint32(n))
370
371
372
373
374 if n > 0 && err == io.EOF {
375 err = nil
376 }
377 }
378
379 return n, err
380 }
381
382 func (c *channel) close() {
383 c.pending.eof()
384 c.extPending.eof()
385 close(c.msg)
386 close(c.incomingRequests)
387 c.writeMu.Lock()
388
389
390 c.sentClose = true
391 c.writeMu.Unlock()
392
393 c.remoteWin.close()
394 }
395
396
397
398
399 func (ch *channel) responseMessageReceived() error {
400 if ch.direction == channelInbound {
401 return errors.New("ssh: channel response message received on inbound channel")
402 }
403 if ch.decided {
404 return errors.New("ssh: duplicate response received for channel")
405 }
406 ch.decided = true
407 return nil
408 }
409
410 func (ch *channel) handlePacket(packet []byte) error {
411 switch packet[0] {
412 case msgChannelData, msgChannelExtendedData:
413 return ch.handleData(packet)
414 case msgChannelClose:
415 ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
416 ch.mux.chanList.remove(ch.localId)
417 ch.close()
418 return nil
419 case msgChannelEOF:
420
421
422 ch.extPending.eof()
423 ch.pending.eof()
424 return nil
425 }
426
427 decoded, err := decode(packet)
428 if err != nil {
429 return err
430 }
431
432 switch msg := decoded.(type) {
433 case *channelOpenFailureMsg:
434 if err := ch.responseMessageReceived(); err != nil {
435 return err
436 }
437 ch.mux.chanList.remove(msg.PeersID)
438 ch.msg <- msg
439 case *channelOpenConfirmMsg:
440 if err := ch.responseMessageReceived(); err != nil {
441 return err
442 }
443 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
444 return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
445 }
446 ch.remoteId = msg.MyID
447 ch.maxRemotePayload = msg.MaxPacketSize
448 ch.remoteWin.add(msg.MyWindow)
449 ch.msg <- msg
450 case *windowAdjustMsg:
451 if !ch.remoteWin.add(msg.AdditionalBytes) {
452 return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
453 }
454 case *channelRequestMsg:
455 req := Request{
456 Type: msg.Request,
457 WantReply: msg.WantReply,
458 Payload: msg.RequestSpecificData,
459 ch: ch,
460 }
461
462 ch.incomingRequests <- &req
463 default:
464 ch.msg <- msg
465 }
466 return nil
467 }
468
469 func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
470 ch := &channel{
471 remoteWin: window{Cond: newCond()},
472 myWindow: channelWindowSize,
473 pending: newBuffer(),
474 extPending: newBuffer(),
475 direction: direction,
476 incomingRequests: make(chan *Request, chanSize),
477 msg: make(chan interface{}, chanSize),
478 chanType: chanType,
479 extraData: extraData,
480 mux: m,
481 packetPool: make(map[uint32][]byte),
482 }
483 ch.localId = m.chanList.add(ch)
484 return ch
485 }
486
487 var errUndecided = errors.New("ssh: must Accept or Reject channel")
488 var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
489
490 type extChannel struct {
491 code uint32
492 ch *channel
493 }
494
495 func (e *extChannel) Write(data []byte) (n int, err error) {
496 return e.ch.WriteExtended(data, e.code)
497 }
498
499 func (e *extChannel) Read(data []byte) (n int, err error) {
500 return e.ch.ReadExtended(data, e.code)
501 }
502
503 func (ch *channel) Accept() (Channel, <-chan *Request, error) {
504 if ch.decided {
505 return nil, nil, errDecidedAlready
506 }
507 ch.maxIncomingPayload = channelMaxPacket
508 confirm := channelOpenConfirmMsg{
509 PeersID: ch.remoteId,
510 MyID: ch.localId,
511 MyWindow: ch.myWindow,
512 MaxPacketSize: ch.maxIncomingPayload,
513 }
514 ch.decided = true
515 if err := ch.sendMessage(confirm); err != nil {
516 return nil, nil, err
517 }
518
519 return ch, ch.incomingRequests, nil
520 }
521
522 func (ch *channel) Reject(reason RejectionReason, message string) error {
523 if ch.decided {
524 return errDecidedAlready
525 }
526 reject := channelOpenFailureMsg{
527 PeersID: ch.remoteId,
528 Reason: reason,
529 Message: message,
530 Language: "en",
531 }
532 ch.decided = true
533 return ch.sendMessage(reject)
534 }
535
536 func (ch *channel) Read(data []byte) (int, error) {
537 if !ch.decided {
538 return 0, errUndecided
539 }
540 return ch.ReadExtended(data, 0)
541 }
542
543 func (ch *channel) Write(data []byte) (int, error) {
544 if !ch.decided {
545 return 0, errUndecided
546 }
547 return ch.WriteExtended(data, 0)
548 }
549
550 func (ch *channel) CloseWrite() error {
551 if !ch.decided {
552 return errUndecided
553 }
554 ch.sentEOF = true
555 return ch.sendMessage(channelEOFMsg{
556 PeersID: ch.remoteId})
557 }
558
559 func (ch *channel) Close() error {
560 if !ch.decided {
561 return errUndecided
562 }
563
564 return ch.sendMessage(channelCloseMsg{
565 PeersID: ch.remoteId})
566 }
567
568
569
570 func (ch *channel) Extended(code uint32) io.ReadWriter {
571 if !ch.decided {
572 return nil
573 }
574 return &extChannel{code, ch}
575 }
576
577 func (ch *channel) Stderr() io.ReadWriter {
578 return ch.Extended(1)
579 }
580
581 func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
582 if !ch.decided {
583 return false, errUndecided
584 }
585
586 if wantReply {
587 ch.sentRequestMu.Lock()
588 defer ch.sentRequestMu.Unlock()
589 }
590
591 msg := channelRequestMsg{
592 PeersID: ch.remoteId,
593 Request: name,
594 WantReply: wantReply,
595 RequestSpecificData: payload,
596 }
597
598 if err := ch.sendMessage(msg); err != nil {
599 return false, err
600 }
601
602 if wantReply {
603 m, ok := (<-ch.msg)
604 if !ok {
605 return false, io.EOF
606 }
607 switch m.(type) {
608 case *channelRequestFailureMsg:
609 return false, nil
610 case *channelRequestSuccessMsg:
611 return true, nil
612 default:
613 return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
614 }
615 }
616
617 return false, nil
618 }
619
620
621 func (ch *channel) ackRequest(ok bool) error {
622 if !ch.decided {
623 return errUndecided
624 }
625
626 var msg interface{}
627 if !ok {
628 msg = channelRequestFailureMsg{
629 PeersID: ch.remoteId,
630 }
631 } else {
632 msg = channelRequestSuccessMsg{
633 PeersID: ch.remoteId,
634 }
635 }
636 return ch.sendMessage(msg)
637 }
638
639 func (ch *channel) ChannelType() string {
640 return ch.chanType
641 }
642
643 func (ch *channel) ExtraData() []byte {
644 return ch.extraData
645 }
646
View as plain text