1
2
3
4
5
6
7 package test
8
9
10
11 import (
12 "bytes"
13 "crypto/rand"
14 "encoding/base64"
15 "fmt"
16 "log"
17 "net"
18 "os"
19 "os/exec"
20 "os/user"
21 "path/filepath"
22 "testing"
23 "text/template"
24
25 "golang.org/x/crypto/internal/testenv"
26 "golang.org/x/crypto/ssh"
27 "golang.org/x/crypto/ssh/testdata"
28 )
29
30 const (
31 defaultSshdConfig = `
32 Protocol 2
33 Banner {{.Dir}}/banner
34 HostKey {{.Dir}}/id_rsa
35 HostKey {{.Dir}}/id_dsa
36 HostKey {{.Dir}}/id_ecdsa
37 HostCertificate {{.Dir}}/id_rsa-sha2-512-cert.pub
38 Pidfile {{.Dir}}/sshd.pid
39 #UsePrivilegeSeparation no
40 KeyRegenerationInterval 3600
41 ServerKeyBits 768
42 SyslogFacility AUTH
43 LogLevel DEBUG2
44 LoginGraceTime 120
45 PermitRootLogin no
46 StrictModes no
47 RSAAuthentication yes
48 PubkeyAuthentication yes
49 AuthorizedKeysFile {{.Dir}}/authorized_keys
50 TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
51 IgnoreRhosts yes
52 RhostsRSAAuthentication no
53 HostbasedAuthentication no
54 PubkeyAcceptedKeyTypes=*
55 `
56 multiAuthSshdConfigTail = `
57 UsePAM yes
58 PasswordAuthentication yes
59 ChallengeResponseAuthentication yes
60 AuthenticationMethods {{.AuthMethods}}
61 `
62 )
63
64 var configTmpl = map[string]*template.Template{
65 "default": template.Must(template.New("").Parse(defaultSshdConfig)),
66 "MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
67
68 type server struct {
69 t *testing.T
70 configfile string
71
72 testUser string
73 testPasswd string
74 sshdTestPwSo string
75
76 lastDialConn net.Conn
77 }
78
79 func username() string {
80 var username string
81 if user, err := user.Current(); err == nil {
82 username = user.Username
83 } else {
84
85
86 log.Printf("user.Current: %v; falling back on $USER", err)
87 username = os.Getenv("USER")
88 }
89 if username == "" {
90 panic("Unable to get username")
91 }
92 return username
93 }
94
95 type storedHostKey struct {
96
97 keys map[string][]byte
98
99
100
101 checkCount int
102 }
103
104 func (k *storedHostKey) Add(key ssh.PublicKey) {
105 if k.keys == nil {
106 k.keys = map[string][]byte{}
107 }
108 k.keys[key.Type()] = key.Marshal()
109 }
110
111 func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
112 k.checkCount++
113 algo := key.Type()
114
115 if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
116 return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
117 }
118 return nil
119 }
120
121 func hostKeyDB() *storedHostKey {
122 keyChecker := &storedHostKey{}
123 keyChecker.Add(testPublicKeys["ecdsa"])
124 keyChecker.Add(testPublicKeys["rsa"])
125 keyChecker.Add(testPublicKeys["dsa"])
126 return keyChecker
127 }
128
129 func clientConfig() *ssh.ClientConfig {
130 config := &ssh.ClientConfig{
131 User: username(),
132 Auth: []ssh.AuthMethod{
133 ssh.PublicKeys(testSigners["user"]),
134 },
135 HostKeyCallback: hostKeyDB().Check,
136 HostKeyAlgorithms: []string{
137 ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
138 ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
139 ssh.KeyAlgoED25519,
140 },
141 }
142 return config
143 }
144
145
146
147
148 func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
149 dir, err := os.MkdirTemp("", "unixConnection")
150 if err != nil {
151 return nil, nil, err
152 }
153 defer os.Remove(dir)
154
155 addr := filepath.Join(dir, "ssh")
156 listener, err := net.Listen("unix", addr)
157 if err != nil {
158 return nil, nil, err
159 }
160 defer listener.Close()
161 c1, err := net.Dial("unix", addr)
162 if err != nil {
163 return nil, nil, err
164 }
165
166 c2, err := listener.Accept()
167 if err != nil {
168 c1.Close()
169 return nil, nil, err
170 }
171
172 return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
173 }
174
175 func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
176 return s.TryDialWithAddr(config, "")
177 }
178
179
180
181 func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
182 sshd, err := exec.LookPath("sshd")
183 if err != nil {
184 s.t.Skipf("skipping test: %v", err)
185 }
186
187 c1, c2, err := unixConnection()
188 if err != nil {
189 s.t.Fatalf("unixConnection: %v", err)
190 }
191 defer func() {
192
193
194 c2.Close()
195
196
197
198
199 if client == nil {
200 c1.Close()
201 }
202 }()
203
204 f, err := c2.File()
205 if err != nil {
206 s.t.Fatalf("UnixConn.File: %v", err)
207 }
208 defer f.Close()
209
210 cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
211 cmd.Stdin = f
212 cmd.Stdout = f
213 cmd.Stderr = new(bytes.Buffer)
214
215 if s.sshdTestPwSo != "" {
216 if s.testUser == "" {
217 s.t.Fatal("user missing from sshd_test_pw.so config")
218 }
219 if s.testPasswd == "" {
220 s.t.Fatal("password missing from sshd_test_pw.so config")
221 }
222 cmd.Env = append(os.Environ(),
223 fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
224 fmt.Sprintf("TEST_USER=%s", s.testUser),
225 fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
226 }
227
228 if err := cmd.Start(); err != nil {
229 s.t.Fatalf("s.cmd.Start: %v", err)
230 }
231 s.lastDialConn = c1
232 s.t.Cleanup(func() {
233
234
235
236
237 cmd.Process.Signal(os.Interrupt)
238 cmd.Wait()
239 if s.t.Failed() || testing.Verbose() {
240
241 s.t.Logf("sshd:\n%s", cmd.Stderr)
242 }
243 })
244
245 conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
246 if err != nil {
247 return nil, err
248 }
249 return ssh.NewClient(conn, chans, reqs), nil
250 }
251
252 func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
253 conn, err := s.TryDial(config)
254 if err != nil {
255 s.t.Fatalf("ssh.Client: %v", err)
256 }
257 return conn
258 }
259
260 func writeFile(path string, contents []byte) {
261 f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
262 if err != nil {
263 panic(err)
264 }
265 defer f.Close()
266 if _, err := f.Write(contents); err != nil {
267 panic(err)
268 }
269 }
270
271
272 func randomPassword() (string, error) {
273 b := make([]byte, 12)
274 _, err := rand.Read(b)
275 if err != nil {
276 return "", err
277 }
278 return base64.RawURLEncoding.EncodeToString(b), nil
279 }
280
281
282
283 func (s *server) setTestPassword(user, passwd string) error {
284 wd, _ := os.Getwd()
285 wrapper := filepath.Join(wd, "sshd_test_pw.so")
286 if _, err := os.Stat(wrapper); err != nil {
287 s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
288 return err
289 }
290
291 s.sshdTestPwSo = wrapper
292 s.testUser = user
293 s.testPasswd = passwd
294 return nil
295 }
296
297
298 func newServer(t *testing.T) *server {
299 return newServerForConfig(t, "default", map[string]string{})
300 }
301
302
303 func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
304 if testing.Short() {
305 t.Skip("skipping test due to -short")
306 }
307 u, err := user.Current()
308 if err != nil {
309 t.Fatalf("user.Current: %v", err)
310 }
311 uname := u.Name
312 if uname == "" {
313
314
315 uname = u.Username
316 }
317 if uname == "root" {
318 t.Skip("skipping test because current user is root")
319 }
320 dir, err := os.MkdirTemp("", "sshtest")
321 if err != nil {
322 t.Fatal(err)
323 }
324 f, err := os.Create(filepath.Join(dir, "sshd_config"))
325 if err != nil {
326 t.Fatal(err)
327 }
328 if _, ok := configTmpl[config]; ok == false {
329 t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
330 }
331 configVars["Dir"] = dir
332 err = configTmpl[config].Execute(f, configVars)
333 if err != nil {
334 t.Fatal(err)
335 }
336 f.Close()
337
338 writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
339
340 for k, v := range testdata.PEMBytes {
341 filename := "id_" + k
342 writeFile(filepath.Join(dir, filename), v)
343 writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
344 }
345
346 for k, v := range testdata.SSHCertificates {
347 filename := "id_" + k + "-cert.pub"
348 writeFile(filepath.Join(dir, filename), v)
349 }
350
351 var authkeys bytes.Buffer
352 for k := range testdata.PEMBytes {
353 authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
354 }
355 writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
356 t.Cleanup(func() {
357 if err := os.RemoveAll(dir); err != nil {
358 t.Error(err)
359 }
360 })
361
362 return &server{
363 t: t,
364 configfile: f.Name(),
365 }
366 }
367
368 func newTempSocket(t *testing.T) (string, func()) {
369 dir, err := os.MkdirTemp("", "socket")
370 if err != nil {
371 t.Fatal(err)
372 }
373 deferFunc := func() { os.RemoveAll(dir) }
374 addr := filepath.Join(dir, "sock")
375 return addr, deferFunc
376 }
377
View as plain text