1
2
3
4
5 package ssh
6
7 import (
8 "crypto/rand"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "net"
14 "strings"
15 "sync"
16 )
17
18
19
20
21 const debugHandshake = false
22
23
24
25
26 const chanSize = 16
27
28
29
30
31 type keyingTransport interface {
32 packetConn
33
34
35
36
37 prepareKeyChange(*algorithms, *kexResult) error
38
39
40
41
42
43 setStrictMode() error
44
45
46
47 setInitialKEXDone()
48 }
49
50
51
52 type handshakeTransport struct {
53 conn keyingTransport
54 config *Config
55
56 serverVersion []byte
57 clientVersion []byte
58
59
60
61
62 hostKeys []Signer
63
64
65
66 publicKeyAuthAlgorithms []string
67
68
69
70 hostKeyAlgorithms []string
71
72
73 incoming chan []byte
74 readError error
75
76 mu sync.Mutex
77 writeError error
78 sentInitPacket []byte
79 sentInitMsg *kexInitMsg
80 pendingPackets [][]byte
81 writePacketsLeft uint32
82 writeBytesLeft int64
83
84
85
86
87 requestKex chan struct{}
88
89
90
91 startKex chan *pendingKex
92 kexLoopDone chan struct{}
93
94
95 hostKeyCallback HostKeyCallback
96 dialAddress string
97 remoteAddr net.Addr
98
99
100
101
102 bannerCallback BannerCallback
103
104
105 algorithms *algorithms
106
107
108 readPacketsLeft uint32
109 readBytesLeft int64
110
111
112 sessionID []byte
113
114
115
116 strictMode bool
117 }
118
119 type pendingKex struct {
120 otherInit []byte
121 done chan error
122 }
123
124 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
125 t := &handshakeTransport{
126 conn: conn,
127 serverVersion: serverVersion,
128 clientVersion: clientVersion,
129 incoming: make(chan []byte, chanSize),
130 requestKex: make(chan struct{}, 1),
131 startKex: make(chan *pendingKex),
132 kexLoopDone: make(chan struct{}),
133
134 config: config,
135 }
136 t.resetReadThresholds()
137 t.resetWriteThresholds()
138
139
140 t.requestKex <- struct{}{}
141 return t
142 }
143
144 func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
145 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
146 t.dialAddress = dialAddr
147 t.remoteAddr = addr
148 t.hostKeyCallback = config.HostKeyCallback
149 t.bannerCallback = config.BannerCallback
150 if config.HostKeyAlgorithms != nil {
151 t.hostKeyAlgorithms = config.HostKeyAlgorithms
152 } else {
153 t.hostKeyAlgorithms = supportedHostKeyAlgos
154 }
155 go t.readLoop()
156 go t.kexLoop()
157 return t
158 }
159
160 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
161 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
162 t.hostKeys = config.hostKeys
163 t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
164 go t.readLoop()
165 go t.kexLoop()
166 return t
167 }
168
169 func (t *handshakeTransport) getSessionID() []byte {
170 return t.sessionID
171 }
172
173
174
175 func (t *handshakeTransport) waitSession() error {
176 p, err := t.readPacket()
177 if err != nil {
178 return err
179 }
180 if p[0] != msgNewKeys {
181 return fmt.Errorf("ssh: first packet should be msgNewKeys")
182 }
183
184 return nil
185 }
186
187 func (t *handshakeTransport) id() string {
188 if len(t.hostKeys) > 0 {
189 return "server"
190 }
191 return "client"
192 }
193
194 func (t *handshakeTransport) printPacket(p []byte, write bool) {
195 action := "got"
196 if write {
197 action = "sent"
198 }
199
200 if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
201 log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
202 } else {
203 msg, err := decode(p)
204 log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
205 }
206 }
207
208 func (t *handshakeTransport) readPacket() ([]byte, error) {
209 p, ok := <-t.incoming
210 if !ok {
211 return nil, t.readError
212 }
213 return p, nil
214 }
215
216 func (t *handshakeTransport) readLoop() {
217 first := true
218 for {
219 p, err := t.readOnePacket(first)
220 first = false
221 if err != nil {
222 t.readError = err
223 close(t.incoming)
224 break
225 }
226
227
228
229 if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
230 continue
231 }
232 t.incoming <- p
233 }
234
235
236 t.recordWriteError(t.readError)
237
238
239 close(t.startKex)
240
241
242 }
243
244 func (t *handshakeTransport) pushPacket(p []byte) error {
245 if debugHandshake {
246 t.printPacket(p, true)
247 }
248 return t.conn.writePacket(p)
249 }
250
251 func (t *handshakeTransport) getWriteError() error {
252 t.mu.Lock()
253 defer t.mu.Unlock()
254 return t.writeError
255 }
256
257 func (t *handshakeTransport) recordWriteError(err error) {
258 t.mu.Lock()
259 defer t.mu.Unlock()
260 if t.writeError == nil && err != nil {
261 t.writeError = err
262 }
263 }
264
265 func (t *handshakeTransport) requestKeyExchange() {
266 select {
267 case t.requestKex <- struct{}{}:
268 default:
269
270 }
271 }
272
273 func (t *handshakeTransport) resetWriteThresholds() {
274 t.writePacketsLeft = packetRekeyThreshold
275 if t.config.RekeyThreshold > 0 {
276 t.writeBytesLeft = int64(t.config.RekeyThreshold)
277 } else if t.algorithms != nil {
278 t.writeBytesLeft = t.algorithms.w.rekeyBytes()
279 } else {
280 t.writeBytesLeft = 1 << 30
281 }
282 }
283
284 func (t *handshakeTransport) kexLoop() {
285
286 write:
287 for t.getWriteError() == nil {
288 var request *pendingKex
289 var sent bool
290
291 for request == nil || !sent {
292 var ok bool
293 select {
294 case request, ok = <-t.startKex:
295 if !ok {
296 break write
297 }
298 case <-t.requestKex:
299 break
300 }
301
302 if !sent {
303 if err := t.sendKexInit(); err != nil {
304 t.recordWriteError(err)
305 break
306 }
307 sent = true
308 }
309 }
310
311 if err := t.getWriteError(); err != nil {
312 if request != nil {
313 request.done <- err
314 }
315 break
316 }
317
318
319
320
321
322
323
324
325
326 err := t.enterKeyExchange(request.otherInit)
327
328 t.mu.Lock()
329 t.writeError = err
330 t.sentInitPacket = nil
331 t.sentInitMsg = nil
332
333 t.resetWriteThresholds()
334
335
336
337
338
339
340
341 clear:
342 for {
343 select {
344 case <-t.requestKex:
345
346 default:
347 break clear
348 }
349 }
350
351 request.done <- t.writeError
352
353
354
355
356
357
358 for _, p := range t.pendingPackets {
359 t.writeError = t.pushPacket(p)
360 if t.writeError != nil {
361 break
362 }
363 }
364 t.pendingPackets = t.pendingPackets[:0]
365 t.mu.Unlock()
366 }
367
368
369 t.conn.Close()
370
371
372
373 for request := range t.startKex {
374 request.done <- t.getWriteError()
375 }
376
377
378 close(t.kexLoopDone)
379 }
380
381
382
383
384
385
386 const packetRekeyThreshold = (1 << 31)
387
388 func (t *handshakeTransport) resetReadThresholds() {
389 t.readPacketsLeft = packetRekeyThreshold
390 if t.config.RekeyThreshold > 0 {
391 t.readBytesLeft = int64(t.config.RekeyThreshold)
392 } else if t.algorithms != nil {
393 t.readBytesLeft = t.algorithms.r.rekeyBytes()
394 } else {
395 t.readBytesLeft = 1 << 30
396 }
397 }
398
399 func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
400 p, err := t.conn.readPacket()
401 if err != nil {
402 return nil, err
403 }
404
405 if t.readPacketsLeft > 0 {
406 t.readPacketsLeft--
407 } else {
408 t.requestKeyExchange()
409 }
410
411 if t.readBytesLeft > 0 {
412 t.readBytesLeft -= int64(len(p))
413 } else {
414 t.requestKeyExchange()
415 }
416
417 if debugHandshake {
418 t.printPacket(p, false)
419 }
420
421 if first && p[0] != msgKexInit {
422 return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
423 }
424
425 if p[0] != msgKexInit {
426 return p, nil
427 }
428
429 firstKex := t.sessionID == nil
430
431 kex := pendingKex{
432 done: make(chan error, 1),
433 otherInit: p,
434 }
435 t.startKex <- &kex
436 err = <-kex.done
437
438 if debugHandshake {
439 log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
440 }
441
442 if err != nil {
443 return nil, err
444 }
445
446 t.resetReadThresholds()
447
448
449
450 successPacket := []byte{msgIgnore}
451 if firstKex {
452
453
454
455 successPacket = []byte{msgNewKeys}
456 }
457
458 return successPacket, nil
459 }
460
461 const (
462 kexStrictClient = "kex-strict-c-v00@openssh.com"
463 kexStrictServer = "kex-strict-s-v00@openssh.com"
464 )
465
466
467 func (t *handshakeTransport) sendKexInit() error {
468 t.mu.Lock()
469 defer t.mu.Unlock()
470 if t.sentInitMsg != nil {
471
472
473
474
475 return nil
476 }
477
478 msg := &kexInitMsg{
479 CiphersClientServer: t.config.Ciphers,
480 CiphersServerClient: t.config.Ciphers,
481 MACsClientServer: t.config.MACs,
482 MACsServerClient: t.config.MACs,
483 CompressionClientServer: supportedCompressions,
484 CompressionServerClient: supportedCompressions,
485 }
486 io.ReadFull(rand.Reader, msg.Cookie[:])
487
488
489
490
491
492 msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2)
493 msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
494
495 isServer := len(t.hostKeys) > 0
496 if isServer {
497 for _, k := range t.hostKeys {
498
499
500
501
502
503
504 keyFormat := k.PublicKey().Type()
505
506 switch s := k.(type) {
507 case MultiAlgorithmSigner:
508 for _, algo := range algorithmsForKeyFormat(keyFormat) {
509 if contains(s.Algorithms(), underlyingAlgo(algo)) {
510 msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
511 }
512 }
513 case AlgorithmSigner:
514 msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
515 default:
516 msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
517 }
518 }
519
520 if t.sessionID == nil {
521 msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
522 }
523 } else {
524 msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
525
526
527
528
529
530
531
532 if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
533 msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
534 msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
535 }
536
537 }
538
539 packet := Marshal(msg)
540
541
542 packetCopy := make([]byte, len(packet))
543 copy(packetCopy, packet)
544
545 if err := t.pushPacket(packetCopy); err != nil {
546 return err
547 }
548
549 t.sentInitMsg = msg
550 t.sentInitPacket = packet
551
552 return nil
553 }
554
555 func (t *handshakeTransport) writePacket(p []byte) error {
556 switch p[0] {
557 case msgKexInit:
558 return errors.New("ssh: only handshakeTransport can send kexInit")
559 case msgNewKeys:
560 return errors.New("ssh: only handshakeTransport can send newKeys")
561 }
562
563 t.mu.Lock()
564 defer t.mu.Unlock()
565 if t.writeError != nil {
566 return t.writeError
567 }
568
569 if t.sentInitMsg != nil {
570
571 cp := make([]byte, len(p))
572 copy(cp, p)
573 t.pendingPackets = append(t.pendingPackets, cp)
574 return nil
575 }
576
577 if t.writeBytesLeft > 0 {
578 t.writeBytesLeft -= int64(len(p))
579 } else {
580 t.requestKeyExchange()
581 }
582
583 if t.writePacketsLeft > 0 {
584 t.writePacketsLeft--
585 } else {
586 t.requestKeyExchange()
587 }
588
589 if err := t.pushPacket(p); err != nil {
590 t.writeError = err
591 }
592
593 return nil
594 }
595
596 func (t *handshakeTransport) Close() error {
597
598
599 err := t.conn.Close()
600
601
602
603
604 <-t.kexLoopDone
605
606 return err
607 }
608
609 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
610 if debugHandshake {
611 log.Printf("%s entered key exchange", t.id())
612 }
613
614 otherInit := &kexInitMsg{}
615 if err := Unmarshal(otherInitPacket, otherInit); err != nil {
616 return err
617 }
618
619 magics := handshakeMagics{
620 clientVersion: t.clientVersion,
621 serverVersion: t.serverVersion,
622 clientKexInit: otherInitPacket,
623 serverKexInit: t.sentInitPacket,
624 }
625
626 clientInit := otherInit
627 serverInit := t.sentInitMsg
628 isClient := len(t.hostKeys) == 0
629 if isClient {
630 clientInit, serverInit = serverInit, clientInit
631
632 magics.clientKexInit = t.sentInitPacket
633 magics.serverKexInit = otherInitPacket
634 }
635
636 var err error
637 t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
638 if err != nil {
639 return err
640 }
641
642 if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
643 t.strictMode = true
644 if err := t.conn.setStrictMode(); err != nil {
645 return err
646 }
647 }
648
649
650
651
652
653
654
655
656
657
658
659 if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
660
661
662 if _, err := t.conn.readPacket(); err != nil {
663 return err
664 }
665 }
666
667 kex, ok := kexAlgoMap[t.algorithms.kex]
668 if !ok {
669 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
670 }
671
672 var result *kexResult
673 if len(t.hostKeys) > 0 {
674 result, err = t.server(kex, &magics)
675 } else {
676 result, err = t.client(kex, &magics)
677 }
678
679 if err != nil {
680 return err
681 }
682
683 firstKeyExchange := t.sessionID == nil
684 if firstKeyExchange {
685 t.sessionID = result.H
686 }
687 result.SessionID = t.sessionID
688
689 if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
690 return err
691 }
692 if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
693 return err
694 }
695
696
697
698
699 if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
700 supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
701 extInfo := &extInfoMsg{
702 NumExtensions: 2,
703 Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
704 }
705 extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
706 extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
707 extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
708 extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
709 extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
710 extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
711 extInfo.Payload = appendInt(extInfo.Payload, 1)
712 extInfo.Payload = append(extInfo.Payload, "0"...)
713 if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
714 return err
715 }
716 }
717
718 if packet, err := t.conn.readPacket(); err != nil {
719 return err
720 } else if packet[0] != msgNewKeys {
721 return unexpectedMessageError(msgNewKeys, packet[0])
722 }
723
724 if firstKeyExchange {
725
726
727 t.conn.setInitialKEXDone()
728 }
729
730 return nil
731 }
732
733
734
735
736
737
738
739 type algorithmSignerWrapper struct {
740 Signer
741 }
742
743 func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
744 if algorithm != underlyingAlgo(a.PublicKey().Type()) {
745 return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
746 }
747 return a.Sign(rand, data)
748 }
749
750 func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
751 for _, k := range hostKeys {
752 if s, ok := k.(MultiAlgorithmSigner); ok {
753 if !contains(s.Algorithms(), underlyingAlgo(algo)) {
754 continue
755 }
756 }
757
758 if algo == k.PublicKey().Type() {
759 return algorithmSignerWrapper{k}
760 }
761
762 k, ok := k.(AlgorithmSigner)
763 if !ok {
764 continue
765 }
766 for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
767 if algo == a {
768 return k
769 }
770 }
771 }
772 return nil
773 }
774
775 func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
776 hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
777 if hostKey == nil {
778 return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
779 }
780
781 r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
782 return r, err
783 }
784
785 func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
786 result, err := kex.Client(t.conn, t.config.Rand, magics)
787 if err != nil {
788 return nil, err
789 }
790
791 hostKey, err := ParsePublicKey(result.HostKey)
792 if err != nil {
793 return nil, err
794 }
795
796 if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
797 return nil, err
798 }
799
800 err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
801 if err != nil {
802 return nil, err
803 }
804
805 return result, nil
806 }
807
View as plain text