1
2
3
4
5 package agent
6
7 import (
8 "crypto"
9 "crypto/rand"
10 "fmt"
11 pseudorand "math/rand"
12 "reflect"
13 "strings"
14 "testing"
15
16 "golang.org/x/crypto/ssh"
17 )
18
19 func TestServer(t *testing.T) {
20 c1, c2, err := netPipe()
21 if err != nil {
22 t.Fatalf("netPipe: %v", err)
23 }
24 defer c1.Close()
25 defer c2.Close()
26 client := NewClient(c1)
27
28 go ServeAgent(NewKeyring(), c2)
29
30 testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
31 }
32
33 func TestLockServer(t *testing.T) {
34 testLockAgent(NewKeyring(), t)
35 }
36
37 func TestSetupForwardAgent(t *testing.T) {
38 a, b, err := netPipe()
39 if err != nil {
40 t.Fatalf("netPipe: %v", err)
41 }
42
43 defer a.Close()
44 defer b.Close()
45
46 _, socket, cleanup := startOpenSSHAgent(t)
47 defer cleanup()
48
49 serverConf := ssh.ServerConfig{
50 NoClientAuth: true,
51 }
52 serverConf.AddHostKey(testSigners["rsa"])
53 incoming := make(chan *ssh.ServerConn, 1)
54 go func() {
55 conn, _, _, err := ssh.NewServerConn(a, &serverConf)
56 incoming <- conn
57 if err != nil {
58 t.Errorf("NewServerConn error: %v", err)
59 return
60 }
61 }()
62
63 conf := ssh.ClientConfig{
64 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
65 }
66 conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
67 if err != nil {
68 t.Fatalf("NewClientConn: %v", err)
69 }
70 client := ssh.NewClient(conn, chans, reqs)
71
72 if err := ForwardToRemote(client, socket); err != nil {
73 t.Fatalf("SetupForwardAgent: %v", err)
74 }
75 server := <-incoming
76 if server == nil {
77 t.Fatal("Unable to get server")
78 }
79 ch, reqs, err := server.OpenChannel(channelType, nil)
80 if err != nil {
81 t.Fatalf("OpenChannel(%q): %v", channelType, err)
82 }
83 go ssh.DiscardRequests(reqs)
84
85 agentClient := NewClient(ch)
86 testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
87 conn.Close()
88 }
89
90 func TestV1ProtocolMessages(t *testing.T) {
91 c1, c2, err := netPipe()
92 if err != nil {
93 t.Fatalf("netPipe: %v", err)
94 }
95 defer c1.Close()
96 defer c2.Close()
97 c := NewClient(c1)
98
99 go ServeAgent(NewKeyring(), c2)
100
101 testV1ProtocolMessages(t, c.(*client))
102 }
103
104 func testV1ProtocolMessages(t *testing.T, c *client) {
105 reply, err := c.call([]byte{agentRequestV1Identities})
106 if err != nil {
107 t.Fatalf("v1 request all failed: %v", err)
108 }
109 if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
110 t.Fatalf("invalid request all response: %#v", reply)
111 }
112
113 reply, err = c.call([]byte{agentRemoveAllV1Identities})
114 if err != nil {
115 t.Fatalf("v1 remove all failed: %v", err)
116 }
117 if _, ok := reply.(*successAgentMsg); !ok {
118 t.Fatalf("invalid remove all response: %#v", reply)
119 }
120 }
121
122 func verifyKey(sshAgent Agent) error {
123 keys, err := sshAgent.List()
124 if err != nil {
125 return fmt.Errorf("listing keys: %v", err)
126 }
127
128 if len(keys) != 1 {
129 return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
130 }
131
132 buf := make([]byte, 128)
133 if _, err := rand.Read(buf); err != nil {
134 return fmt.Errorf("rand: %v", err)
135 }
136
137 sig, err := sshAgent.Sign(keys[0], buf)
138 if err != nil {
139 return fmt.Errorf("sign: %v", err)
140 }
141
142 if err := keys[0].Verify(buf, sig); err != nil {
143 return fmt.Errorf("verify: %v", err)
144 }
145 return nil
146 }
147
148 func addKeyToAgent(key crypto.PrivateKey) error {
149 sshAgent := NewKeyring()
150 if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
151 return fmt.Errorf("add: %v", err)
152 }
153 return verifyKey(sshAgent)
154 }
155
156 func TestKeyTypes(t *testing.T) {
157 for k, v := range testPrivateKeys {
158 if err := addKeyToAgent(v); err != nil {
159 t.Errorf("error adding key type %s, %v", k, err)
160 }
161 if err := addCertToAgentSock(v, nil); err != nil {
162 t.Errorf("error adding key type %s, %v", k, err)
163 }
164 }
165 }
166
167 func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
168 a, b, err := netPipe()
169 if err != nil {
170 return err
171 }
172 agentServer := NewKeyring()
173 go ServeAgent(agentServer, a)
174
175 agentClient := NewClient(b)
176 if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
177 return fmt.Errorf("add: %v", err)
178 }
179 return verifyKey(agentClient)
180 }
181
182 func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
183 sshAgent := NewKeyring()
184 if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
185 return fmt.Errorf("add: %v", err)
186 }
187 return verifyKey(sshAgent)
188 }
189
190 func TestCertTypes(t *testing.T) {
191 for keyType, key := range testPublicKeys {
192 cert := &ssh.Certificate{
193 ValidPrincipals: []string{"gopher1"},
194 ValidAfter: 0,
195 ValidBefore: ssh.CertTimeInfinity,
196 Key: key,
197 Serial: 1,
198 CertType: ssh.UserCert,
199 SignatureKey: testPublicKeys["rsa"],
200 Permissions: ssh.Permissions{
201 CriticalOptions: map[string]string{},
202 Extensions: map[string]string{},
203 },
204 }
205 if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
206 t.Fatalf("signcert: %v", err)
207 }
208 if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
209 t.Fatalf("%v", err)
210 }
211 if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
212 t.Fatalf("%v", err)
213 }
214 }
215 }
216
217 func TestParseConstraints(t *testing.T) {
218
219 var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
220 lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
221 if err != nil {
222 t.Fatalf("parseConstraints: %v", err)
223 }
224 if lifetimeSecs != msg.LifetimeSecs {
225 t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
226 }
227
228
229 _, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
230 if err != nil {
231 t.Fatalf("%v", err)
232 }
233 if !confirmBeforeUse {
234 t.Error("got comfirmBeforeUse == false")
235 }
236
237
238 var data []byte
239 var expect []ConstraintExtension
240 for i := 0; i < 10; i++ {
241 var ext = ConstraintExtension{
242 ExtensionName: fmt.Sprintf("name%d", i),
243 ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
244 }
245 expect = append(expect, ext)
246 if i%2 == 0 {
247 data = append(data, agentConstrainExtension)
248 } else {
249 data = append(data, agentConstrainExtensionV00)
250 }
251 data = append(data, ssh.Marshal(ext)...)
252 }
253 _, _, extensions, err := parseConstraints(data)
254 if err != nil {
255 t.Fatalf("%v", err)
256 }
257 if !reflect.DeepEqual(expect, extensions) {
258 t.Errorf("got extension %v, want %v", extensions, expect)
259 }
260
261
262 _, _, _, err = parseConstraints([]byte{128})
263 if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
264 t.Errorf("unexpected error: %v", err)
265 }
266 }
267
View as plain text