// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package agent import ( "crypto" "crypto/rand" "fmt" pseudorand "math/rand" "reflect" "strings" "testing" "golang.org/x/crypto/ssh" ) func TestServer(t *testing.T) { c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() client := NewClient(c1) go ServeAgent(NewKeyring(), c2) testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) } func TestLockServer(t *testing.T) { testLockAgent(NewKeyring(), t) } func TestSetupForwardAgent(t *testing.T) { a, b, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer a.Close() defer b.Close() _, socket, cleanup := startOpenSSHAgent(t) defer cleanup() serverConf := ssh.ServerConfig{ NoClientAuth: true, } serverConf.AddHostKey(testSigners["rsa"]) incoming := make(chan *ssh.ServerConn, 1) go func() { conn, _, _, err := ssh.NewServerConn(a, &serverConf) incoming <- conn if err != nil { t.Errorf("NewServerConn error: %v", err) return } }() conf := ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), } conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) if err != nil { t.Fatalf("NewClientConn: %v", err) } client := ssh.NewClient(conn, chans, reqs) if err := ForwardToRemote(client, socket); err != nil { t.Fatalf("SetupForwardAgent: %v", err) } server := <-incoming if server == nil { t.Fatal("Unable to get server") } ch, reqs, err := server.OpenChannel(channelType, nil) if err != nil { t.Fatalf("OpenChannel(%q): %v", channelType, err) } go ssh.DiscardRequests(reqs) agentClient := NewClient(ch) testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) conn.Close() } func TestV1ProtocolMessages(t *testing.T) { c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() c := NewClient(c1) go ServeAgent(NewKeyring(), c2) testV1ProtocolMessages(t, c.(*client)) } func testV1ProtocolMessages(t *testing.T, c *client) { reply, err := c.call([]byte{agentRequestV1Identities}) if err != nil { t.Fatalf("v1 request all failed: %v", err) } if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 { t.Fatalf("invalid request all response: %#v", reply) } reply, err = c.call([]byte{agentRemoveAllV1Identities}) if err != nil { t.Fatalf("v1 remove all failed: %v", err) } if _, ok := reply.(*successAgentMsg); !ok { t.Fatalf("invalid remove all response: %#v", reply) } } func verifyKey(sshAgent Agent) error { keys, err := sshAgent.List() if err != nil { return fmt.Errorf("listing keys: %v", err) } if len(keys) != 1 { return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys)) } buf := make([]byte, 128) if _, err := rand.Read(buf); err != nil { return fmt.Errorf("rand: %v", err) } sig, err := sshAgent.Sign(keys[0], buf) if err != nil { return fmt.Errorf("sign: %v", err) } if err := keys[0].Verify(buf, sig); err != nil { return fmt.Errorf("verify: %v", err) } return nil } func addKeyToAgent(key crypto.PrivateKey) error { sshAgent := NewKeyring() if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil { return fmt.Errorf("add: %v", err) } return verifyKey(sshAgent) } func TestKeyTypes(t *testing.T) { for k, v := range testPrivateKeys { if err := addKeyToAgent(v); err != nil { t.Errorf("error adding key type %s, %v", k, err) } if err := addCertToAgentSock(v, nil); err != nil { t.Errorf("error adding key type %s, %v", k, err) } } } func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error { a, b, err := netPipe() if err != nil { return err } agentServer := NewKeyring() go ServeAgent(agentServer, a) agentClient := NewClient(b) if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil { return fmt.Errorf("add: %v", err) } return verifyKey(agentClient) } func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error { sshAgent := NewKeyring() if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil { return fmt.Errorf("add: %v", err) } return verifyKey(sshAgent) } func TestCertTypes(t *testing.T) { for keyType, key := range testPublicKeys { cert := &ssh.Certificate{ ValidPrincipals: []string{"gopher1"}, ValidAfter: 0, ValidBefore: ssh.CertTimeInfinity, Key: key, Serial: 1, CertType: ssh.UserCert, SignatureKey: testPublicKeys["rsa"], Permissions: ssh.Permissions{ CriticalOptions: map[string]string{}, Extensions: map[string]string{}, }, } if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil { t.Fatalf("signcert: %v", err) } if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil { t.Fatalf("%v", err) } if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil { t.Fatalf("%v", err) } } } func TestParseConstraints(t *testing.T) { // Test LifetimeSecs var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()} lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg)) if err != nil { t.Fatalf("parseConstraints: %v", err) } if lifetimeSecs != msg.LifetimeSecs { t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs) } // Test ConfirmBeforeUse _, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm}) if err != nil { t.Fatalf("%v", err) } if !confirmBeforeUse { t.Error("got comfirmBeforeUse == false") } // Test ConstraintExtensions var data []byte var expect []ConstraintExtension for i := 0; i < 10; i++ { var ext = ConstraintExtension{ ExtensionName: fmt.Sprintf("name%d", i), ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)), } expect = append(expect, ext) if i%2 == 0 { data = append(data, agentConstrainExtension) } else { data = append(data, agentConstrainExtensionV00) } data = append(data, ssh.Marshal(ext)...) } _, _, extensions, err := parseConstraints(data) if err != nil { t.Fatalf("%v", err) } if !reflect.DeepEqual(expect, extensions) { t.Errorf("got extension %v, want %v", extensions, expect) } // Test Unknown Constraint _, _, _, err = parseConstraints([]byte{128}) if err == nil || !strings.Contains(err.Error(), "unknown constraint") { t.Errorf("unexpected error: %v", err) } }