1
2
3
4
5 package ssh
6
7
8
9 import (
10 "bytes"
11 crypto_rand "crypto/rand"
12 "errors"
13 "io"
14 "math/rand"
15 "net"
16 "sync"
17 "testing"
18
19 "golang.org/x/crypto/ssh/terminal"
20 )
21
22 type serverType func(Channel, <-chan *Request, *testing.T)
23
24
25 func dial(handler serverType, t *testing.T) *Client {
26 c1, c2, err := netPipe()
27 if err != nil {
28 t.Fatalf("netPipe: %v", err)
29 }
30
31 var wg sync.WaitGroup
32 t.Cleanup(wg.Wait)
33 wg.Add(1)
34 go func() {
35 defer func() {
36 c1.Close()
37 wg.Done()
38 }()
39 conf := ServerConfig{
40 NoClientAuth: true,
41 }
42 conf.AddHostKey(testSigners["rsa"])
43
44 conn, chans, reqs, err := NewServerConn(c1, &conf)
45 if err != nil {
46 t.Errorf("Unable to handshake: %v", err)
47 return
48 }
49 wg.Add(1)
50 go func() {
51 DiscardRequests(reqs)
52 wg.Done()
53 }()
54
55 for newCh := range chans {
56 if newCh.ChannelType() != "session" {
57 newCh.Reject(UnknownChannelType, "unknown channel type")
58 continue
59 }
60
61 ch, inReqs, err := newCh.Accept()
62 if err != nil {
63 t.Errorf("Accept: %v", err)
64 continue
65 }
66 wg.Add(1)
67 go func() {
68 handler(ch, inReqs, t)
69 wg.Done()
70 }()
71 }
72 if err := conn.Wait(); err != io.EOF {
73 t.Logf("server exit reason: %v", err)
74 }
75 }()
76
77 config := &ClientConfig{
78 User: "testuser",
79 HostKeyCallback: InsecureIgnoreHostKey(),
80 }
81
82 conn, chans, reqs, err := NewClientConn(c2, "", config)
83 if err != nil {
84 t.Fatalf("unable to dial remote side: %v", err)
85 }
86
87 return NewClient(conn, chans, reqs)
88 }
89
90
91 func TestSessionShell(t *testing.T) {
92 conn := dial(shellHandler, t)
93 defer conn.Close()
94 session, err := conn.NewSession()
95 if err != nil {
96 t.Fatalf("Unable to request new session: %v", err)
97 }
98 defer session.Close()
99 stdout := new(bytes.Buffer)
100 session.Stdout = stdout
101 if err := session.Shell(); err != nil {
102 t.Fatalf("Unable to execute command: %s", err)
103 }
104 if err := session.Wait(); err != nil {
105 t.Fatalf("Remote command did not exit cleanly: %v", err)
106 }
107 actual := stdout.String()
108 if actual != "golang" {
109 t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
110 }
111 }
112
113
114
115
116 func TestSessionStdoutPipe(t *testing.T) {
117 conn := dial(shellHandler, t)
118 defer conn.Close()
119 session, err := conn.NewSession()
120 if err != nil {
121 t.Fatalf("Unable to request new session: %v", err)
122 }
123 defer session.Close()
124 stdout, err := session.StdoutPipe()
125 if err != nil {
126 t.Fatalf("Unable to request StdoutPipe(): %v", err)
127 }
128 var buf bytes.Buffer
129 if err := session.Shell(); err != nil {
130 t.Fatalf("Unable to execute command: %v", err)
131 }
132 done := make(chan bool, 1)
133 go func() {
134 if _, err := io.Copy(&buf, stdout); err != nil {
135 t.Errorf("Copy of stdout failed: %v", err)
136 }
137 done <- true
138 }()
139 if err := session.Wait(); err != nil {
140 t.Fatalf("Remote command did not exit cleanly: %v", err)
141 }
142 <-done
143 actual := buf.String()
144 if actual != "golang" {
145 t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
146 }
147 }
148
149
150
151 func TestSessionOutput(t *testing.T) {
152 conn := dial(fixedOutputHandler, t)
153 defer conn.Close()
154 session, err := conn.NewSession()
155 if err != nil {
156 t.Fatalf("Unable to request new session: %v", err)
157 }
158 defer session.Close()
159
160 buf, err := session.Output("")
161 if err != nil {
162 t.Error("Remote command did not exit cleanly:", err)
163 }
164 w := "this-is-stdout."
165 g := string(buf)
166 if g != w {
167 t.Error("Remote command did not return expected string:")
168 t.Logf("want %q", w)
169 t.Logf("got %q", g)
170 }
171 }
172
173
174
175 func TestSessionCombinedOutput(t *testing.T) {
176 conn := dial(fixedOutputHandler, t)
177 defer conn.Close()
178 session, err := conn.NewSession()
179 if err != nil {
180 t.Fatalf("Unable to request new session: %v", err)
181 }
182 defer session.Close()
183
184 buf, err := session.CombinedOutput("")
185 if err != nil {
186 t.Error("Remote command did not exit cleanly:", err)
187 }
188 const stdout = "this-is-stdout."
189 const stderr = "this-is-stderr."
190 g := string(buf)
191 if g != stdout+stderr && g != stderr+stdout {
192 t.Error("Remote command did not return expected string:")
193 t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
194 t.Logf("got %q", g)
195 }
196 }
197
198
199 func TestExitStatusNonZero(t *testing.T) {
200 conn := dial(exitStatusNonZeroHandler, t)
201 defer conn.Close()
202 session, err := conn.NewSession()
203 if err != nil {
204 t.Fatalf("Unable to request new session: %v", err)
205 }
206 defer session.Close()
207 if err := session.Shell(); err != nil {
208 t.Fatalf("Unable to execute command: %v", err)
209 }
210 err = session.Wait()
211 if err == nil {
212 t.Fatalf("expected command to fail but it didn't")
213 }
214 e, ok := err.(*ExitError)
215 if !ok {
216 t.Fatalf("expected *ExitError but got %T", err)
217 }
218 if e.ExitStatus() != 15 {
219 t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
220 }
221 }
222
223
224 func TestExitStatusZero(t *testing.T) {
225 conn := dial(exitStatusZeroHandler, t)
226 defer conn.Close()
227 session, err := conn.NewSession()
228 if err != nil {
229 t.Fatalf("Unable to request new session: %v", err)
230 }
231 defer session.Close()
232
233 if err := session.Shell(); err != nil {
234 t.Fatalf("Unable to execute command: %v", err)
235 }
236 err = session.Wait()
237 if err != nil {
238 t.Fatalf("expected nil but got %v", err)
239 }
240 }
241
242
243 func TestExitSignalAndStatus(t *testing.T) {
244 conn := dial(exitSignalAndStatusHandler, t)
245 defer conn.Close()
246 session, err := conn.NewSession()
247 if err != nil {
248 t.Fatalf("Unable to request new session: %v", err)
249 }
250 defer session.Close()
251 if err := session.Shell(); err != nil {
252 t.Fatalf("Unable to execute command: %v", err)
253 }
254 err = session.Wait()
255 if err == nil {
256 t.Fatalf("expected command to fail but it didn't")
257 }
258 e, ok := err.(*ExitError)
259 if !ok {
260 t.Fatalf("expected *ExitError but got %T", err)
261 }
262 if e.Signal() != "TERM" || e.ExitStatus() != 15 {
263 t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
264 }
265 }
266
267
268 func TestKnownExitSignalOnly(t *testing.T) {
269 conn := dial(exitSignalHandler, t)
270 defer conn.Close()
271 session, err := conn.NewSession()
272 if err != nil {
273 t.Fatalf("Unable to request new session: %v", err)
274 }
275 defer session.Close()
276 if err := session.Shell(); err != nil {
277 t.Fatalf("Unable to execute command: %v", err)
278 }
279 err = session.Wait()
280 if err == nil {
281 t.Fatalf("expected command to fail but it didn't")
282 }
283 e, ok := err.(*ExitError)
284 if !ok {
285 t.Fatalf("expected *ExitError but got %T", err)
286 }
287 if e.Signal() != "TERM" || e.ExitStatus() != 143 {
288 t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
289 }
290 }
291
292
293 func TestUnknownExitSignal(t *testing.T) {
294 conn := dial(exitSignalUnknownHandler, t)
295 defer conn.Close()
296 session, err := conn.NewSession()
297 if err != nil {
298 t.Fatalf("Unable to request new session: %v", err)
299 }
300 defer session.Close()
301 if err := session.Shell(); err != nil {
302 t.Fatalf("Unable to execute command: %v", err)
303 }
304 err = session.Wait()
305 if err == nil {
306 t.Fatalf("expected command to fail but it didn't")
307 }
308 e, ok := err.(*ExitError)
309 if !ok {
310 t.Fatalf("expected *ExitError but got %T", err)
311 }
312 if e.Signal() != "SYS" || e.ExitStatus() != 128 {
313 t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
314 }
315 }
316
317 func TestExitWithoutStatusOrSignal(t *testing.T) {
318 conn := dial(exitWithoutSignalOrStatus, t)
319 defer conn.Close()
320 session, err := conn.NewSession()
321 if err != nil {
322 t.Fatalf("Unable to request new session: %v", err)
323 }
324 defer session.Close()
325 if err := session.Shell(); err != nil {
326 t.Fatalf("Unable to execute command: %v", err)
327 }
328 err = session.Wait()
329 if err == nil {
330 t.Fatalf("expected command to fail but it didn't")
331 }
332 if _, ok := err.(*ExitMissingError); !ok {
333 t.Fatalf("got %T want *ExitMissingError", err)
334 }
335 }
336
337
338 const windowTestBytes = 16000 * 200
339
340
341
342 func TestServerWindow(t *testing.T) {
343 origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
344 io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
345 origBytes := origBuf.Bytes()
346
347 conn := dial(echoHandler, t)
348 defer conn.Close()
349 session, err := conn.NewSession()
350 if err != nil {
351 t.Fatal(err)
352 }
353 defer session.Close()
354
355 serverStdin, err := session.StdinPipe()
356 if err != nil {
357 t.Fatalf("StdinPipe failed: %v", err)
358 }
359
360 result := make(chan []byte)
361 go func() {
362 defer close(result)
363 echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
364 serverStdout, err := session.StdoutPipe()
365 if err != nil {
366 t.Errorf("StdoutPipe failed: %v", err)
367 return
368 }
369 n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
370 if err != nil && err != io.EOF {
371 t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
372 }
373 result <- echoedBuf.Bytes()
374 }()
375
376 written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
377 if err != nil {
378 t.Errorf("failed to copy origBuf to serverStdin: %v", err)
379 } else if written != windowTestBytes {
380 t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes)
381 }
382
383 echoedBytes := <-result
384
385 if !bytes.Equal(origBytes, echoedBytes) {
386 t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
387 }
388 }
389
390
391 func TestClientHandlesKeepalives(t *testing.T) {
392 conn := dial(channelKeepaliveSender, t)
393 defer conn.Close()
394 session, err := conn.NewSession()
395 if err != nil {
396 t.Fatal(err)
397 }
398 defer session.Close()
399 if err := session.Shell(); err != nil {
400 t.Fatalf("Unable to execute command: %v", err)
401 }
402 err = session.Wait()
403 if err != nil {
404 t.Fatalf("expected nil but got: %v", err)
405 }
406 }
407
408 type exitStatusMsg struct {
409 Status uint32
410 }
411
412 type exitSignalMsg struct {
413 Signal string
414 CoreDumped bool
415 Errmsg string
416 Lang string
417 }
418
419 func handleTerminalRequests(in <-chan *Request) {
420 for req := range in {
421 ok := false
422 switch req.Type {
423 case "shell":
424 ok = true
425 if len(req.Payload) > 0 {
426
427 ok = false
428 }
429 case "env":
430 ok = true
431 }
432 req.Reply(ok, nil)
433 }
434 }
435
436 func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
437 term := terminal.NewTerminal(ch, prompt)
438 go handleTerminalRequests(in)
439 return term
440 }
441
442 func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
443 defer ch.Close()
444
445 shell := newServerShell(ch, in, "> ")
446 readLine(shell, t)
447 sendStatus(0, ch, t)
448 }
449
450 func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
451 defer ch.Close()
452 shell := newServerShell(ch, in, "> ")
453 readLine(shell, t)
454 sendStatus(15, ch, t)
455 }
456
457 func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
458 defer ch.Close()
459 shell := newServerShell(ch, in, "> ")
460 readLine(shell, t)
461 sendStatus(15, ch, t)
462 sendSignal("TERM", ch, t)
463 }
464
465 func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
466 defer ch.Close()
467 shell := newServerShell(ch, in, "> ")
468 readLine(shell, t)
469 sendSignal("TERM", ch, t)
470 }
471
472 func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
473 defer ch.Close()
474 shell := newServerShell(ch, in, "> ")
475 readLine(shell, t)
476 sendSignal("SYS", ch, t)
477 }
478
479 func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
480 defer ch.Close()
481 shell := newServerShell(ch, in, "> ")
482 readLine(shell, t)
483 }
484
485 func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
486 defer ch.Close()
487
488 shell := newServerShell(ch, in, "golang")
489 readLine(shell, t)
490 sendStatus(0, ch, t)
491 }
492
493
494
495 func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
496 defer ch.Close()
497 _, err := ch.Read(nil)
498
499 req, ok := <-in
500 if !ok {
501 t.Fatalf("error: expected channel request, got: %#v", err)
502 return
503 }
504
505
506 req.Reply(true, nil)
507
508 _, err = io.WriteString(ch, "this-is-stdout.")
509 if err != nil {
510 t.Fatalf("error writing on server: %v", err)
511 }
512 _, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
513 if err != nil {
514 t.Fatalf("error writing on server: %v", err)
515 }
516 sendStatus(0, ch, t)
517 }
518
519 func readLine(shell *terminal.Terminal, t *testing.T) {
520 if _, err := shell.ReadLine(); err != nil && err != io.EOF {
521 t.Errorf("unable to read line: %v", err)
522 }
523 }
524
525 func sendStatus(status uint32, ch Channel, t *testing.T) {
526 msg := exitStatusMsg{
527 Status: status,
528 }
529 if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
530 t.Errorf("unable to send status: %v", err)
531 }
532 }
533
534 func sendSignal(signal string, ch Channel, t *testing.T) {
535 sig := exitSignalMsg{
536 Signal: signal,
537 CoreDumped: false,
538 Errmsg: "Process terminated",
539 Lang: "en-GB-oed",
540 }
541 if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
542 t.Errorf("unable to send signal: %v", err)
543 }
544 }
545
546 func discardHandler(ch Channel, t *testing.T) {
547 defer ch.Close()
548 io.Copy(io.Discard, ch)
549 }
550
551 func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
552 defer ch.Close()
553 if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
554 t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
555 }
556 }
557
558
559
560 func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
561 var (
562 buf = make([]byte, 32*1024)
563 written int
564 remaining = n
565 )
566 for remaining > 0 {
567 l := rand.Intn(1 << 15)
568 if remaining < l {
569 l = remaining
570 }
571 nr, er := src.Read(buf[:l])
572 nw, ew := dst.Write(buf[:nr])
573 remaining -= nw
574 written += nw
575 if ew != nil {
576 return written, ew
577 }
578 if nr != nw {
579 return written, io.ErrShortWrite
580 }
581 if er != nil && er != io.EOF {
582 return written, er
583 }
584 }
585 return written, nil
586 }
587
588 func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
589 defer ch.Close()
590 shell := newServerShell(ch, in, "> ")
591 readLine(shell, t)
592 if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
593 t.Errorf("unable to send channel keepalive request: %v", err)
594 }
595 sendStatus(0, ch, t)
596 }
597
598 func TestClientWriteEOF(t *testing.T) {
599 conn := dial(simpleEchoHandler, t)
600 defer conn.Close()
601
602 session, err := conn.NewSession()
603 if err != nil {
604 t.Fatal(err)
605 }
606 defer session.Close()
607 stdin, err := session.StdinPipe()
608 if err != nil {
609 t.Fatalf("StdinPipe failed: %v", err)
610 }
611 stdout, err := session.StdoutPipe()
612 if err != nil {
613 t.Fatalf("StdoutPipe failed: %v", err)
614 }
615
616 data := []byte(`0000`)
617 _, err = stdin.Write(data)
618 if err != nil {
619 t.Fatalf("Write failed: %v", err)
620 }
621 stdin.Close()
622
623 res, err := io.ReadAll(stdout)
624 if err != nil {
625 t.Fatalf("Read failed: %v", err)
626 }
627
628 if !bytes.Equal(data, res) {
629 t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
630 }
631 }
632
633 func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
634 defer ch.Close()
635 data, err := io.ReadAll(ch)
636 if err != nil {
637 t.Errorf("handler read error: %v", err)
638 }
639 _, err = ch.Write(data)
640 if err != nil {
641 t.Errorf("handler write error: %v", err)
642 }
643 }
644
645 func TestSessionID(t *testing.T) {
646 c1, c2, err := netPipe()
647 if err != nil {
648 t.Fatalf("netPipe: %v", err)
649 }
650 defer c1.Close()
651 defer c2.Close()
652
653 serverID := make(chan []byte, 1)
654 clientID := make(chan []byte, 1)
655
656 serverConf := &ServerConfig{
657 NoClientAuth: true,
658 }
659 serverConf.AddHostKey(testSigners["ecdsa"])
660 clientConf := &ClientConfig{
661 HostKeyCallback: InsecureIgnoreHostKey(),
662 User: "user",
663 }
664
665 var wg sync.WaitGroup
666 t.Cleanup(wg.Wait)
667
668 srvErrCh := make(chan error, 1)
669 wg.Add(1)
670 go func() {
671 defer wg.Done()
672 conn, chans, reqs, err := NewServerConn(c1, serverConf)
673 srvErrCh <- err
674 if err != nil {
675 return
676 }
677 serverID <- conn.SessionID()
678 wg.Add(1)
679 go func() {
680 DiscardRequests(reqs)
681 wg.Done()
682 }()
683 for ch := range chans {
684 ch.Reject(Prohibited, "")
685 }
686 }()
687
688 cliErrCh := make(chan error, 1)
689 wg.Add(1)
690 go func() {
691 defer wg.Done()
692 conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
693 cliErrCh <- err
694 if err != nil {
695 return
696 }
697 clientID <- conn.SessionID()
698 wg.Add(1)
699 go func() {
700 DiscardRequests(reqs)
701 wg.Done()
702 }()
703 for ch := range chans {
704 ch.Reject(Prohibited, "")
705 }
706 }()
707
708 if err := <-srvErrCh; err != nil {
709 t.Fatalf("server handshake: %v", err)
710 }
711
712 if err := <-cliErrCh; err != nil {
713 t.Fatalf("client handshake: %v", err)
714 }
715
716 s := <-serverID
717 c := <-clientID
718 if bytes.Compare(s, c) != 0 {
719 t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
720 } else if len(s) == 0 {
721 t.Errorf("client and server SessionID were empty.")
722 }
723 }
724
725 type noReadConn struct {
726 readSeen bool
727 net.Conn
728 }
729
730 func (c *noReadConn) Close() error {
731 return nil
732 }
733
734 func (c *noReadConn) Read(b []byte) (int, error) {
735 c.readSeen = true
736 return 0, errors.New("noReadConn error")
737 }
738
739 func TestInvalidServerConfiguration(t *testing.T) {
740 c1, c2, err := netPipe()
741 if err != nil {
742 t.Fatalf("netPipe: %v", err)
743 }
744 defer c1.Close()
745 defer c2.Close()
746
747 serveConn := noReadConn{Conn: c1}
748 serverConf := &ServerConfig{}
749
750 NewServerConn(&serveConn, serverConf)
751 if serveConn.readSeen {
752 t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
753 }
754
755 serverConf.AddHostKey(testSigners["ecdsa"])
756
757 NewServerConn(&serveConn, serverConf)
758 if serveConn.readSeen {
759 t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
760 }
761 }
762
763 func TestHostKeyAlgorithms(t *testing.T) {
764 serverConf := &ServerConfig{
765 NoClientAuth: true,
766 }
767 serverConf.AddHostKey(testSigners["rsa"])
768 serverConf.AddHostKey(testSigners["ecdsa"])
769
770 var wg sync.WaitGroup
771 t.Cleanup(wg.Wait)
772 connect := func(clientConf *ClientConfig, want string) {
773 var alg string
774 clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
775 alg = key.Type()
776 return nil
777 }
778 c1, c2, err := netPipe()
779 if err != nil {
780 t.Fatalf("netPipe: %v", err)
781 }
782 defer c1.Close()
783 defer c2.Close()
784
785 wg.Add(1)
786 go func() {
787 NewServerConn(c1, serverConf)
788 wg.Done()
789 }()
790 _, _, _, err = NewClientConn(c2, "", clientConf)
791 if err != nil {
792 t.Fatalf("NewClientConn: %v", err)
793 }
794 if alg != want {
795 t.Errorf("selected key algorithm %s, want %s", alg, want)
796 }
797 }
798
799
800
801 clientConf := &ClientConfig{
802 HostKeyCallback: InsecureIgnoreHostKey(),
803 }
804 connect(clientConf, KeyAlgoECDSA256)
805
806
807 clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA}
808 connect(clientConf, KeyAlgoRSA)
809
810
811 clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512}
812
813
814 connect(clientConf, KeyAlgoRSA)
815
816 c1, c2, err := netPipe()
817 if err != nil {
818 t.Fatalf("netPipe: %v", err)
819 }
820 defer c1.Close()
821 defer c2.Close()
822
823 wg.Add(1)
824 go func() {
825 NewServerConn(c1, serverConf)
826 wg.Done()
827 }()
828 clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
829 _, _, _, err = NewClientConn(c2, "", clientConf)
830 if err == nil {
831 t.Fatal("succeeded connecting with unknown hostkey algorithm")
832 }
833 }
834
835 func TestServerClientAuthCallback(t *testing.T) {
836 c1, c2, err := netPipe()
837 if err != nil {
838 t.Fatalf("netPipe: %v", err)
839 }
840 defer c1.Close()
841 defer c2.Close()
842
843 userCh := make(chan string, 1)
844
845 serverConf := &ServerConfig{
846 NoClientAuth: true,
847 NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
848 userCh <- conn.User()
849 return nil, nil
850 },
851 }
852 const someUsername = "some-username"
853
854 serverConf.AddHostKey(testSigners["ecdsa"])
855 clientConf := &ClientConfig{
856 HostKeyCallback: InsecureIgnoreHostKey(),
857 User: someUsername,
858 }
859
860 var wg sync.WaitGroup
861 t.Cleanup(wg.Wait)
862 wg.Add(1)
863 go func() {
864 defer wg.Done()
865 _, chans, reqs, err := NewServerConn(c1, serverConf)
866 if err != nil {
867 t.Errorf("server handshake: %v", err)
868 userCh <- "error"
869 return
870 }
871 wg.Add(1)
872 go func() {
873 DiscardRequests(reqs)
874 wg.Done()
875 }()
876 for ch := range chans {
877 ch.Reject(Prohibited, "")
878 }
879 }()
880
881 conn, _, _, err := NewClientConn(c2, "", clientConf)
882 if err != nil {
883 t.Fatalf("client handshake: %v", err)
884 return
885 }
886 conn.Close()
887
888 got := <-userCh
889 if got != someUsername {
890 t.Errorf("username = %q; want %q", got, someUsername)
891 }
892 }
893
View as plain text