1
2
3
4
5 package ssh
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "sync"
12 "testing"
13 )
14
15 func muxPair() (*mux, *mux) {
16 a, b := memPipe()
17
18 s := newMux(a)
19 c := newMux(b)
20
21 return s, c
22 }
23
24
25
26 func channelPair(t *testing.T) (*channel, *channel, *mux) {
27 c, s := muxPair()
28
29 res := make(chan *channel, 1)
30 go func() {
31 newCh, ok := <-s.incomingChannels
32 if !ok {
33 t.Error("no incoming channel")
34 close(res)
35 return
36 }
37 if newCh.ChannelType() != "chan" {
38 t.Errorf("got type %q want chan", newCh.ChannelType())
39 newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType()))
40 close(res)
41 return
42 }
43 ch, _, err := newCh.Accept()
44 if err != nil {
45 t.Errorf("accept: %v", err)
46 close(res)
47 return
48 }
49 res <- ch.(*channel)
50 }()
51
52 ch, err := c.openChannel("chan", nil)
53 if err != nil {
54 t.Fatalf("OpenChannel: %v", err)
55 }
56 w := <-res
57 if w == nil {
58 t.Fatal("unable to get write channel")
59 }
60
61 return w, ch, c
62 }
63
64
65
66 func TestMuxChannelExtendedThreadSafety(t *testing.T) {
67 writer, reader, mux := channelPair(t)
68 defer writer.Close()
69 defer reader.Close()
70 defer mux.Close()
71
72 var wr, rd sync.WaitGroup
73 magic := "hello world"
74
75 wr.Add(2)
76 go func() {
77 io.WriteString(writer, magic)
78 wr.Done()
79 }()
80 go func() {
81 io.WriteString(writer.Stderr(), magic)
82 wr.Done()
83 }()
84
85 rd.Add(2)
86 go func() {
87 c, err := io.ReadAll(reader)
88 if string(c) != magic {
89 t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err)
90 }
91 rd.Done()
92 }()
93 go func() {
94 c, err := io.ReadAll(reader.Stderr())
95 if string(c) != magic {
96 t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err)
97 }
98 rd.Done()
99 }()
100
101 wr.Wait()
102 writer.CloseWrite()
103 rd.Wait()
104 }
105
106 func TestMuxReadWrite(t *testing.T) {
107 s, c, mux := channelPair(t)
108 defer s.Close()
109 defer c.Close()
110 defer mux.Close()
111
112 magic := "hello world"
113 magicExt := "hello stderr"
114 var wg sync.WaitGroup
115 t.Cleanup(wg.Wait)
116 wg.Add(1)
117 go func() {
118 defer wg.Done()
119 _, err := s.Write([]byte(magic))
120 if err != nil {
121 t.Errorf("Write: %v", err)
122 return
123 }
124 _, err = s.Extended(1).Write([]byte(magicExt))
125 if err != nil {
126 t.Errorf("Write: %v", err)
127 return
128 }
129 }()
130
131 var buf [1024]byte
132 n, err := c.Read(buf[:])
133 if err != nil {
134 t.Fatalf("server Read: %v", err)
135 }
136 got := string(buf[:n])
137 if got != magic {
138 t.Fatalf("server: got %q want %q", got, magic)
139 }
140
141 n, err = c.Extended(1).Read(buf[:])
142 if err != nil {
143 t.Fatalf("server Read: %v", err)
144 }
145
146 got = string(buf[:n])
147 if got != magicExt {
148 t.Fatalf("server: got %q want %q", got, magic)
149 }
150 }
151
152 func TestMuxChannelOverflow(t *testing.T) {
153 reader, writer, mux := channelPair(t)
154 defer reader.Close()
155 defer writer.Close()
156 defer mux.Close()
157
158 var wg sync.WaitGroup
159 t.Cleanup(wg.Wait)
160 wg.Add(1)
161 go func() {
162 defer wg.Done()
163 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
164 t.Errorf("could not fill window: %v", err)
165 }
166 writer.Write(make([]byte, 1))
167 }()
168 writer.remoteWin.waitWriterBlocked()
169
170
171 packet := make([]byte, 1+4+4+1)
172 packet[0] = msgChannelData
173 marshalUint32(packet[1:], writer.remoteId)
174 marshalUint32(packet[5:], uint32(1))
175 packet[9] = 42
176
177 if err := writer.mux.conn.writePacket(packet); err != nil {
178 t.Errorf("could not send packet")
179 }
180 if _, err := reader.SendRequest("hello", true, nil); err == nil {
181 t.Errorf("SendRequest succeeded.")
182 }
183 }
184
185 func TestMuxChannelReadUnblock(t *testing.T) {
186 reader, writer, mux := channelPair(t)
187 defer reader.Close()
188 defer writer.Close()
189 defer mux.Close()
190
191 var wg sync.WaitGroup
192 t.Cleanup(wg.Wait)
193 wg.Add(1)
194 go func() {
195 defer wg.Done()
196 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
197 t.Errorf("could not fill window: %v", err)
198 }
199 if _, err := writer.Write(make([]byte, 1)); err != nil {
200 t.Errorf("Write: %v", err)
201 }
202 writer.Close()
203 }()
204
205 writer.remoteWin.waitWriterBlocked()
206
207 buf := make([]byte, 32768)
208 for {
209 _, err := reader.Read(buf)
210 if err == io.EOF {
211 break
212 }
213 if err != nil {
214 t.Fatalf("Read: %v", err)
215 }
216 }
217 }
218
219 func TestMuxChannelCloseWriteUnblock(t *testing.T) {
220 reader, writer, mux := channelPair(t)
221 defer reader.Close()
222 defer writer.Close()
223 defer mux.Close()
224
225 var wg sync.WaitGroup
226 t.Cleanup(wg.Wait)
227 wg.Add(1)
228 go func() {
229 defer wg.Done()
230 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
231 t.Errorf("could not fill window: %v", err)
232 }
233 if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
234 t.Errorf("got %v, want EOF for unblock write", err)
235 }
236 }()
237
238 writer.remoteWin.waitWriterBlocked()
239 reader.Close()
240 }
241
242 func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
243 reader, writer, mux := channelPair(t)
244 defer reader.Close()
245 defer writer.Close()
246 defer mux.Close()
247
248 var wg sync.WaitGroup
249 t.Cleanup(wg.Wait)
250 wg.Add(1)
251 go func() {
252 defer wg.Done()
253 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
254 t.Errorf("could not fill window: %v", err)
255 }
256 if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
257 t.Errorf("got %v, want EOF for unblock write", err)
258 }
259 }()
260
261 writer.remoteWin.waitWriterBlocked()
262 mux.Close()
263 }
264
265 func TestMuxReject(t *testing.T) {
266 client, server := muxPair()
267 defer server.Close()
268 defer client.Close()
269
270 var wg sync.WaitGroup
271 t.Cleanup(wg.Wait)
272 wg.Add(1)
273 go func() {
274 defer wg.Done()
275
276 ch, ok := <-server.incomingChannels
277 if !ok {
278 t.Error("cannot accept channel")
279 return
280 }
281 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
282 t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
283 ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String())
284 return
285 }
286 ch.Reject(RejectionReason(42), "message")
287 }()
288
289 ch, err := client.openChannel("ch", []byte("extra"))
290 if ch != nil {
291 t.Fatal("openChannel not rejected")
292 }
293
294 ocf, ok := err.(*OpenChannelError)
295 if !ok {
296 t.Errorf("got %#v want *OpenChannelError", err)
297 } else if ocf.Reason != 42 || ocf.Message != "message" {
298 t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
299 }
300
301 want := "ssh: rejected: unknown reason 42 (message)"
302 if err.Error() != want {
303 t.Errorf("got %q, want %q", err.Error(), want)
304 }
305 }
306
307 func TestMuxChannelRequest(t *testing.T) {
308 client, server, mux := channelPair(t)
309 defer server.Close()
310 defer client.Close()
311 defer mux.Close()
312
313 var received int
314 var wg sync.WaitGroup
315 t.Cleanup(wg.Wait)
316 wg.Add(1)
317 go func() {
318 for r := range server.incomingRequests {
319 received++
320 r.Reply(r.Type == "yes", nil)
321 }
322 wg.Done()
323 }()
324 _, err := client.SendRequest("yes", false, nil)
325 if err != nil {
326 t.Fatalf("SendRequest: %v", err)
327 }
328 ok, err := client.SendRequest("yes", true, nil)
329 if err != nil {
330 t.Fatalf("SendRequest: %v", err)
331 }
332
333 if !ok {
334 t.Errorf("SendRequest(yes): %v", ok)
335
336 }
337
338 ok, err = client.SendRequest("no", true, nil)
339 if err != nil {
340 t.Fatalf("SendRequest: %v", err)
341 }
342 if ok {
343 t.Errorf("SendRequest(no): %v", ok)
344 }
345
346 client.Close()
347 wg.Wait()
348
349 if received != 3 {
350 t.Errorf("got %d requests, want %d", received, 3)
351 }
352 }
353
354 func TestMuxUnknownChannelRequests(t *testing.T) {
355 clientPipe, serverPipe := memPipe()
356 client := newMux(clientPipe)
357 defer serverPipe.Close()
358 defer client.Close()
359
360 kDone := make(chan error, 1)
361 go func() {
362
363 err := serverPipe.writePacket(Marshal(channelRequestMsg{
364 PeersID: 1,
365 Request: "keepalive@openssh.com",
366 WantReply: false,
367 RequestSpecificData: []byte{},
368 }))
369 if err != nil {
370 kDone <- fmt.Errorf("send: %w", err)
371 return
372 }
373
374
375
376 err = serverPipe.writePacket(Marshal(channelRequestMsg{
377 PeersID: 2,
378 Request: "keepalive@openssh.com",
379 WantReply: true,
380 RequestSpecificData: []byte{},
381 }))
382 if err != nil {
383 kDone <- fmt.Errorf("send: %w", err)
384 return
385 }
386
387 packet, err := serverPipe.readPacket()
388 if err != nil {
389 kDone <- fmt.Errorf("read packet: %w", err)
390 return
391 }
392 decoded, err := decode(packet)
393 if err != nil {
394 kDone <- fmt.Errorf("decode failed: %w", err)
395 return
396 }
397
398 switch msg := decoded.(type) {
399 case *channelRequestFailureMsg:
400 if msg.PeersID != 2 {
401 kDone <- fmt.Errorf("received response to wrong message: %v", msg)
402 return
403
404 }
405 default:
406 kDone <- fmt.Errorf("unexpected channel message: %v", msg)
407 return
408 }
409
410 kDone <- nil
411
412
413
414 packet, err = serverPipe.readPacket()
415 if err != nil {
416 kDone <- fmt.Errorf("read packet: %w", err)
417 return
418 }
419 if packet[0] != msgGlobalRequest {
420 kDone <- errors.New("expected global request")
421 return
422 }
423
424 err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
425 Data: []byte{},
426 }))
427 if err != nil {
428 kDone <- fmt.Errorf("failed to send failure msg: %w", err)
429 return
430 }
431
432 close(kDone)
433 }()
434
435
436
437 if err := <-kDone; err != nil {
438 t.Fatal(err)
439 }
440
441
442 if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
443 t.Fatalf("failed to send keepalive: %v", err)
444 }
445
446
447 if err := <-kDone; err != nil {
448 t.Fatal(err)
449 }
450 }
451
452 func TestMuxClosedChannel(t *testing.T) {
453 clientPipe, serverPipe := memPipe()
454 client := newMux(clientPipe)
455 defer serverPipe.Close()
456 defer client.Close()
457
458 kDone := make(chan error, 1)
459 go func() {
460
461 packet, err := serverPipe.readPacket()
462 if err != nil {
463 kDone <- fmt.Errorf("read packet: %w", err)
464 return
465 }
466 if packet[0] != msgChannelOpen {
467 kDone <- errors.New("expected chan open")
468 return
469 }
470
471 var openMsg channelOpenMsg
472 if err := Unmarshal(packet, &openMsg); err != nil {
473 kDone <- fmt.Errorf("unmarshal: %w", err)
474 return
475 }
476
477
478 err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{
479 PeersID: openMsg.PeersID,
480 MyID: 0,
481 MyWindow: 0,
482 MaxPacketSize: channelMaxPacket,
483 }))
484 if err != nil {
485 kDone <- fmt.Errorf("send: %w", err)
486 return
487 }
488
489
490 err = serverPipe.writePacket(Marshal(channelCloseMsg{
491 PeersID: openMsg.PeersID,
492 }))
493 if err != nil {
494 kDone <- fmt.Errorf("send: %w", err)
495 return
496 }
497
498
499 err = serverPipe.writePacket(Marshal(channelRequestMsg{
500 PeersID: openMsg.PeersID,
501 Request: "keepalive@openssh.com",
502 WantReply: true,
503 RequestSpecificData: []byte{},
504 }))
505 if err != nil {
506 kDone <- fmt.Errorf("send: %w", err)
507 return
508 }
509
510
511 packet, err = serverPipe.readPacket()
512 if err != nil {
513 kDone <- fmt.Errorf("read packet: %w", err)
514 return
515 }
516 if packet[0] != msgChannelClose {
517 kDone <- errors.New("expected channel close")
518 return
519 }
520
521
522 packet, err = serverPipe.readPacket()
523 if err != nil {
524 kDone <- fmt.Errorf("read packet: %w", err)
525 return
526 }
527 if packet[0] != msgChannelFailure {
528 kDone <- errors.New("expected channel failure")
529 return
530 }
531 kDone <- nil
532
533
534
535 packet, err = serverPipe.readPacket()
536 if err != nil {
537 kDone <- fmt.Errorf("read packet: %w", err)
538 return
539 }
540 if packet[0] != msgGlobalRequest {
541 kDone <- errors.New("expected global request")
542 return
543 }
544
545 err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
546 Data: []byte{},
547 }))
548 if err != nil {
549 kDone <- fmt.Errorf("failed to send failure msg: %w", err)
550 return
551 }
552
553 close(kDone)
554 }()
555
556
557 ch, err := client.openChannel("chan", nil)
558 if err != nil {
559 t.Fatalf("OpenChannel: %v", err)
560 }
561 defer ch.Close()
562
563
564 <-kDone
565
566
567 if _, ok := <-ch.incomingRequests; ok {
568 t.Fatalf("channel not closed")
569 }
570
571
572 if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
573 t.Fatalf("failed to send keepalive: %v", err)
574 }
575
576
577 <-kDone
578 }
579
580 func TestMuxGlobalRequest(t *testing.T) {
581 var sawPeek bool
582 var wg sync.WaitGroup
583 defer func() {
584 wg.Wait()
585 if !sawPeek {
586 t.Errorf("never saw 'peek' request")
587 }
588 }()
589
590 clientMux, serverMux := muxPair()
591 defer serverMux.Close()
592 defer clientMux.Close()
593
594 wg.Add(1)
595 go func() {
596 defer wg.Done()
597 for r := range serverMux.incomingRequests {
598 sawPeek = sawPeek || r.Type == "peek"
599 if r.WantReply {
600 err := r.Reply(r.Type == "yes",
601 append([]byte(r.Type), r.Payload...))
602 if err != nil {
603 t.Errorf("AckRequest: %v", err)
604 }
605 }
606 }
607 }()
608
609 _, _, err := clientMux.SendRequest("peek", false, nil)
610 if err != nil {
611 t.Errorf("SendRequest: %v", err)
612 }
613
614 ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
615 if !ok || string(data) != "yesa" || err != nil {
616 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
617 ok, data, err)
618 }
619 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
620 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
621 ok, data, err)
622 }
623
624 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
625 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
626 ok, data, err)
627 }
628 }
629
630 func TestMuxGlobalRequestUnblock(t *testing.T) {
631 clientMux, serverMux := muxPair()
632 defer serverMux.Close()
633 defer clientMux.Close()
634
635 result := make(chan error, 1)
636 go func() {
637 _, _, err := clientMux.SendRequest("hello", true, nil)
638 result <- err
639 }()
640
641 <-serverMux.incomingRequests
642 serverMux.conn.Close()
643 err := <-result
644
645 if err != io.EOF {
646 t.Errorf("want EOF, got %v", io.EOF)
647 }
648 }
649
650 func TestMuxChannelRequestUnblock(t *testing.T) {
651 a, b, connB := channelPair(t)
652 defer a.Close()
653 defer b.Close()
654 defer connB.Close()
655
656 result := make(chan error, 1)
657 go func() {
658 _, err := a.SendRequest("hello", true, nil)
659 result <- err
660 }()
661
662 <-b.incomingRequests
663 connB.conn.Close()
664 err := <-result
665
666 if err != io.EOF {
667 t.Errorf("want EOF, got %v", err)
668 }
669 }
670
671 func TestMuxCloseChannel(t *testing.T) {
672 r, w, mux := channelPair(t)
673 defer mux.Close()
674 defer r.Close()
675 defer w.Close()
676
677 result := make(chan error, 1)
678 go func() {
679 var b [1024]byte
680 _, err := r.Read(b[:])
681 result <- err
682 }()
683 if err := w.Close(); err != nil {
684 t.Errorf("w.Close: %v", err)
685 }
686
687 if _, err := w.Write([]byte("hello")); err != io.EOF {
688 t.Errorf("got err %v, want io.EOF after Close", err)
689 }
690
691 if err := <-result; err != io.EOF {
692 t.Errorf("got %v (%T), want io.EOF", err, err)
693 }
694 }
695
696 func TestMuxCloseWriteChannel(t *testing.T) {
697 r, w, mux := channelPair(t)
698 defer mux.Close()
699
700 result := make(chan error, 1)
701 go func() {
702 var b [1024]byte
703 _, err := r.Read(b[:])
704 result <- err
705 }()
706 if err := w.CloseWrite(); err != nil {
707 t.Errorf("w.CloseWrite: %v", err)
708 }
709
710 if _, err := w.Write([]byte("hello")); err != io.EOF {
711 t.Errorf("got err %v, want io.EOF after CloseWrite", err)
712 }
713
714 if err := <-result; err != io.EOF {
715 t.Errorf("got %v (%T), want io.EOF", err, err)
716 }
717 }
718
719 func TestMuxInvalidRecord(t *testing.T) {
720 a, b := muxPair()
721 defer a.Close()
722 defer b.Close()
723
724 packet := make([]byte, 1+4+4+1)
725 packet[0] = msgChannelData
726 marshalUint32(packet[1:], 29348723 )
727 marshalUint32(packet[5:], 1)
728 packet[9] = 42
729
730 a.conn.writePacket(packet)
731 go a.SendRequest("hello", false, nil)
732
733 req, ok := <-b.incomingRequests
734 if ok {
735 t.Errorf("got request %#v after receiving invalid packet", req)
736 }
737 }
738
739 func TestZeroWindowAdjust(t *testing.T) {
740 a, b, mux := channelPair(t)
741 defer a.Close()
742 defer b.Close()
743 defer mux.Close()
744
745 go func() {
746 io.WriteString(a, "hello")
747
748 a.sendMessage(windowAdjustMsg{})
749 io.WriteString(a, "world")
750 a.Close()
751 }()
752
753 want := "helloworld"
754 c, _ := io.ReadAll(b)
755 if string(c) != want {
756 t.Errorf("got %q want %q", c, want)
757 }
758 }
759
760 func TestMuxMaxPacketSize(t *testing.T) {
761 a, b, mux := channelPair(t)
762 defer a.Close()
763 defer b.Close()
764 defer mux.Close()
765
766 large := make([]byte, a.maxRemotePayload+1)
767 packet := make([]byte, 1+4+4+1+len(large))
768 packet[0] = msgChannelData
769 marshalUint32(packet[1:], a.remoteId)
770 marshalUint32(packet[5:], uint32(len(large)))
771 packet[9] = 42
772
773 if err := a.mux.conn.writePacket(packet); err != nil {
774 t.Errorf("could not send packet")
775 }
776
777 var wg sync.WaitGroup
778 t.Cleanup(wg.Wait)
779 wg.Add(1)
780 go func() {
781 a.SendRequest("hello", false, nil)
782 wg.Done()
783 }()
784
785 _, ok := <-b.incomingRequests
786 if ok {
787 t.Errorf("connection still alive after receiving large packet.")
788 }
789 }
790
791 func TestMuxChannelWindowDeferredUpdates(t *testing.T) {
792 s, c, mux := channelPair(t)
793 cTransport := mux.conn.(*memTransport)
794 defer s.Close()
795 defer c.Close()
796 defer mux.Close()
797
798 var wg sync.WaitGroup
799 t.Cleanup(wg.Wait)
800
801 data := make([]byte, 1024)
802
803 wg.Add(1)
804 go func() {
805 defer wg.Done()
806 _, err := s.Write(data)
807 if err != nil {
808 t.Errorf("Write: %v", err)
809 return
810 }
811 }()
812 cWritesInit := cTransport.getWriteCount()
813 buf := make([]byte, 1)
814 for i := 0; i < len(data); i++ {
815 n, err := c.Read(buf)
816 if n != len(buf) || err != nil {
817 t.Fatalf("Read: %v, %v", n, err)
818 }
819 }
820 cWrites := cTransport.getWriteCount() - cWritesInit
821
822
823 if cWrites > 30 {
824 t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites)
825 }
826 }
827
828
829 func TestDebug(t *testing.T) {
830 if debugMux {
831 t.Error("mux debug switched on")
832 }
833 if debugHandshake {
834 t.Error("handshake debug switched on")
835 }
836 if debugTransport {
837 t.Error("transport debug switched on")
838 }
839 }
840
View as plain text