1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package test
20
21 import (
22 "fmt"
23 "strings"
24 "testing"
25
26 "golang.org/x/crypto/ssh"
27 )
28
29
30 type multiAuthTestCase struct {
31 authMethods []string
32 expectedPasswordCbs int
33 expectedKbdIntCbs int
34 }
35
36
37 type multiAuthTestCtx struct {
38 password string
39 numPasswordCbs int
40 numKbdIntCbs int
41 }
42
43
44 func newMultiAuthTestCtx(t *testing.T) *multiAuthTestCtx {
45 password, err := randomPassword()
46 if err != nil {
47 t.Fatalf("Failed to generate random test password: %s", err.Error())
48 }
49
50 return &multiAuthTestCtx{
51 password: password,
52 }
53 }
54
55
56 func (ctx *multiAuthTestCtx) passwordCb() (secret string, err error) {
57 ctx.numPasswordCbs++
58 return ctx.password, nil
59 }
60
61
62 func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
63 if len(questions) == 0 {
64 return nil, nil
65 }
66
67 ctx.numKbdIntCbs++
68 if len(questions) == 1 {
69 return []string{ctx.password}, nil
70 }
71
72 return nil, fmt.Errorf("unsupported keyboard-interactive flow")
73 }
74
75
76 func TestMultiAuth(t *testing.T) {
77 testCases := []multiAuthTestCase{
78
79 {
80 authMethods: []string{"password", "publickey"},
81 expectedPasswordCbs: 1,
82 },
83
84 {
85 authMethods: []string{"keyboard-interactive", "publickey"},
86 expectedKbdIntCbs: 1,
87 },
88
89 {
90 authMethods: []string{"publickey", "password"},
91 expectedPasswordCbs: 1,
92 },
93
94 {
95 authMethods: []string{"publickey", "keyboard-interactive"},
96 expectedKbdIntCbs: 1,
97 },
98
99 {
100 authMethods: []string{"password", "password"},
101 expectedPasswordCbs: 2,
102 },
103 }
104
105 for _, testCase := range testCases {
106 t.Run(strings.Join(testCase.authMethods, ","), func(t *testing.T) {
107 ctx := newMultiAuthTestCtx(t)
108
109 server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")})
110
111 clientConfig := clientConfig()
112 server.setTestPassword(clientConfig.User, ctx.password)
113
114 publicKeyAuthMethod := clientConfig.Auth[0]
115 clientConfig.Auth = nil
116 for _, authMethod := range testCase.authMethods {
117 switch authMethod {
118 case "publickey":
119 clientConfig.Auth = append(clientConfig.Auth, publicKeyAuthMethod)
120 case "password":
121 clientConfig.Auth = append(clientConfig.Auth,
122 ssh.RetryableAuthMethod(ssh.PasswordCallback(ctx.passwordCb), 5))
123 case "keyboard-interactive":
124 clientConfig.Auth = append(clientConfig.Auth,
125 ssh.RetryableAuthMethod(ssh.KeyboardInteractive(ctx.kbdIntCb), 5))
126 default:
127 t.Fatalf("Unknown authentication method %s", authMethod)
128 }
129 }
130
131 conn := server.Dial(clientConfig)
132 defer conn.Close()
133
134 if ctx.numPasswordCbs != testCase.expectedPasswordCbs {
135 t.Fatalf("passwordCallback was called %d times, expected %d times", ctx.numPasswordCbs, testCase.expectedPasswordCbs)
136 }
137
138 if ctx.numKbdIntCbs != testCase.expectedKbdIntCbs {
139 t.Fatalf("keyboardInteractiveCallback was called %d times, expected %d times", ctx.numKbdIntCbs, testCase.expectedKbdIntCbs)
140 }
141 })
142 }
143 }
144
View as plain text