1
2
3
4
5 package agent
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "errors"
11 "io"
12 "net"
13 "os"
14 "os/exec"
15 "path/filepath"
16 "runtime"
17 "strconv"
18 "strings"
19 "testing"
20 "time"
21
22 "golang.org/x/crypto/ssh"
23 )
24
25
26 func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) {
27 if testing.Short() {
28
29
30 t.Skip("skipping test due to -short")
31 }
32 if runtime.GOOS == "windows" {
33 t.Skip("skipping on windows, we don't support connecting to the ssh-agent via a named pipe")
34 }
35
36 bin, err := exec.LookPath("ssh-agent")
37 if err != nil {
38 t.Skip("could not find ssh-agent")
39 }
40
41 cmd := exec.Command(bin, "-s")
42 cmd.Env = []string{}
43 cmd.Stderr = new(bytes.Buffer)
44 out, err := cmd.Output()
45 if err != nil {
46 t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
47 }
48
49
50
51
52
53
54
55 fields := bytes.Split(out, []byte(";"))
56 line := bytes.SplitN(fields[0], []byte("="), 2)
57 line[0] = bytes.TrimLeft(line[0], "\n")
58 if string(line[0]) != "SSH_AUTH_SOCK" {
59 t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
60 }
61 socket = string(line[1])
62
63 line = bytes.SplitN(fields[2], []byte("="), 2)
64 line[0] = bytes.TrimLeft(line[0], "\n")
65 if string(line[0]) != "SSH_AGENT_PID" {
66 t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
67 }
68 pidStr := line[1]
69 pid, err := strconv.Atoi(string(pidStr))
70 if err != nil {
71 t.Fatalf("Atoi(%q): %v", pidStr, err)
72 }
73
74 conn, err := net.Dial("unix", string(socket))
75 if err != nil {
76 t.Fatalf("net.Dial: %v", err)
77 }
78
79 ac := NewClient(conn)
80 return ac, socket, func() {
81 proc, _ := os.FindProcess(pid)
82 if proc != nil {
83 proc.Kill()
84 }
85 conn.Close()
86 os.RemoveAll(filepath.Dir(socket))
87 }
88 }
89
90 func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) {
91 c1, c2, err := netPipe()
92 if err != nil {
93 t.Fatalf("netPipe: %v", err)
94 }
95 go ServeAgent(agent, c2)
96
97 return NewClient(c1), func() {
98 c1.Close()
99 c2.Close()
100 }
101 }
102
103
104 func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) {
105 return startAgent(t, NewKeyring())
106 }
107
108 func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
109 agent, _, cleanup := startOpenSSHAgent(t)
110 defer cleanup()
111
112 testAgentInterface(t, agent, key, cert, lifetimeSecs)
113 }
114
115 func testKeyringAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
116 agent, cleanup := startKeyringAgent(t)
117 defer cleanup()
118
119 testAgentInterface(t, agent, key, cert, lifetimeSecs)
120 }
121
122 func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
123 signer, err := ssh.NewSignerFromKey(key)
124 if err != nil {
125 t.Fatalf("NewSignerFromKey(%T): %v", key, err)
126 }
127
128 if keys, err := agent.List(); err != nil {
129 t.Fatalf("RequestIdentities: %v", err)
130 } else if len(keys) > 0 {
131 t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
132 }
133
134
135 var pubKey ssh.PublicKey
136 if cert != nil {
137 err = agent.Add(AddedKey{
138 PrivateKey: key,
139 Certificate: cert,
140 Comment: "comment",
141 LifetimeSecs: lifetimeSecs,
142 })
143 pubKey = cert
144 } else {
145 err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs})
146 pubKey = signer.PublicKey()
147 }
148 if err != nil {
149 t.Fatalf("insert(%T): %v", key, err)
150 }
151
152
153 if keys, err := agent.List(); err != nil {
154 t.Fatalf("List: %v", err)
155 } else if len(keys) != 1 {
156 t.Fatalf("got %v, want 1 key", keys)
157 } else if keys[0].Comment != "comment" {
158 t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
159 } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
160 t.Fatalf("key mismatch")
161 }
162
163
164 data := []byte("hello")
165 sig, err := agent.Sign(pubKey, data)
166 if err != nil {
167 t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
168 }
169
170 if err := pubKey.Verify(data, sig); err != nil {
171 t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
172 }
173
174
175 if pubKey.Type() == "ssh-rsa" {
176 sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
177 sig, err = agent.SignWithFlags(pubKey, data, flag)
178 if err != nil {
179 t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err)
180 }
181 if sig.Format != expectedSigFormat {
182 t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat)
183 }
184 if err := pubKey.Verify(data, sig); err != nil {
185 t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
186 }
187 }
188 sshFlagTest(0, ssh.KeyAlgoRSA)
189 sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
190 sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
191 }
192
193
194 if lifetimeSecs > 0 {
195 time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond)
196 keys, err := agent.List()
197 if err != nil {
198 t.Fatalf("List: %v", err)
199 }
200 if len(keys) > 0 {
201 t.Fatalf("key not expired")
202 }
203 }
204
205 }
206
207 func TestMalformedRequests(t *testing.T) {
208 keyringAgent := NewKeyring()
209
210 testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
211 c, s := net.Pipe()
212 defer c.Close()
213 defer s.Close()
214 go func() {
215 _, err := c.Write(requestBytes)
216 if err != nil {
217 t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
218 }
219 c.Close()
220 }()
221 err := ServeAgent(keyringAgent, s)
222 if err == nil {
223 t.Error("ServeAgent should have returned an error to malformed input")
224 } else {
225 if (err != io.EOF) != wantServerErr {
226 t.Errorf("ServeAgent returned expected error: %v", err)
227 }
228 }
229 }
230
231 var testCases = []struct {
232 name string
233 requestBytes []byte
234 wantServerErr bool
235 }{
236 {"Empty request", []byte{}, false},
237 {"Short header", []byte{0x00}, true},
238 {"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
239 {"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
240 }
241 for _, tc := range testCases {
242 t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
243 }
244 }
245
246 func TestAgent(t *testing.T) {
247 for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
248 testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
249 testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
250 }
251 }
252
253 func TestCert(t *testing.T) {
254 cert := &ssh.Certificate{
255 Key: testPublicKeys["rsa"],
256 ValidBefore: ssh.CertTimeInfinity,
257 CertType: ssh.UserCert,
258 }
259 cert.SignCert(rand.Reader, testSigners["ecdsa"])
260
261 testOpenSSHAgent(t, testPrivateKeys["rsa"], cert, 0)
262 testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
263 }
264
265
266 func netListener() (net.Listener, error) {
267 listener, err := net.Listen("tcp", "127.0.0.1:0")
268 if err != nil {
269 listener, err = net.Listen("tcp", "[::1]:0")
270 if err != nil {
271 return nil, err
272 }
273 }
274 return listener, nil
275 }
276
277
278
279
280 func netPipe() (net.Conn, net.Conn, error) {
281 listener, err := netListener()
282 if err != nil {
283 return nil, nil, err
284 }
285 defer listener.Close()
286 c1, err := net.Dial("tcp", listener.Addr().String())
287 if err != nil {
288 return nil, nil, err
289 }
290
291 c2, err := listener.Accept()
292 if err != nil {
293 c1.Close()
294 return nil, nil, err
295 }
296
297 return c1, c2, nil
298 }
299
300 func TestServerResponseTooLarge(t *testing.T) {
301 a, b, err := netPipe()
302 if err != nil {
303 t.Fatalf("netPipe: %v", err)
304 }
305 done := make(chan struct{})
306 defer func() { <-done }()
307
308 defer a.Close()
309 defer b.Close()
310
311 var response identitiesAnswerAgentMsg
312 response.NumKeys = 1
313 response.Keys = make([]byte, maxAgentResponseBytes+1)
314
315 agent := NewClient(a)
316 go func() {
317 defer close(done)
318 n, err := b.Write(ssh.Marshal(response))
319 if n < 4 {
320 if runtime.GOOS == "plan9" {
321 if e1, ok := err.(*net.OpError); ok {
322 if e2, ok := e1.Err.(*os.PathError); ok {
323 switch e2.Err.Error() {
324 case "Hangup", "i/o on hungup channel":
325
326 return
327 }
328 }
329 }
330 }
331 t.Errorf("At least 4 bytes (the response size) should have been successfully written: %d < 4: %v", n, err)
332 }
333 }()
334 _, err = agent.List()
335 if err == nil {
336 t.Fatal("Did not get error result")
337 }
338 if err.Error() != "agent: client error: response too large" {
339 t.Fatal("Did not get expected error result")
340 }
341 }
342
343 func TestAuth(t *testing.T) {
344 agent, _, cleanup := startOpenSSHAgent(t)
345 defer cleanup()
346
347 a, b, err := netPipe()
348 if err != nil {
349 t.Fatalf("netPipe: %v", err)
350 }
351
352 defer a.Close()
353 defer b.Close()
354
355 if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
356 t.Errorf("Add: %v", err)
357 }
358
359 serverConf := ssh.ServerConfig{}
360 serverConf.AddHostKey(testSigners["rsa"])
361 serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
362 if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
363 return nil, nil
364 }
365
366 return nil, errors.New("pubkey rejected")
367 }
368
369 go func() {
370 conn, _, _, err := ssh.NewServerConn(a, &serverConf)
371 if err != nil {
372 t.Errorf("NewServerConn error: %v", err)
373 return
374 }
375 conn.Close()
376 }()
377
378 conf := ssh.ClientConfig{
379 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
380 }
381 conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
382 conn, _, _, err := ssh.NewClientConn(b, "", &conf)
383 if err != nil {
384 t.Fatalf("NewClientConn: %v", err)
385 }
386 conn.Close()
387 }
388
389 func TestLockOpenSSHAgent(t *testing.T) {
390 agent, _, cleanup := startOpenSSHAgent(t)
391 defer cleanup()
392 testLockAgent(agent, t)
393 }
394
395 func TestLockKeyringAgent(t *testing.T) {
396 agent, cleanup := startKeyringAgent(t)
397 defer cleanup()
398 testLockAgent(agent, t)
399 }
400
401 func testLockAgent(agent Agent, t *testing.T) {
402 if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
403 t.Errorf("Add: %v", err)
404 }
405 if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil {
406 t.Errorf("Add: %v", err)
407 }
408 if keys, err := agent.List(); err != nil {
409 t.Errorf("List: %v", err)
410 } else if len(keys) != 2 {
411 t.Errorf("Want 2 keys, got %v", keys)
412 }
413
414 passphrase := []byte("secret")
415 if err := agent.Lock(passphrase); err != nil {
416 t.Errorf("Lock: %v", err)
417 }
418
419 if keys, err := agent.List(); err != nil {
420 t.Errorf("List: %v", err)
421 } else if len(keys) != 0 {
422 t.Errorf("Want 0 keys, got %v", keys)
423 }
424
425 signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
426 if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
427 t.Fatalf("Sign did not fail")
428 }
429
430 if err := agent.Remove(signer.PublicKey()); err == nil {
431 t.Fatalf("Remove did not fail")
432 }
433
434 if err := agent.RemoveAll(); err == nil {
435 t.Fatalf("RemoveAll did not fail")
436 }
437
438 if err := agent.Unlock(nil); err == nil {
439 t.Errorf("Unlock with wrong passphrase succeeded")
440 }
441 if err := agent.Unlock(passphrase); err != nil {
442 t.Errorf("Unlock: %v", err)
443 }
444
445 if err := agent.Remove(signer.PublicKey()); err != nil {
446 t.Fatalf("Remove: %v", err)
447 }
448
449 if keys, err := agent.List(); err != nil {
450 t.Errorf("List: %v", err)
451 } else if len(keys) != 1 {
452 t.Errorf("Want 1 keys, got %v", keys)
453 }
454 }
455
456 func testOpenSSHAgentLifetime(t *testing.T) {
457 agent, _, cleanup := startOpenSSHAgent(t)
458 defer cleanup()
459 testAgentLifetime(t, agent)
460 }
461
462 func testKeyringAgentLifetime(t *testing.T) {
463 agent, cleanup := startKeyringAgent(t)
464 defer cleanup()
465 testAgentLifetime(t, agent)
466 }
467
468 func testAgentLifetime(t *testing.T, agent Agent) {
469 for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
470
471 err := agent.Add(AddedKey{
472 PrivateKey: testPrivateKeys[keyType],
473 Comment: "comment",
474 LifetimeSecs: 1,
475 })
476 if err != nil {
477 t.Fatalf("add: %v", err)
478 }
479
480 cert := &ssh.Certificate{
481 Key: testPublicKeys[keyType],
482 ValidBefore: ssh.CertTimeInfinity,
483 CertType: ssh.UserCert,
484 }
485 cert.SignCert(rand.Reader, testSigners[keyType])
486 err = agent.Add(AddedKey{
487 PrivateKey: testPrivateKeys[keyType],
488 Certificate: cert,
489 Comment: "comment",
490 LifetimeSecs: 1,
491 })
492 if err != nil {
493 t.Fatalf("add: %v", err)
494 }
495 }
496 time.Sleep(1100 * time.Millisecond)
497 if keys, err := agent.List(); err != nil {
498 t.Errorf("List: %v", err)
499 } else if len(keys) != 0 {
500 t.Errorf("Want 0 keys, got %v", len(keys))
501 }
502 }
503
504 type keyringExtended struct {
505 *keyring
506 }
507
508 func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) {
509 if extensionType != "my-extension@example.com" {
510 return []byte{agentExtensionFailure}, nil
511 }
512 return append([]byte{agentSuccess}, contents...), nil
513 }
514
515 func TestAgentExtensions(t *testing.T) {
516 agent, _, cleanup := startOpenSSHAgent(t)
517 defer cleanup()
518 _, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
519 if err == nil {
520 t.Fatal("should have gotten agent extension failure")
521 }
522
523 agent, cleanup = startAgent(t, &keyringExtended{})
524 defer cleanup()
525 result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
526 if err != nil {
527 t.Fatalf("agent extension failure: %v", err)
528 }
529 if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) {
530 t.Fatalf("agent extension result invalid: %v", result)
531 }
532
533 _, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02})
534 if err == nil {
535 t.Fatal("should have gotten agent extension failure")
536 }
537 }
538
View as plain text