...

Source file src/golang.org/x/crypto/ssh/agent/server_test.go

Documentation: golang.org/x/crypto/ssh/agent

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package agent
     6  
     7  import (
     8  	"crypto"
     9  	"crypto/rand"
    10  	"fmt"
    11  	pseudorand "math/rand"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  
    16  	"golang.org/x/crypto/ssh"
    17  )
    18  
    19  func TestServer(t *testing.T) {
    20  	c1, c2, err := netPipe()
    21  	if err != nil {
    22  		t.Fatalf("netPipe: %v", err)
    23  	}
    24  	defer c1.Close()
    25  	defer c2.Close()
    26  	client := NewClient(c1)
    27  
    28  	go ServeAgent(NewKeyring(), c2)
    29  
    30  	testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
    31  }
    32  
    33  func TestLockServer(t *testing.T) {
    34  	testLockAgent(NewKeyring(), t)
    35  }
    36  
    37  func TestSetupForwardAgent(t *testing.T) {
    38  	a, b, err := netPipe()
    39  	if err != nil {
    40  		t.Fatalf("netPipe: %v", err)
    41  	}
    42  
    43  	defer a.Close()
    44  	defer b.Close()
    45  
    46  	_, socket, cleanup := startOpenSSHAgent(t)
    47  	defer cleanup()
    48  
    49  	serverConf := ssh.ServerConfig{
    50  		NoClientAuth: true,
    51  	}
    52  	serverConf.AddHostKey(testSigners["rsa"])
    53  	incoming := make(chan *ssh.ServerConn, 1)
    54  	go func() {
    55  		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
    56  		incoming <- conn
    57  		if err != nil {
    58  			t.Errorf("NewServerConn error: %v", err)
    59  			return
    60  		}
    61  	}()
    62  
    63  	conf := ssh.ClientConfig{
    64  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    65  	}
    66  	conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
    67  	if err != nil {
    68  		t.Fatalf("NewClientConn: %v", err)
    69  	}
    70  	client := ssh.NewClient(conn, chans, reqs)
    71  
    72  	if err := ForwardToRemote(client, socket); err != nil {
    73  		t.Fatalf("SetupForwardAgent: %v", err)
    74  	}
    75  	server := <-incoming
    76  	if server == nil {
    77  		t.Fatal("Unable to get server")
    78  	}
    79  	ch, reqs, err := server.OpenChannel(channelType, nil)
    80  	if err != nil {
    81  		t.Fatalf("OpenChannel(%q): %v", channelType, err)
    82  	}
    83  	go ssh.DiscardRequests(reqs)
    84  
    85  	agentClient := NewClient(ch)
    86  	testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
    87  	conn.Close()
    88  }
    89  
    90  func TestV1ProtocolMessages(t *testing.T) {
    91  	c1, c2, err := netPipe()
    92  	if err != nil {
    93  		t.Fatalf("netPipe: %v", err)
    94  	}
    95  	defer c1.Close()
    96  	defer c2.Close()
    97  	c := NewClient(c1)
    98  
    99  	go ServeAgent(NewKeyring(), c2)
   100  
   101  	testV1ProtocolMessages(t, c.(*client))
   102  }
   103  
   104  func testV1ProtocolMessages(t *testing.T, c *client) {
   105  	reply, err := c.call([]byte{agentRequestV1Identities})
   106  	if err != nil {
   107  		t.Fatalf("v1 request all failed: %v", err)
   108  	}
   109  	if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
   110  		t.Fatalf("invalid request all response: %#v", reply)
   111  	}
   112  
   113  	reply, err = c.call([]byte{agentRemoveAllV1Identities})
   114  	if err != nil {
   115  		t.Fatalf("v1 remove all failed: %v", err)
   116  	}
   117  	if _, ok := reply.(*successAgentMsg); !ok {
   118  		t.Fatalf("invalid remove all response: %#v", reply)
   119  	}
   120  }
   121  
   122  func verifyKey(sshAgent Agent) error {
   123  	keys, err := sshAgent.List()
   124  	if err != nil {
   125  		return fmt.Errorf("listing keys: %v", err)
   126  	}
   127  
   128  	if len(keys) != 1 {
   129  		return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
   130  	}
   131  
   132  	buf := make([]byte, 128)
   133  	if _, err := rand.Read(buf); err != nil {
   134  		return fmt.Errorf("rand: %v", err)
   135  	}
   136  
   137  	sig, err := sshAgent.Sign(keys[0], buf)
   138  	if err != nil {
   139  		return fmt.Errorf("sign: %v", err)
   140  	}
   141  
   142  	if err := keys[0].Verify(buf, sig); err != nil {
   143  		return fmt.Errorf("verify: %v", err)
   144  	}
   145  	return nil
   146  }
   147  
   148  func addKeyToAgent(key crypto.PrivateKey) error {
   149  	sshAgent := NewKeyring()
   150  	if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
   151  		return fmt.Errorf("add: %v", err)
   152  	}
   153  	return verifyKey(sshAgent)
   154  }
   155  
   156  func TestKeyTypes(t *testing.T) {
   157  	for k, v := range testPrivateKeys {
   158  		if err := addKeyToAgent(v); err != nil {
   159  			t.Errorf("error adding key type %s, %v", k, err)
   160  		}
   161  		if err := addCertToAgentSock(v, nil); err != nil {
   162  			t.Errorf("error adding key type %s, %v", k, err)
   163  		}
   164  	}
   165  }
   166  
   167  func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
   168  	a, b, err := netPipe()
   169  	if err != nil {
   170  		return err
   171  	}
   172  	agentServer := NewKeyring()
   173  	go ServeAgent(agentServer, a)
   174  
   175  	agentClient := NewClient(b)
   176  	if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
   177  		return fmt.Errorf("add: %v", err)
   178  	}
   179  	return verifyKey(agentClient)
   180  }
   181  
   182  func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
   183  	sshAgent := NewKeyring()
   184  	if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
   185  		return fmt.Errorf("add: %v", err)
   186  	}
   187  	return verifyKey(sshAgent)
   188  }
   189  
   190  func TestCertTypes(t *testing.T) {
   191  	for keyType, key := range testPublicKeys {
   192  		cert := &ssh.Certificate{
   193  			ValidPrincipals: []string{"gopher1"},
   194  			ValidAfter:      0,
   195  			ValidBefore:     ssh.CertTimeInfinity,
   196  			Key:             key,
   197  			Serial:          1,
   198  			CertType:        ssh.UserCert,
   199  			SignatureKey:    testPublicKeys["rsa"],
   200  			Permissions: ssh.Permissions{
   201  				CriticalOptions: map[string]string{},
   202  				Extensions:      map[string]string{},
   203  			},
   204  		}
   205  		if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
   206  			t.Fatalf("signcert: %v", err)
   207  		}
   208  		if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
   209  			t.Fatalf("%v", err)
   210  		}
   211  		if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
   212  			t.Fatalf("%v", err)
   213  		}
   214  	}
   215  }
   216  
   217  func TestParseConstraints(t *testing.T) {
   218  	// Test LifetimeSecs
   219  	var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
   220  	lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
   221  	if err != nil {
   222  		t.Fatalf("parseConstraints: %v", err)
   223  	}
   224  	if lifetimeSecs != msg.LifetimeSecs {
   225  		t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
   226  	}
   227  
   228  	// Test ConfirmBeforeUse
   229  	_, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
   230  	if err != nil {
   231  		t.Fatalf("%v", err)
   232  	}
   233  	if !confirmBeforeUse {
   234  		t.Error("got comfirmBeforeUse == false")
   235  	}
   236  
   237  	// Test ConstraintExtensions
   238  	var data []byte
   239  	var expect []ConstraintExtension
   240  	for i := 0; i < 10; i++ {
   241  		var ext = ConstraintExtension{
   242  			ExtensionName:    fmt.Sprintf("name%d", i),
   243  			ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
   244  		}
   245  		expect = append(expect, ext)
   246  		if i%2 == 0 {
   247  			data = append(data, agentConstrainExtension)
   248  		} else {
   249  			data = append(data, agentConstrainExtensionV00)
   250  		}
   251  		data = append(data, ssh.Marshal(ext)...)
   252  	}
   253  	_, _, extensions, err := parseConstraints(data)
   254  	if err != nil {
   255  		t.Fatalf("%v", err)
   256  	}
   257  	if !reflect.DeepEqual(expect, extensions) {
   258  		t.Errorf("got extension %v, want %v", extensions, expect)
   259  	}
   260  
   261  	// Test Unknown Constraint
   262  	_, _, _, err = parseConstraints([]byte{128})
   263  	if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
   264  		t.Errorf("unexpected error: %v", err)
   265  	}
   266  }
   267  

View as plain text