...

Source file src/golang.org/x/crypto/ssh/agent/client_test.go

Documentation: golang.org/x/crypto/ssh/agent

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package agent
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"io"
    12  	"net"
    13  	"os"
    14  	"os/exec"
    15  	"path/filepath"
    16  	"runtime"
    17  	"strconv"
    18  	"strings"
    19  	"testing"
    20  	"time"
    21  
    22  	"golang.org/x/crypto/ssh"
    23  )
    24  
    25  // startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
    26  func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) {
    27  	if testing.Short() {
    28  		// ssh-agent is not always available, and the key
    29  		// types supported vary by platform.
    30  		t.Skip("skipping test due to -short")
    31  	}
    32  	if runtime.GOOS == "windows" {
    33  		t.Skip("skipping on windows, we don't support connecting to the ssh-agent via a named pipe")
    34  	}
    35  
    36  	bin, err := exec.LookPath("ssh-agent")
    37  	if err != nil {
    38  		t.Skip("could not find ssh-agent")
    39  	}
    40  
    41  	cmd := exec.Command(bin, "-s")
    42  	cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior.
    43  	cmd.Stderr = new(bytes.Buffer)
    44  	out, err := cmd.Output()
    45  	if err != nil {
    46  		t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
    47  	}
    48  
    49  	// Output looks like:
    50  	//
    51  	//	SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
    52  	//	SSH_AGENT_PID=15542; export SSH_AGENT_PID;
    53  	//	echo Agent pid 15542;
    54  
    55  	fields := bytes.Split(out, []byte(";"))
    56  	line := bytes.SplitN(fields[0], []byte("="), 2)
    57  	line[0] = bytes.TrimLeft(line[0], "\n")
    58  	if string(line[0]) != "SSH_AUTH_SOCK" {
    59  		t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
    60  	}
    61  	socket = string(line[1])
    62  
    63  	line = bytes.SplitN(fields[2], []byte("="), 2)
    64  	line[0] = bytes.TrimLeft(line[0], "\n")
    65  	if string(line[0]) != "SSH_AGENT_PID" {
    66  		t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
    67  	}
    68  	pidStr := line[1]
    69  	pid, err := strconv.Atoi(string(pidStr))
    70  	if err != nil {
    71  		t.Fatalf("Atoi(%q): %v", pidStr, err)
    72  	}
    73  
    74  	conn, err := net.Dial("unix", string(socket))
    75  	if err != nil {
    76  		t.Fatalf("net.Dial: %v", err)
    77  	}
    78  
    79  	ac := NewClient(conn)
    80  	return ac, socket, func() {
    81  		proc, _ := os.FindProcess(pid)
    82  		if proc != nil {
    83  			proc.Kill()
    84  		}
    85  		conn.Close()
    86  		os.RemoveAll(filepath.Dir(socket))
    87  	}
    88  }
    89  
    90  func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) {
    91  	c1, c2, err := netPipe()
    92  	if err != nil {
    93  		t.Fatalf("netPipe: %v", err)
    94  	}
    95  	go ServeAgent(agent, c2)
    96  
    97  	return NewClient(c1), func() {
    98  		c1.Close()
    99  		c2.Close()
   100  	}
   101  }
   102  
   103  // startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
   104  func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) {
   105  	return startAgent(t, NewKeyring())
   106  }
   107  
   108  func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
   109  	agent, _, cleanup := startOpenSSHAgent(t)
   110  	defer cleanup()
   111  
   112  	testAgentInterface(t, agent, key, cert, lifetimeSecs)
   113  }
   114  
   115  func testKeyringAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
   116  	agent, cleanup := startKeyringAgent(t)
   117  	defer cleanup()
   118  
   119  	testAgentInterface(t, agent, key, cert, lifetimeSecs)
   120  }
   121  
   122  func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
   123  	signer, err := ssh.NewSignerFromKey(key)
   124  	if err != nil {
   125  		t.Fatalf("NewSignerFromKey(%T): %v", key, err)
   126  	}
   127  	// The agent should start up empty.
   128  	if keys, err := agent.List(); err != nil {
   129  		t.Fatalf("RequestIdentities: %v", err)
   130  	} else if len(keys) > 0 {
   131  		t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
   132  	}
   133  
   134  	// Attempt to insert the key, with certificate if specified.
   135  	var pubKey ssh.PublicKey
   136  	if cert != nil {
   137  		err = agent.Add(AddedKey{
   138  			PrivateKey:   key,
   139  			Certificate:  cert,
   140  			Comment:      "comment",
   141  			LifetimeSecs: lifetimeSecs,
   142  		})
   143  		pubKey = cert
   144  	} else {
   145  		err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs})
   146  		pubKey = signer.PublicKey()
   147  	}
   148  	if err != nil {
   149  		t.Fatalf("insert(%T): %v", key, err)
   150  	}
   151  
   152  	// Did the key get inserted successfully?
   153  	if keys, err := agent.List(); err != nil {
   154  		t.Fatalf("List: %v", err)
   155  	} else if len(keys) != 1 {
   156  		t.Fatalf("got %v, want 1 key", keys)
   157  	} else if keys[0].Comment != "comment" {
   158  		t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
   159  	} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
   160  		t.Fatalf("key mismatch")
   161  	}
   162  
   163  	// Can the agent make a valid signature?
   164  	data := []byte("hello")
   165  	sig, err := agent.Sign(pubKey, data)
   166  	if err != nil {
   167  		t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
   168  	}
   169  
   170  	if err := pubKey.Verify(data, sig); err != nil {
   171  		t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
   172  	}
   173  
   174  	// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
   175  	if pubKey.Type() == "ssh-rsa" {
   176  		sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
   177  			sig, err = agent.SignWithFlags(pubKey, data, flag)
   178  			if err != nil {
   179  				t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err)
   180  			}
   181  			if sig.Format != expectedSigFormat {
   182  				t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat)
   183  			}
   184  			if err := pubKey.Verify(data, sig); err != nil {
   185  				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
   186  			}
   187  		}
   188  		sshFlagTest(0, ssh.KeyAlgoRSA)
   189  		sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
   190  		sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
   191  	}
   192  
   193  	// If the key has a lifetime, is it removed when it should be?
   194  	if lifetimeSecs > 0 {
   195  		time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond)
   196  		keys, err := agent.List()
   197  		if err != nil {
   198  			t.Fatalf("List: %v", err)
   199  		}
   200  		if len(keys) > 0 {
   201  			t.Fatalf("key not expired")
   202  		}
   203  	}
   204  
   205  }
   206  
   207  func TestMalformedRequests(t *testing.T) {
   208  	keyringAgent := NewKeyring()
   209  
   210  	testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
   211  		c, s := net.Pipe()
   212  		defer c.Close()
   213  		defer s.Close()
   214  		go func() {
   215  			_, err := c.Write(requestBytes)
   216  			if err != nil {
   217  				t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
   218  			}
   219  			c.Close()
   220  		}()
   221  		err := ServeAgent(keyringAgent, s)
   222  		if err == nil {
   223  			t.Error("ServeAgent should have returned an error to malformed input")
   224  		} else {
   225  			if (err != io.EOF) != wantServerErr {
   226  				t.Errorf("ServeAgent returned expected error: %v", err)
   227  			}
   228  		}
   229  	}
   230  
   231  	var testCases = []struct {
   232  		name          string
   233  		requestBytes  []byte
   234  		wantServerErr bool
   235  	}{
   236  		{"Empty request", []byte{}, false},
   237  		{"Short header", []byte{0x00}, true},
   238  		{"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
   239  		{"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
   240  	}
   241  	for _, tc := range testCases {
   242  		t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
   243  	}
   244  }
   245  
   246  func TestAgent(t *testing.T) {
   247  	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
   248  		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
   249  		testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
   250  	}
   251  }
   252  
   253  func TestCert(t *testing.T) {
   254  	cert := &ssh.Certificate{
   255  		Key:         testPublicKeys["rsa"],
   256  		ValidBefore: ssh.CertTimeInfinity,
   257  		CertType:    ssh.UserCert,
   258  	}
   259  	cert.SignCert(rand.Reader, testSigners["ecdsa"])
   260  
   261  	testOpenSSHAgent(t, testPrivateKeys["rsa"], cert, 0)
   262  	testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
   263  }
   264  
   265  // netListener creates a localhost network listener.
   266  func netListener() (net.Listener, error) {
   267  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   268  	if err != nil {
   269  		listener, err = net.Listen("tcp", "[::1]:0")
   270  		if err != nil {
   271  			return nil, err
   272  		}
   273  	}
   274  	return listener, nil
   275  }
   276  
   277  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
   278  // therefore is buffered (net.Pipe deadlocks if both sides start with
   279  // a write.)
   280  func netPipe() (net.Conn, net.Conn, error) {
   281  	listener, err := netListener()
   282  	if err != nil {
   283  		return nil, nil, err
   284  	}
   285  	defer listener.Close()
   286  	c1, err := net.Dial("tcp", listener.Addr().String())
   287  	if err != nil {
   288  		return nil, nil, err
   289  	}
   290  
   291  	c2, err := listener.Accept()
   292  	if err != nil {
   293  		c1.Close()
   294  		return nil, nil, err
   295  	}
   296  
   297  	return c1, c2, nil
   298  }
   299  
   300  func TestServerResponseTooLarge(t *testing.T) {
   301  	a, b, err := netPipe()
   302  	if err != nil {
   303  		t.Fatalf("netPipe: %v", err)
   304  	}
   305  	done := make(chan struct{})
   306  	defer func() { <-done }()
   307  
   308  	defer a.Close()
   309  	defer b.Close()
   310  
   311  	var response identitiesAnswerAgentMsg
   312  	response.NumKeys = 1
   313  	response.Keys = make([]byte, maxAgentResponseBytes+1)
   314  
   315  	agent := NewClient(a)
   316  	go func() {
   317  		defer close(done)
   318  		n, err := b.Write(ssh.Marshal(response))
   319  		if n < 4 {
   320  			if runtime.GOOS == "plan9" {
   321  				if e1, ok := err.(*net.OpError); ok {
   322  					if e2, ok := e1.Err.(*os.PathError); ok {
   323  						switch e2.Err.Error() {
   324  						case "Hangup", "i/o on hungup channel":
   325  							// syscall.Pwrite returns -1 in this case even when some data did get written.
   326  							return
   327  						}
   328  					}
   329  				}
   330  			}
   331  			t.Errorf("At least 4 bytes (the response size) should have been successfully written: %d < 4: %v", n, err)
   332  		}
   333  	}()
   334  	_, err = agent.List()
   335  	if err == nil {
   336  		t.Fatal("Did not get error result")
   337  	}
   338  	if err.Error() != "agent: client error: response too large" {
   339  		t.Fatal("Did not get expected error result")
   340  	}
   341  }
   342  
   343  func TestAuth(t *testing.T) {
   344  	agent, _, cleanup := startOpenSSHAgent(t)
   345  	defer cleanup()
   346  
   347  	a, b, err := netPipe()
   348  	if err != nil {
   349  		t.Fatalf("netPipe: %v", err)
   350  	}
   351  
   352  	defer a.Close()
   353  	defer b.Close()
   354  
   355  	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
   356  		t.Errorf("Add: %v", err)
   357  	}
   358  
   359  	serverConf := ssh.ServerConfig{}
   360  	serverConf.AddHostKey(testSigners["rsa"])
   361  	serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
   362  		if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
   363  			return nil, nil
   364  		}
   365  
   366  		return nil, errors.New("pubkey rejected")
   367  	}
   368  
   369  	go func() {
   370  		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
   371  		if err != nil {
   372  			t.Errorf("NewServerConn error: %v", err)
   373  			return
   374  		}
   375  		conn.Close()
   376  	}()
   377  
   378  	conf := ssh.ClientConfig{
   379  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
   380  	}
   381  	conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
   382  	conn, _, _, err := ssh.NewClientConn(b, "", &conf)
   383  	if err != nil {
   384  		t.Fatalf("NewClientConn: %v", err)
   385  	}
   386  	conn.Close()
   387  }
   388  
   389  func TestLockOpenSSHAgent(t *testing.T) {
   390  	agent, _, cleanup := startOpenSSHAgent(t)
   391  	defer cleanup()
   392  	testLockAgent(agent, t)
   393  }
   394  
   395  func TestLockKeyringAgent(t *testing.T) {
   396  	agent, cleanup := startKeyringAgent(t)
   397  	defer cleanup()
   398  	testLockAgent(agent, t)
   399  }
   400  
   401  func testLockAgent(agent Agent, t *testing.T) {
   402  	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
   403  		t.Errorf("Add: %v", err)
   404  	}
   405  	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil {
   406  		t.Errorf("Add: %v", err)
   407  	}
   408  	if keys, err := agent.List(); err != nil {
   409  		t.Errorf("List: %v", err)
   410  	} else if len(keys) != 2 {
   411  		t.Errorf("Want 2 keys, got %v", keys)
   412  	}
   413  
   414  	passphrase := []byte("secret")
   415  	if err := agent.Lock(passphrase); err != nil {
   416  		t.Errorf("Lock: %v", err)
   417  	}
   418  
   419  	if keys, err := agent.List(); err != nil {
   420  		t.Errorf("List: %v", err)
   421  	} else if len(keys) != 0 {
   422  		t.Errorf("Want 0 keys, got %v", keys)
   423  	}
   424  
   425  	signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
   426  	if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
   427  		t.Fatalf("Sign did not fail")
   428  	}
   429  
   430  	if err := agent.Remove(signer.PublicKey()); err == nil {
   431  		t.Fatalf("Remove did not fail")
   432  	}
   433  
   434  	if err := agent.RemoveAll(); err == nil {
   435  		t.Fatalf("RemoveAll did not fail")
   436  	}
   437  
   438  	if err := agent.Unlock(nil); err == nil {
   439  		t.Errorf("Unlock with wrong passphrase succeeded")
   440  	}
   441  	if err := agent.Unlock(passphrase); err != nil {
   442  		t.Errorf("Unlock: %v", err)
   443  	}
   444  
   445  	if err := agent.Remove(signer.PublicKey()); err != nil {
   446  		t.Fatalf("Remove: %v", err)
   447  	}
   448  
   449  	if keys, err := agent.List(); err != nil {
   450  		t.Errorf("List: %v", err)
   451  	} else if len(keys) != 1 {
   452  		t.Errorf("Want 1 keys, got %v", keys)
   453  	}
   454  }
   455  
   456  func testOpenSSHAgentLifetime(t *testing.T) {
   457  	agent, _, cleanup := startOpenSSHAgent(t)
   458  	defer cleanup()
   459  	testAgentLifetime(t, agent)
   460  }
   461  
   462  func testKeyringAgentLifetime(t *testing.T) {
   463  	agent, cleanup := startKeyringAgent(t)
   464  	defer cleanup()
   465  	testAgentLifetime(t, agent)
   466  }
   467  
   468  func testAgentLifetime(t *testing.T, agent Agent) {
   469  	for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
   470  		// Add private keys to the agent.
   471  		err := agent.Add(AddedKey{
   472  			PrivateKey:   testPrivateKeys[keyType],
   473  			Comment:      "comment",
   474  			LifetimeSecs: 1,
   475  		})
   476  		if err != nil {
   477  			t.Fatalf("add: %v", err)
   478  		}
   479  		// Add certs to the agent.
   480  		cert := &ssh.Certificate{
   481  			Key:         testPublicKeys[keyType],
   482  			ValidBefore: ssh.CertTimeInfinity,
   483  			CertType:    ssh.UserCert,
   484  		}
   485  		cert.SignCert(rand.Reader, testSigners[keyType])
   486  		err = agent.Add(AddedKey{
   487  			PrivateKey:   testPrivateKeys[keyType],
   488  			Certificate:  cert,
   489  			Comment:      "comment",
   490  			LifetimeSecs: 1,
   491  		})
   492  		if err != nil {
   493  			t.Fatalf("add: %v", err)
   494  		}
   495  	}
   496  	time.Sleep(1100 * time.Millisecond)
   497  	if keys, err := agent.List(); err != nil {
   498  		t.Errorf("List: %v", err)
   499  	} else if len(keys) != 0 {
   500  		t.Errorf("Want 0 keys, got %v", len(keys))
   501  	}
   502  }
   503  
   504  type keyringExtended struct {
   505  	*keyring
   506  }
   507  
   508  func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) {
   509  	if extensionType != "my-extension@example.com" {
   510  		return []byte{agentExtensionFailure}, nil
   511  	}
   512  	return append([]byte{agentSuccess}, contents...), nil
   513  }
   514  
   515  func TestAgentExtensions(t *testing.T) {
   516  	agent, _, cleanup := startOpenSSHAgent(t)
   517  	defer cleanup()
   518  	_, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
   519  	if err == nil {
   520  		t.Fatal("should have gotten agent extension failure")
   521  	}
   522  
   523  	agent, cleanup = startAgent(t, &keyringExtended{})
   524  	defer cleanup()
   525  	result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
   526  	if err != nil {
   527  		t.Fatalf("agent extension failure: %v", err)
   528  	}
   529  	if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) {
   530  		t.Fatalf("agent extension result invalid: %v", result)
   531  	}
   532  
   533  	_, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02})
   534  	if err == nil {
   535  		t.Fatal("should have gotten agent extension failure")
   536  	}
   537  }
   538  

View as plain text