1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "strings"
13 )
14
15 type authResult int
16
17 const (
18 authFailure authResult = iota
19 authPartialSuccess
20 authSuccess
21 )
22
23
24 func (c *connection) clientAuthenticate(config *ClientConfig) error {
25
26 if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
27 return err
28 }
29 packet, err := c.transport.readPacket()
30 if err != nil {
31 return err
32 }
33
34
35
36 extensions := make(map[string][]byte)
37 if len(packet) > 0 && packet[0] == msgExtInfo {
38 var extInfo extInfoMsg
39 if err := Unmarshal(packet, &extInfo); err != nil {
40 return err
41 }
42 payload := extInfo.Payload
43 for i := uint32(0); i < extInfo.NumExtensions; i++ {
44 name, rest, ok := parseString(payload)
45 if !ok {
46 return parseError(msgExtInfo)
47 }
48 value, rest, ok := parseString(rest)
49 if !ok {
50 return parseError(msgExtInfo)
51 }
52 extensions[string(name)] = value
53 payload = rest
54 }
55 packet, err = c.transport.readPacket()
56 if err != nil {
57 return err
58 }
59 }
60 var serviceAccept serviceAcceptMsg
61 if err := Unmarshal(packet, &serviceAccept); err != nil {
62 return err
63 }
64
65
66
67 var tried []string
68 var lastMethods []string
69
70 sessionID := c.transport.getSessionID()
71 for auth := AuthMethod(new(noneAuth)); auth != nil; {
72 ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
73 if err != nil {
74
75
76 ok = authFailure
77 }
78 if ok == authSuccess {
79
80 return nil
81 } else if ok == authFailure {
82 if m := auth.method(); !contains(tried, m) {
83 tried = append(tried, m)
84 }
85 }
86 if methods == nil {
87 methods = lastMethods
88 }
89 lastMethods = methods
90
91 auth = nil
92
93 findNext:
94 for _, a := range config.Auth {
95 candidateMethod := a.method()
96 if contains(tried, candidateMethod) {
97 continue
98 }
99 for _, meth := range methods {
100 if meth == candidateMethod {
101 auth = a
102 break findNext
103 }
104 }
105 }
106
107 if auth == nil && err != nil {
108
109
110 return err
111 }
112 }
113 return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
114 }
115
116 func contains(list []string, e string) bool {
117 for _, s := range list {
118 if s == e {
119 return true
120 }
121 }
122 return false
123 }
124
125
126 type AuthMethod interface {
127
128
129
130
131
132 auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
133
134
135 method() string
136 }
137
138
139 type noneAuth int
140
141 func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
142 if err := c.writePacket(Marshal(&userAuthRequestMsg{
143 User: user,
144 Service: serviceSSH,
145 Method: "none",
146 })); err != nil {
147 return authFailure, nil, err
148 }
149
150 return handleAuthResponse(c)
151 }
152
153 func (n *noneAuth) method() string {
154 return "none"
155 }
156
157
158
159 type passwordCallback func() (password string, err error)
160
161 func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
162 type passwordAuthMsg struct {
163 User string `sshtype:"50"`
164 Service string
165 Method string
166 Reply bool
167 Password string
168 }
169
170 pw, err := cb()
171
172
173
174 if err != nil {
175 return authFailure, nil, err
176 }
177
178 if err := c.writePacket(Marshal(&passwordAuthMsg{
179 User: user,
180 Service: serviceSSH,
181 Method: cb.method(),
182 Reply: false,
183 Password: pw,
184 })); err != nil {
185 return authFailure, nil, err
186 }
187
188 return handleAuthResponse(c)
189 }
190
191 func (cb passwordCallback) method() string {
192 return "password"
193 }
194
195
196 func Password(secret string) AuthMethod {
197 return passwordCallback(func() (string, error) { return secret, nil })
198 }
199
200
201
202 func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
203 return passwordCallback(prompt)
204 }
205
206 type publickeyAuthMsg struct {
207 User string `sshtype:"50"`
208 Service string
209 Method string
210
211
212 HasSig bool
213 Algoname string
214 PubKey []byte
215
216
217 Sig []byte `ssh:"rest"`
218 }
219
220
221
222 type publicKeyCallback func() ([]Signer, error)
223
224 func (cb publicKeyCallback) method() string {
225 return "publickey"
226 }
227
228 func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
229 var as MultiAlgorithmSigner
230 keyFormat := signer.PublicKey().Type()
231
232
233
234
235 switch s := signer.(type) {
236 case MultiAlgorithmSigner:
237 as = s
238 case AlgorithmSigner:
239 as = &multiAlgorithmSigner{
240 AlgorithmSigner: s,
241 supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
242 }
243 default:
244 as = &multiAlgorithmSigner{
245 AlgorithmSigner: algorithmSignerWrapper{signer},
246 supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
247 }
248 }
249
250 getFallbackAlgo := func() (string, error) {
251
252
253
254 if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
255 return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
256 underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
257 }
258 return keyFormat, nil
259 }
260
261 extPayload, ok := extensions["server-sig-algs"]
262 if !ok {
263
264
265 algo, err := getFallbackAlgo()
266 return as, algo, err
267 }
268
269
270
271
272
273 serverAlgos := strings.Split(string(extPayload), ",")
274 for _, algo := range serverAlgos {
275 if certAlgo, ok := certificateAlgo(algo); ok {
276 serverAlgos = append(serverAlgos, certAlgo)
277 }
278 }
279
280
281 var keyAlgos []string
282 for _, algo := range algorithmsForKeyFormat(keyFormat) {
283 if contains(as.Algorithms(), underlyingAlgo(algo)) {
284 keyAlgos = append(keyAlgos, algo)
285 }
286 }
287
288 algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
289 if err != nil {
290
291
292 algo, err := getFallbackAlgo()
293 return as, algo, err
294 }
295 return as, algo, nil
296 }
297
298 func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
299
300
301
302
303
304 signers, err := cb()
305 if err != nil {
306 return authFailure, nil, err
307 }
308 var methods []string
309 var errSigAlgo error
310
311 origSignersLen := len(signers)
312 for idx := 0; idx < len(signers); idx++ {
313 signer := signers[idx]
314 pub := signer.PublicKey()
315 as, algo, err := pickSignatureAlgorithm(signer, extensions)
316 if err != nil && errSigAlgo == nil {
317
318
319
320 errSigAlgo = err
321 continue
322 }
323 ok, err := validateKey(pub, algo, user, c)
324 if err != nil {
325 return authFailure, nil, err
326 }
327
328
329
330
331
332 if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
333 if contains(as.Algorithms(), KeyAlgoRSA) {
334
335
336 signers = append(signers, &multiAlgorithmSigner{
337 AlgorithmSigner: as,
338 supportedAlgorithms: []string{KeyAlgoRSA},
339 })
340 }
341 }
342 if !ok {
343 continue
344 }
345
346 pubKey := pub.Marshal()
347 data := buildDataSignedForAuth(session, userAuthRequestMsg{
348 User: user,
349 Service: serviceSSH,
350 Method: cb.method(),
351 }, algo, pubKey)
352 sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
353 if err != nil {
354 return authFailure, nil, err
355 }
356
357
358 s := Marshal(sign)
359 sig := make([]byte, stringLength(len(s)))
360 marshalString(sig, s)
361 msg := publickeyAuthMsg{
362 User: user,
363 Service: serviceSSH,
364 Method: cb.method(),
365 HasSig: true,
366 Algoname: algo,
367 PubKey: pubKey,
368 Sig: sig,
369 }
370 p := Marshal(&msg)
371 if err := c.writePacket(p); err != nil {
372 return authFailure, nil, err
373 }
374 var success authResult
375 success, methods, err = handleAuthResponse(c)
376 if err != nil {
377 return authFailure, nil, err
378 }
379
380
381
382
383
384 if success == authSuccess || !contains(methods, cb.method()) {
385 return success, methods, err
386 }
387 }
388
389 return authFailure, methods, errSigAlgo
390 }
391
392
393 func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
394 pubKey := key.Marshal()
395 msg := publickeyAuthMsg{
396 User: user,
397 Service: serviceSSH,
398 Method: "publickey",
399 HasSig: false,
400 Algoname: algo,
401 PubKey: pubKey,
402 }
403 if err := c.writePacket(Marshal(&msg)); err != nil {
404 return false, err
405 }
406
407 return confirmKeyAck(key, algo, c)
408 }
409
410 func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
411 pubKey := key.Marshal()
412
413 for {
414 packet, err := c.readPacket()
415 if err != nil {
416 return false, err
417 }
418 switch packet[0] {
419 case msgUserAuthBanner:
420 if err := handleBannerResponse(c, packet); err != nil {
421 return false, err
422 }
423 case msgUserAuthPubKeyOk:
424 var msg userAuthPubKeyOkMsg
425 if err := Unmarshal(packet, &msg); err != nil {
426 return false, err
427 }
428 if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
429 return false, nil
430 }
431 return true, nil
432 case msgUserAuthFailure:
433 return false, nil
434 default:
435 return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
436 }
437 }
438 }
439
440
441
442 func PublicKeys(signers ...Signer) AuthMethod {
443 return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
444 }
445
446
447
448 func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
449 return publicKeyCallback(getSigners)
450 }
451
452
453
454
455 func handleAuthResponse(c packetConn) (authResult, []string, error) {
456 gotMsgExtInfo := false
457 for {
458 packet, err := c.readPacket()
459 if err != nil {
460 return authFailure, nil, err
461 }
462
463 switch packet[0] {
464 case msgUserAuthBanner:
465 if err := handleBannerResponse(c, packet); err != nil {
466 return authFailure, nil, err
467 }
468 case msgExtInfo:
469
470 if gotMsgExtInfo {
471 return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
472 }
473 gotMsgExtInfo = true
474 case msgUserAuthFailure:
475 var msg userAuthFailureMsg
476 if err := Unmarshal(packet, &msg); err != nil {
477 return authFailure, nil, err
478 }
479 if msg.PartialSuccess {
480 return authPartialSuccess, msg.Methods, nil
481 }
482 return authFailure, msg.Methods, nil
483 case msgUserAuthSuccess:
484 return authSuccess, nil, nil
485 default:
486 return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
487 }
488 }
489 }
490
491 func handleBannerResponse(c packetConn, packet []byte) error {
492 var msg userAuthBannerMsg
493 if err := Unmarshal(packet, &msg); err != nil {
494 return err
495 }
496
497 transport, ok := c.(*handshakeTransport)
498 if !ok {
499 return nil
500 }
501
502 if transport.bannerCallback != nil {
503 return transport.bannerCallback(msg.Message)
504 }
505
506 return nil
507 }
508
509
510
511
512
513
514
515
516 type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error)
517
518
519
520 func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
521 return challenge
522 }
523
524 func (cb KeyboardInteractiveChallenge) method() string {
525 return "keyboard-interactive"
526 }
527
528 func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
529 type initiateMsg struct {
530 User string `sshtype:"50"`
531 Service string
532 Method string
533 Language string
534 Submethods string
535 }
536
537 if err := c.writePacket(Marshal(&initiateMsg{
538 User: user,
539 Service: serviceSSH,
540 Method: "keyboard-interactive",
541 })); err != nil {
542 return authFailure, nil, err
543 }
544
545 gotMsgExtInfo := false
546 for {
547 packet, err := c.readPacket()
548 if err != nil {
549 return authFailure, nil, err
550 }
551
552
553 switch packet[0] {
554 case msgUserAuthBanner:
555 if err := handleBannerResponse(c, packet); err != nil {
556 return authFailure, nil, err
557 }
558 continue
559 case msgExtInfo:
560
561 if gotMsgExtInfo {
562 return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
563 }
564 gotMsgExtInfo = true
565 continue
566 case msgUserAuthInfoRequest:
567
568 case msgUserAuthFailure:
569 var msg userAuthFailureMsg
570 if err := Unmarshal(packet, &msg); err != nil {
571 return authFailure, nil, err
572 }
573 if msg.PartialSuccess {
574 return authPartialSuccess, msg.Methods, nil
575 }
576 return authFailure, msg.Methods, nil
577 case msgUserAuthSuccess:
578 return authSuccess, nil, nil
579 default:
580 return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
581 }
582
583 var msg userAuthInfoRequestMsg
584 if err := Unmarshal(packet, &msg); err != nil {
585 return authFailure, nil, err
586 }
587
588
589 rest := msg.Prompts
590 var prompts []string
591 var echos []bool
592 for i := 0; i < int(msg.NumPrompts); i++ {
593 prompt, r, ok := parseString(rest)
594 if !ok || len(r) == 0 {
595 return authFailure, nil, errors.New("ssh: prompt format error")
596 }
597 prompts = append(prompts, string(prompt))
598 echos = append(echos, r[0] != 0)
599 rest = r[1:]
600 }
601
602 if len(rest) != 0 {
603 return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
604 }
605
606 answers, err := cb(msg.Name, msg.Instruction, prompts, echos)
607 if err != nil {
608 return authFailure, nil, err
609 }
610
611 if len(answers) != len(prompts) {
612 return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts))
613 }
614 responseLength := 1 + 4
615 for _, a := range answers {
616 responseLength += stringLength(len(a))
617 }
618 serialized := make([]byte, responseLength)
619 p := serialized
620 p[0] = msgUserAuthInfoResponse
621 p = p[1:]
622 p = marshalUint32(p, uint32(len(answers)))
623 for _, a := range answers {
624 p = marshalString(p, []byte(a))
625 }
626
627 if err := c.writePacket(serialized); err != nil {
628 return authFailure, nil, err
629 }
630 }
631 }
632
633 type retryableAuthMethod struct {
634 authMethod AuthMethod
635 maxTries int
636 }
637
638 func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
639 for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
640 ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
641 if ok != authFailure || err != nil {
642 return ok, methods, err
643 }
644 }
645 return ok, methods, err
646 }
647
648 func (r *retryableAuthMethod) method() string {
649 return r.authMethod.method()
650 }
651
652
653
654
655
656
657
658
659
660
661
662
663 func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod {
664 return &retryableAuthMethod{authMethod: auth, maxTries: maxTries}
665 }
666
667
668
669
670
671 func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod {
672 if gssAPIClient == nil {
673 panic("gss-api client must be not nil with enable gssapi-with-mic")
674 }
675 return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target}
676 }
677
678 type gssAPIWithMICCallback struct {
679 gssAPIClient GSSAPIClient
680 target string
681 }
682
683 func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
684 m := &userAuthRequestMsg{
685 User: user,
686 Service: serviceSSH,
687 Method: g.method(),
688 }
689
690
691 m.Payload = appendU32(m.Payload, 1)
692 m.Payload = appendString(m.Payload, string(krb5OID))
693 if err := c.writePacket(Marshal(m)); err != nil {
694 return authFailure, nil, err
695 }
696
697
698
699
700
701
702 packet, err := c.readPacket()
703 if err != nil {
704 return authFailure, nil, err
705 }
706 userAuthGSSAPIResp := &userAuthGSSAPIResponse{}
707 if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil {
708 return authFailure, nil, err
709 }
710
711
712 var token []byte
713 defer g.gssAPIClient.DeleteSecContext()
714 for {
715
716 nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false)
717 if err != nil {
718 return authFailure, nil, err
719 }
720 if len(nextToken) > 0 {
721 if err := c.writePacket(Marshal(&userAuthGSSAPIToken{
722 Token: nextToken,
723 })); err != nil {
724 return authFailure, nil, err
725 }
726 }
727 if !needContinue {
728 break
729 }
730 packet, err = c.readPacket()
731 if err != nil {
732 return authFailure, nil, err
733 }
734 switch packet[0] {
735 case msgUserAuthFailure:
736 var msg userAuthFailureMsg
737 if err := Unmarshal(packet, &msg); err != nil {
738 return authFailure, nil, err
739 }
740 if msg.PartialSuccess {
741 return authPartialSuccess, msg.Methods, nil
742 }
743 return authFailure, msg.Methods, nil
744 case msgUserAuthGSSAPIError:
745 userAuthGSSAPIErrorResp := &userAuthGSSAPIError{}
746 if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil {
747 return authFailure, nil, err
748 }
749 return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+
750 "Major Status: %d\n"+
751 "Minor Status: %d\n"+
752 "Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus,
753 userAuthGSSAPIErrorResp.Message)
754 case msgUserAuthGSSAPIToken:
755 userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
756 if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
757 return authFailure, nil, err
758 }
759 token = userAuthGSSAPITokenReq.Token
760 }
761 }
762
763
764 micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic")
765 micToken, err := g.gssAPIClient.GetMIC(micField)
766 if err != nil {
767 return authFailure, nil, err
768 }
769 if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{
770 MIC: micToken,
771 })); err != nil {
772 return authFailure, nil, err
773 }
774 return handleAuthResponse(c)
775 }
776
777 func (g *gssAPIWithMICCallback) method() string {
778 return "gssapi-with-mic"
779 }
780
View as plain text