1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "errors"
11 "fmt"
12 "io"
13 "net"
14 "reflect"
15 "runtime"
16 "strings"
17 "sync"
18 "testing"
19 )
20
21 type testChecker struct {
22 calls []string
23 }
24
25 func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
26 if dialAddr == "bad" {
27 return fmt.Errorf("dialAddr is bad")
28 }
29
30 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
31 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
32 }
33
34 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
35
36 return nil
37 }
38
39
40
41
42 func netPipe() (net.Conn, net.Conn, error) {
43 listener, err := net.Listen("tcp", "127.0.0.1:0")
44 if err != nil {
45 listener, err = net.Listen("tcp", "[::1]:0")
46 if err != nil {
47 return nil, nil, err
48 }
49 }
50 defer listener.Close()
51 c1, err := net.Dial("tcp", listener.Addr().String())
52 if err != nil {
53 return nil, nil, err
54 }
55
56 c2, err := listener.Accept()
57 if err != nil {
58 c1.Close()
59 return nil, nil, err
60 }
61
62 return c1, c2, nil
63 }
64
65
66
67 type noiseTransport struct {
68 keyingTransport
69 }
70
71 func (t *noiseTransport) writePacket(p []byte) error {
72 ignore := []byte{msgIgnore}
73 if err := t.keyingTransport.writePacket(ignore); err != nil {
74 return err
75 }
76 debug := []byte{msgDebug, 1, 2, 3}
77 if err := t.keyingTransport.writePacket(debug); err != nil {
78 return err
79 }
80
81 return t.keyingTransport.writePacket(p)
82 }
83
84 func addNoiseTransport(t keyingTransport) keyingTransport {
85 return &noiseTransport{t}
86 }
87
88
89
90
91 func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
92 a, b, err := netPipe()
93 if err != nil {
94 return nil, nil, err
95 }
96
97 var trC, trS keyingTransport
98
99 trC = newTransport(a, rand.Reader, true)
100 trS = newTransport(b, rand.Reader, false)
101 if noise {
102 trC = addNoiseTransport(trC)
103 trS = addNoiseTransport(trS)
104 }
105 clientConf.SetDefaults()
106
107 v := []byte("version")
108 client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
109
110 serverConf := &ServerConfig{}
111 serverConf.AddHostKey(testSigners["ecdsa"])
112 serverConf.AddHostKey(testSigners["rsa"])
113 serverConf.SetDefaults()
114 server = newServerTransport(trS, v, v, serverConf)
115
116 if err := server.waitSession(); err != nil {
117 return nil, nil, fmt.Errorf("server.waitSession: %v", err)
118 }
119 if err := client.waitSession(); err != nil {
120 return nil, nil, fmt.Errorf("client.waitSession: %v", err)
121 }
122
123 return client, server, nil
124 }
125
126 func TestHandshakeBasic(t *testing.T) {
127 if runtime.GOOS == "plan9" {
128 t.Skip("see golang.org/issue/7237")
129 }
130
131 checker := &syncChecker{
132 waitCall: make(chan int, 10),
133 called: make(chan int, 10),
134 }
135
136 checker.waitCall <- 1
137 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
138 if err != nil {
139 t.Fatalf("handshakePair: %v", err)
140 }
141
142 defer trC.Close()
143 defer trS.Close()
144
145
146 <-checker.called
147
148 clientDone := make(chan int, 0)
149 gotHalf := make(chan int, 0)
150 const N = 20
151 errorCh := make(chan error, 1)
152
153 go func() {
154 defer close(clientDone)
155
156
157
158
159 for i := 0; i < N; i++ {
160 p := []byte{msgRequestSuccess, byte(i)}
161 if err := trC.writePacket(p); err != nil {
162 errorCh <- err
163 trC.Close()
164 return
165 }
166 if (i % 10) == 5 {
167 <-gotHalf
168
169 trC.requestKeyExchange()
170
171
172
173
174 <-checker.called
175 }
176 if (i % 10) == 7 {
177
178
179
180 checker.waitCall <- 1
181 }
182 }
183 errorCh <- nil
184 }()
185
186
187 i := 0
188 for ; i < N; i++ {
189 p, err := trS.readPacket()
190 if err != nil && err != io.EOF {
191 t.Fatalf("server error: %v", err)
192 }
193 if (i % 10) == 5 {
194 gotHalf <- 1
195 }
196
197 want := []byte{msgRequestSuccess, byte(i)}
198 if bytes.Compare(p, want) != 0 {
199 t.Errorf("message %d: got %v, want %v", i, p, want)
200 }
201 }
202 <-clientDone
203 if err := <-errorCh; err != nil {
204 t.Fatalf("sendPacket: %v", err)
205 }
206 if i != N {
207 t.Errorf("received %d messages, want 10.", i)
208 }
209
210 close(checker.called)
211 if _, ok := <-checker.called; ok {
212
213
214
215 t.Fatalf("got another host key checks after 2 handshakes")
216 }
217 }
218
219 func TestForceFirstKex(t *testing.T) {
220
221 checker := &testChecker{}
222 clientConf := &ClientConfig{HostKeyCallback: checker.Check}
223 a, b, err := netPipe()
224 if err != nil {
225 t.Fatalf("netPipe: %v", err)
226 }
227
228 var trC, trS keyingTransport
229
230 trC = newTransport(a, rand.Reader, true)
231
232
233 trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
234
235
236 trS = newTransport(b, rand.Reader, false)
237 clientConf.SetDefaults()
238
239 v := []byte("version")
240 client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
241
242 serverConf := &ServerConfig{}
243 serverConf.AddHostKey(testSigners["ecdsa"])
244 serverConf.AddHostKey(testSigners["rsa"])
245 serverConf.SetDefaults()
246 server := newServerTransport(trS, v, v, serverConf)
247
248 defer client.Close()
249 defer server.Close()
250
251
252
253
254
255 if err := server.waitSession(); err == nil {
256 t.Errorf("server first kex init should reject unexpected packet")
257 }
258 }
259
260 func TestHandshakeAutoRekeyWrite(t *testing.T) {
261 checker := &syncChecker{
262 called: make(chan int, 10),
263 waitCall: nil,
264 }
265 clientConf := &ClientConfig{HostKeyCallback: checker.Check}
266 clientConf.RekeyThreshold = 500
267 trC, trS, err := handshakePair(clientConf, "addr", false)
268 if err != nil {
269 t.Fatalf("handshakePair: %v", err)
270 }
271 defer trC.Close()
272 defer trS.Close()
273
274 input := make([]byte, 251)
275 input[0] = msgRequestSuccess
276
277 done := make(chan int, 1)
278 const numPacket = 5
279 go func() {
280 defer close(done)
281 j := 0
282 for ; j < numPacket; j++ {
283 if p, err := trS.readPacket(); err != nil {
284 break
285 } else if !bytes.Equal(input, p) {
286 t.Errorf("got packet type %d, want %d", p[0], input[0])
287 }
288 }
289
290 if j != numPacket {
291 t.Errorf("got %d, want 5 messages", j)
292 }
293 }()
294
295 <-checker.called
296
297 for i := 0; i < numPacket; i++ {
298 p := make([]byte, len(input))
299 copy(p, input)
300 if err := trC.writePacket(p); err != nil {
301 t.Errorf("writePacket: %v", err)
302 }
303 if i == 2 {
304
305 <-checker.called
306 }
307
308 }
309 <-done
310 }
311
312 type syncChecker struct {
313 waitCall chan int
314 called chan int
315 }
316
317 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
318 c.called <- 1
319 if c.waitCall != nil {
320 <-c.waitCall
321 }
322 return nil
323 }
324
325 func TestHandshakeAutoRekeyRead(t *testing.T) {
326 sync := &syncChecker{
327 called: make(chan int, 2),
328 waitCall: nil,
329 }
330 clientConf := &ClientConfig{
331 HostKeyCallback: sync.Check,
332 }
333 clientConf.RekeyThreshold = 500
334
335 trC, trS, err := handshakePair(clientConf, "addr", false)
336 if err != nil {
337 t.Fatalf("handshakePair: %v", err)
338 }
339 defer trC.Close()
340 defer trS.Close()
341
342 packet := make([]byte, 501)
343 packet[0] = msgRequestSuccess
344 if err := trS.writePacket(packet); err != nil {
345 t.Fatalf("writePacket: %v", err)
346 }
347
348
349
350 errorCh := make(chan error, 1)
351 go func() {
352 _, err := trC.readPacket()
353 errorCh <- err
354 }()
355
356 if err := <-errorCh; err != nil {
357 t.Fatalf("readPacket(client): %v", err)
358 }
359
360 <-sync.called
361 }
362
363
364
365 type errorKeyingTransport struct {
366 packetConn
367 readLeft, writeLeft int
368 }
369
370 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
371 return nil
372 }
373
374 func (n *errorKeyingTransport) getSessionID() []byte {
375 return nil
376 }
377
378 func (n *errorKeyingTransport) writePacket(packet []byte) error {
379 if n.writeLeft == 0 {
380 n.Close()
381 return errors.New("barf")
382 }
383
384 n.writeLeft--
385 return n.packetConn.writePacket(packet)
386 }
387
388 func (n *errorKeyingTransport) readPacket() ([]byte, error) {
389 if n.readLeft == 0 {
390 n.Close()
391 return nil, errors.New("barf")
392 }
393
394 n.readLeft--
395 return n.packetConn.readPacket()
396 }
397
398 func (n *errorKeyingTransport) setStrictMode() error { return nil }
399
400 func (n *errorKeyingTransport) setInitialKEXDone() {}
401
402 func TestHandshakeErrorHandlingRead(t *testing.T) {
403 for i := 0; i < 20; i++ {
404 testHandshakeErrorHandlingN(t, i, -1, false)
405 }
406 }
407
408 func TestHandshakeErrorHandlingWrite(t *testing.T) {
409 for i := 0; i < 20; i++ {
410 testHandshakeErrorHandlingN(t, -1, i, false)
411 }
412 }
413
414 func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
415 for i := 0; i < 20; i++ {
416 testHandshakeErrorHandlingN(t, i, -1, true)
417 }
418 }
419
420 func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
421 for i := 0; i < 20; i++ {
422 testHandshakeErrorHandlingN(t, -1, i, true)
423 }
424 }
425
426
427
428
429 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
430 if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" {
431 t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS)
432 }
433 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
434
435 a, b := memPipe()
436 defer a.Close()
437 defer b.Close()
438
439 key := testSigners["ecdsa"]
440 serverConf := Config{RekeyThreshold: minRekeyThreshold}
441 serverConf.SetDefaults()
442 serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
443 serverConn.hostKeys = []Signer{key}
444 go serverConn.readLoop()
445 go serverConn.kexLoop()
446
447 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
448 clientConf.SetDefaults()
449 clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
450 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
451 clientConn.hostKeyCallback = InsecureIgnoreHostKey()
452 go clientConn.readLoop()
453 go clientConn.kexLoop()
454
455 var wg sync.WaitGroup
456
457 for _, hs := range []packetConn{serverConn, clientConn} {
458 if !coupled {
459 wg.Add(2)
460 go func(c packetConn) {
461 for i := 0; ; i++ {
462 str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
463 err := c.writePacket(Marshal(&serviceRequestMsg{str}))
464 if err != nil {
465 break
466 }
467 }
468 wg.Done()
469 c.Close()
470 }(hs)
471 go func(c packetConn) {
472 for {
473 _, err := c.readPacket()
474 if err != nil {
475 break
476 }
477 }
478 wg.Done()
479 }(hs)
480 } else {
481 wg.Add(1)
482 go func(c packetConn) {
483 for {
484 _, err := c.readPacket()
485 if err != nil {
486 break
487 }
488 if err := c.writePacket(msg); err != nil {
489 break
490 }
491
492 }
493 wg.Done()
494 }(hs)
495 }
496 }
497 wg.Wait()
498 }
499
500 func TestDisconnect(t *testing.T) {
501 if runtime.GOOS == "plan9" {
502 t.Skip("see golang.org/issue/7237")
503 }
504 checker := &testChecker{}
505 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
506 if err != nil {
507 t.Fatalf("handshakePair: %v", err)
508 }
509
510 defer trC.Close()
511 defer trS.Close()
512
513 trC.writePacket([]byte{msgRequestSuccess, 0, 0})
514 errMsg := &disconnectMsg{
515 Reason: 42,
516 Message: "such is life",
517 }
518 trC.writePacket(Marshal(errMsg))
519 trC.writePacket([]byte{msgRequestSuccess, 0, 0})
520
521 packet, err := trS.readPacket()
522 if err != nil {
523 t.Fatalf("readPacket 1: %v", err)
524 }
525 if packet[0] != msgRequestSuccess {
526 t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
527 }
528
529 _, err = trS.readPacket()
530 if err == nil {
531 t.Errorf("readPacket 2 succeeded")
532 } else if !reflect.DeepEqual(err, errMsg) {
533 t.Errorf("got error %#v, want %#v", err, errMsg)
534 }
535
536 _, err = trS.readPacket()
537 if err == nil {
538 t.Errorf("readPacket 3 succeeded")
539 }
540 }
541
542 func TestHandshakeRekeyDefault(t *testing.T) {
543 clientConf := &ClientConfig{
544 Config: Config{
545 Ciphers: []string{"aes128-ctr"},
546 },
547 HostKeyCallback: InsecureIgnoreHostKey(),
548 }
549 trC, trS, err := handshakePair(clientConf, "addr", false)
550 if err != nil {
551 t.Fatalf("handshakePair: %v", err)
552 }
553 defer trC.Close()
554 defer trS.Close()
555
556 trC.writePacket([]byte{msgRequestSuccess, 0, 0})
557 trC.Close()
558
559 rgb := (1024 + trC.readBytesLeft) >> 30
560 wgb := (1024 + trC.writeBytesLeft) >> 30
561
562 if rgb != 64 {
563 t.Errorf("got rekey after %dG read, want 64G", rgb)
564 }
565 if wgb != 64 {
566 t.Errorf("got rekey after %dG write, want 64G", wgb)
567 }
568 }
569
570 func TestHandshakeAEADCipherNoMAC(t *testing.T) {
571 for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} {
572 checker := &syncChecker{
573 called: make(chan int, 1),
574 }
575 clientConf := &ClientConfig{
576 Config: Config{
577 Ciphers: []string{cipher},
578 MACs: []string{},
579 },
580 HostKeyCallback: checker.Check,
581 }
582 trC, trS, err := handshakePair(clientConf, "addr", false)
583 if err != nil {
584 t.Fatalf("handshakePair: %v", err)
585 }
586 defer trC.Close()
587 defer trS.Close()
588
589 <-checker.called
590 }
591 }
592
593
594
595
596 func TestNoSHA2Support(t *testing.T) {
597 c1, c2, err := netPipe()
598 if err != nil {
599 t.Fatalf("netPipe: %v", err)
600 }
601 defer c1.Close()
602 defer c2.Close()
603
604 serverConf := &ServerConfig{
605 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
606 return &Permissions{}, nil
607 },
608 }
609 serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]})
610 go func() {
611 _, _, _, err := NewServerConn(c1, serverConf)
612 if err != nil {
613 t.Error(err)
614 }
615 }()
616
617 clientConf := &ClientConfig{
618 User: "test",
619 Auth: []AuthMethod{Password("testpw")},
620 HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()),
621 }
622
623 if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
624 t.Fatal(err)
625 }
626 }
627
628 func TestMultiAlgoSignerHandshake(t *testing.T) {
629 algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
630 if !ok {
631 t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
632 }
633 multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
634 if err != nil {
635 t.Fatalf("unable to create multi algorithm signer: %v", err)
636 }
637 c1, c2, err := netPipe()
638 if err != nil {
639 t.Fatalf("netPipe: %v", err)
640 }
641 defer c1.Close()
642 defer c2.Close()
643
644 serverConf := &ServerConfig{
645 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
646 return &Permissions{}, nil
647 },
648 }
649 serverConf.AddHostKey(multiAlgoSigner)
650 go NewServerConn(c1, serverConf)
651
652 clientConf := &ClientConfig{
653 User: "test",
654 Auth: []AuthMethod{Password("testpw")},
655 HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()),
656 HostKeyAlgorithms: []string{KeyAlgoRSASHA512},
657 }
658
659 if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
660 t.Fatal(err)
661 }
662 }
663
664 func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) {
665 algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
666 if !ok {
667 t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
668 }
669 multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
670 if err != nil {
671 t.Fatalf("unable to create multi algorithm signer: %v", err)
672 }
673 c1, c2, err := netPipe()
674 if err != nil {
675 t.Fatalf("netPipe: %v", err)
676 }
677 defer c1.Close()
678 defer c2.Close()
679
680
681 serverConf := &ServerConfig{
682 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
683 return &Permissions{}, nil
684 },
685 }
686 serverConf.AddHostKey(multiAlgoSigner)
687 go NewServerConn(c1, serverConf)
688
689
690 clientConf := &ClientConfig{
691 User: "test",
692 Auth: []AuthMethod{Password("testpw")},
693 HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()),
694 HostKeyAlgorithms: []string{KeyAlgoRSA},
695 }
696
697 _, _, _, err = NewClientConn(c2, "", clientConf)
698 if err == nil {
699 t.Fatal("succeeded connecting with no common hostkey algorithm")
700 }
701 }
702
703 func TestPickIncompatibleHostKeyAlgo(t *testing.T) {
704 algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
705 if !ok {
706 t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
707 }
708 multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
709 if err != nil {
710 t.Fatalf("unable to create multi algorithm signer: %v", err)
711 }
712 signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA)
713 if signer != nil {
714 t.Fatal("incompatible signer returned")
715 }
716 }
717
718 func TestStrictKEXResetSeqFirstKEX(t *testing.T) {
719 if runtime.GOOS == "plan9" {
720 t.Skip("see golang.org/issue/7237")
721 }
722
723 checker := &syncChecker{
724 waitCall: make(chan int, 10),
725 called: make(chan int, 10),
726 }
727
728 checker.waitCall <- 1
729 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
730 if err != nil {
731 t.Fatalf("handshakePair: %v", err)
732 }
733 <-checker.called
734
735 t.Cleanup(func() {
736 trC.Close()
737 trS.Close()
738 })
739
740
741 _, err = trC.readPacket()
742 if err != nil {
743 t.Fatalf("readPacket failed: %s", err)
744 }
745
746
747
748 trC.Close()
749 trS.Close()
750
751
752
753
754
755 if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
756 trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
757 t.Errorf(
758 "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
759 trC.conn.(*transport).reader.seqNum,
760 trC.conn.(*transport).writer.seqNum,
761 trS.conn.(*transport).reader.seqNum,
762 trS.conn.(*transport).writer.seqNum,
763 )
764 }
765 }
766
767 func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) {
768 if runtime.GOOS == "plan9" {
769 t.Skip("see golang.org/issue/7237")
770 }
771
772 checker := &syncChecker{
773 waitCall: make(chan int, 10),
774 called: make(chan int, 10),
775 }
776
777 checker.waitCall <- 1
778 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
779 if err != nil {
780 t.Fatalf("handshakePair: %v", err)
781 }
782 <-checker.called
783
784 t.Cleanup(func() {
785 trC.Close()
786 trS.Close()
787 })
788
789
790 _, err = trC.readPacket()
791 if err != nil {
792 t.Fatalf("readPacket failed: %s", err)
793 }
794
795
796 for i := 0; i < 5; i++ {
797 if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
798 t.Fatalf("writePacket failed: %s", err)
799 }
800 if _, err := trS.readPacket(); err != nil {
801 t.Fatalf("readPacket failed: %s", err)
802 }
803 if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
804 t.Fatalf("writePacket failed: %s", err)
805 }
806 if _, err := trC.readPacket(); err != nil {
807 t.Fatalf("readPacket failed: %s", err)
808 }
809 }
810
811
812 checker.waitCall <- 1
813 trC.requestKeyExchange()
814 <-checker.called
815
816
817
818
819 dummyPacket := []byte{99}
820 if err := trS.writePacket(dummyPacket); err != nil {
821 t.Fatalf("writePacket failed: %s", err)
822 }
823 if p, err := trC.readPacket(); err != nil {
824 t.Fatalf("readPacket failed: %s", err)
825 } else if !bytes.Equal(p, dummyPacket) {
826 t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
827 }
828
829
830
831 trC.Close()
832 trS.Close()
833
834 if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
835 trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
836 t.Errorf(
837 "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
838 trC.conn.(*transport).reader.seqNum,
839 trC.conn.(*transport).writer.seqNum,
840 trS.conn.(*transport).reader.seqNum,
841 trS.conn.(*transport).writer.seqNum,
842 )
843 }
844 }
845
846 func TestSeqNumIncrease(t *testing.T) {
847 if runtime.GOOS == "plan9" {
848 t.Skip("see golang.org/issue/7237")
849 }
850
851 checker := &syncChecker{
852 waitCall: make(chan int, 10),
853 called: make(chan int, 10),
854 }
855
856 checker.waitCall <- 1
857 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
858 if err != nil {
859 t.Fatalf("handshakePair: %v", err)
860 }
861 <-checker.called
862
863 t.Cleanup(func() {
864 trC.Close()
865 trS.Close()
866 })
867
868
869 _, err = trC.readPacket()
870 if err != nil {
871 t.Fatalf("readPacket failed: %s", err)
872 }
873
874
875 for i := 0; i < 5; i++ {
876 if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
877 t.Fatalf("writePacket failed: %s", err)
878 }
879 if _, err := trS.readPacket(); err != nil {
880 t.Fatalf("readPacket failed: %s", err)
881 }
882 if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
883 t.Fatalf("writePacket failed: %s", err)
884 }
885 if _, err := trC.readPacket(); err != nil {
886 t.Fatalf("readPacket failed: %s", err)
887 }
888 }
889
890
891
892 trC.Close()
893 trS.Close()
894
895 if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 ||
896 trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 {
897 t.Errorf(
898 "unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)",
899 trC.conn.(*transport).reader.seqNum,
900 trC.conn.(*transport).writer.seqNum,
901 trS.conn.(*transport).reader.seqNum,
902 trS.conn.(*transport).writer.seqNum,
903 )
904 }
905 }
906
907 func TestStrictKEXUnexpectedMsg(t *testing.T) {
908 if runtime.GOOS == "plan9" {
909 t.Skip("see golang.org/issue/7237")
910 }
911
912
913 _, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true)
914 if err == nil {
915 t.Fatal("handshake should fail when there are unexpected messages during the handshake")
916 }
917
918 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false)
919 if err != nil {
920 t.Fatalf("handshake failed: %s", err)
921 }
922
923
924 if err := trC.writePacket([]byte{msgIgnore}); err != nil {
925 t.Fatalf("writePacket failed: %s", err)
926 }
927 if err := trC.writePacket([]byte{msgDebug}); err != nil {
928 t.Fatalf("writePacket failed: %s", err)
929 }
930 dummyPacket := []byte{99}
931 if err := trC.writePacket(dummyPacket); err != nil {
932 t.Fatalf("writePacket failed: %s", err)
933 }
934
935 if p, err := trS.readPacket(); err != nil {
936 t.Fatalf("readPacket failed: %s", err)
937 } else if !bytes.Equal(p, dummyPacket) {
938 t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
939 }
940 }
941
942 func TestStrictKEXMixed(t *testing.T) {
943
944
945
946 a, b, err := netPipe()
947 if err != nil {
948 t.Fatalf("netPipe failed: %s", err)
949 }
950
951 var trC, trS keyingTransport
952
953 trC = newTransport(a, rand.Reader, true)
954 trS = newTransport(b, rand.Reader, false)
955 trS = addNoiseTransport(trS)
956
957 clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}
958 clientConf.SetDefaults()
959
960 v := []byte("version")
961 client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
962
963 serverConf := &ServerConfig{}
964 serverConf.AddHostKey(testSigners["ecdsa"])
965 serverConf.AddHostKey(testSigners["rsa"])
966 serverConf.SetDefaults()
967
968 transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version"))
969 transport.hostKeys = serverConf.hostKeys
970 transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms
971
972 readOneFailure := make(chan error, 1)
973 go func() {
974 if _, err := transport.readOnePacket(true); err != nil {
975 readOneFailure <- err
976 }
977 }()
978
979
980 msg := &kexInitMsg{
981 KexAlgos: transport.config.KeyExchanges,
982 CiphersClientServer: transport.config.Ciphers,
983 CiphersServerClient: transport.config.Ciphers,
984 MACsClientServer: transport.config.MACs,
985 MACsServerClient: transport.config.MACs,
986 CompressionClientServer: supportedCompressions,
987 CompressionServerClient: supportedCompressions,
988 ServerHostKeyAlgos: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA},
989 }
990 packet := Marshal(msg)
991
992 packetCopy := make([]byte, len(packet))
993 copy(packetCopy, packet)
994 if err := transport.pushPacket(packetCopy); err != nil {
995 t.Fatalf("pushPacket: %s", err)
996 }
997 transport.sentInitMsg = msg
998 transport.sentInitPacket = packet
999
1000 if err := transport.getWriteError(); err != nil {
1001 t.Fatalf("getWriteError failed: %s", err)
1002 }
1003 var request *pendingKex
1004 select {
1005 case err = <-readOneFailure:
1006 t.Fatalf("server readOnePacket failed: %s", err)
1007 case request = <-transport.startKex:
1008 break
1009 }
1010
1011
1012
1013
1014
1015 if err := transport.enterKeyExchange(request.otherInit); err != nil {
1016 t.Fatalf("enterKeyExchange failed: %s", err)
1017 }
1018 if err := client.waitSession(); err != nil {
1019 t.Fatalf("client.waitSession: %v", err)
1020 }
1021 }
1022
View as plain text