1
2
3
4
5 package ssh
6
7 import (
8 "context"
9 "errors"
10 "fmt"
11 "io"
12 "math/rand"
13 "net"
14 "strconv"
15 "strings"
16 "sync"
17 "time"
18 )
19
20
21
22
23
24
25 func (c *Client) Listen(n, addr string) (net.Listener, error) {
26 switch n {
27 case "tcp", "tcp4", "tcp6":
28 laddr, err := net.ResolveTCPAddr(n, addr)
29 if err != nil {
30 return nil, err
31 }
32 return c.ListenTCP(laddr)
33 case "unix":
34 return c.ListenUnix(addr)
35 default:
36 return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
37 }
38 }
39
40
41
42
43
44
45
46
47 const openSSHPrefix = "OpenSSH_"
48
49 var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
50
51
52
53
54 func isBrokenOpenSSHVersion(versionStr string) bool {
55 i := strings.Index(versionStr, openSSHPrefix)
56 if i < 0 {
57 return false
58 }
59 i += len(openSSHPrefix)
60 j := i
61 for ; j < len(versionStr); j++ {
62 if versionStr[j] < '0' || versionStr[j] > '9' {
63 break
64 }
65 }
66 version, _ := strconv.Atoi(versionStr[i:j])
67 return version < 6
68 }
69
70
71
72 func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
73 var sshListener net.Listener
74 var err error
75 const tries = 10
76 for i := 0; i < tries; i++ {
77 addr := *laddr
78 addr.Port = 1024 + portRandomizer.Intn(60000)
79 sshListener, err = c.ListenTCP(&addr)
80 if err == nil {
81 laddr.Port = addr.Port
82 return sshListener, err
83 }
84 }
85 return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
86 }
87
88
89 type channelForwardMsg struct {
90 addr string
91 rport uint32
92 }
93
94
95
96
97 func (c *Client) handleForwards() {
98 go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
99 go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
100 }
101
102
103
104
105 func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
106 c.handleForwardsOnce.Do(c.handleForwards)
107 if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
108 return c.autoPortListenWorkaround(laddr)
109 }
110
111 m := channelForwardMsg{
112 laddr.IP.String(),
113 uint32(laddr.Port),
114 }
115
116 ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
117 if err != nil {
118 return nil, err
119 }
120 if !ok {
121 return nil, errors.New("ssh: tcpip-forward request denied by peer")
122 }
123
124
125
126 if laddr.Port == 0 {
127 var p struct {
128 Port uint32
129 }
130 if err := Unmarshal(resp, &p); err != nil {
131 return nil, err
132 }
133 laddr.Port = int(p.Port)
134 }
135
136
137 ch := c.forwards.add(laddr)
138
139 return &tcpListener{laddr, c, ch}, nil
140 }
141
142
143
144 type forwardList struct {
145 sync.Mutex
146 entries []forwardEntry
147 }
148
149
150
151 type forwardEntry struct {
152 laddr net.Addr
153 c chan forward
154 }
155
156
157
158
159 type forward struct {
160 newCh NewChannel
161 raddr net.Addr
162 }
163
164 func (l *forwardList) add(addr net.Addr) chan forward {
165 l.Lock()
166 defer l.Unlock()
167 f := forwardEntry{
168 laddr: addr,
169 c: make(chan forward, 1),
170 }
171 l.entries = append(l.entries, f)
172 return f.c
173 }
174
175
176 type forwardedTCPPayload struct {
177 Addr string
178 Port uint32
179 OriginAddr string
180 OriginPort uint32
181 }
182
183
184 func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
185 if port == 0 || port > 65535 {
186 return nil, fmt.Errorf("ssh: port number out of range: %d", port)
187 }
188 ip := net.ParseIP(string(addr))
189 if ip == nil {
190 return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
191 }
192 return &net.TCPAddr{IP: ip, Port: int(port)}, nil
193 }
194
195 func (l *forwardList) handleChannels(in <-chan NewChannel) {
196 for ch := range in {
197 var (
198 laddr net.Addr
199 raddr net.Addr
200 err error
201 )
202 switch channelType := ch.ChannelType(); channelType {
203 case "forwarded-tcpip":
204 var payload forwardedTCPPayload
205 if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
206 ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
207 continue
208 }
209
210
211
212
213
214
215 laddr, err = parseTCPAddr(payload.Addr, payload.Port)
216 if err != nil {
217 ch.Reject(ConnectionFailed, err.Error())
218 continue
219 }
220 raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
221 if err != nil {
222 ch.Reject(ConnectionFailed, err.Error())
223 continue
224 }
225
226 case "forwarded-streamlocal@openssh.com":
227 var payload forwardedStreamLocalPayload
228 if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
229 ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
230 continue
231 }
232 laddr = &net.UnixAddr{
233 Name: payload.SocketPath,
234 Net: "unix",
235 }
236 raddr = &net.UnixAddr{
237 Name: "@",
238 Net: "unix",
239 }
240 default:
241 panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
242 }
243 if ok := l.forward(laddr, raddr, ch); !ok {
244
245
246 ch.Reject(Prohibited, "no forward for address")
247 continue
248 }
249
250 }
251 }
252
253
254
255 func (l *forwardList) remove(addr net.Addr) {
256 l.Lock()
257 defer l.Unlock()
258 for i, f := range l.entries {
259 if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
260 l.entries = append(l.entries[:i], l.entries[i+1:]...)
261 close(f.c)
262 return
263 }
264 }
265 }
266
267
268 func (l *forwardList) closeAll() {
269 l.Lock()
270 defer l.Unlock()
271 for _, f := range l.entries {
272 close(f.c)
273 }
274 l.entries = nil
275 }
276
277 func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
278 l.Lock()
279 defer l.Unlock()
280 for _, f := range l.entries {
281 if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
282 f.c <- forward{newCh: ch, raddr: raddr}
283 return true
284 }
285 }
286 return false
287 }
288
289 type tcpListener struct {
290 laddr *net.TCPAddr
291
292 conn *Client
293 in <-chan forward
294 }
295
296
297 func (l *tcpListener) Accept() (net.Conn, error) {
298 s, ok := <-l.in
299 if !ok {
300 return nil, io.EOF
301 }
302 ch, incoming, err := s.newCh.Accept()
303 if err != nil {
304 return nil, err
305 }
306 go DiscardRequests(incoming)
307
308 return &chanConn{
309 Channel: ch,
310 laddr: l.laddr,
311 raddr: s.raddr,
312 }, nil
313 }
314
315
316 func (l *tcpListener) Close() error {
317 m := channelForwardMsg{
318 l.laddr.IP.String(),
319 uint32(l.laddr.Port),
320 }
321
322
323 l.conn.forwards.remove(l.laddr)
324 ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
325 if err == nil && !ok {
326 err = errors.New("ssh: cancel-tcpip-forward failed")
327 }
328 return err
329 }
330
331
332 func (l *tcpListener) Addr() net.Addr {
333 return l.laddr
334 }
335
336
337
338
339
340
341
342
343 func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
344 if err := ctx.Err(); err != nil {
345 return nil, err
346 }
347 type connErr struct {
348 conn net.Conn
349 err error
350 }
351 ch := make(chan connErr)
352 go func() {
353 conn, err := c.Dial(n, addr)
354 select {
355 case ch <- connErr{conn, err}:
356 case <-ctx.Done():
357 if conn != nil {
358 conn.Close()
359 }
360 }
361 }()
362 select {
363 case res := <-ch:
364 return res.conn, res.err
365 case <-ctx.Done():
366 return nil, ctx.Err()
367 }
368 }
369
370
371
372 func (c *Client) Dial(n, addr string) (net.Conn, error) {
373 var ch Channel
374 switch n {
375 case "tcp", "tcp4", "tcp6":
376
377 host, portString, err := net.SplitHostPort(addr)
378 if err != nil {
379 return nil, err
380 }
381 port, err := strconv.ParseUint(portString, 10, 16)
382 if err != nil {
383 return nil, err
384 }
385 ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
386 if err != nil {
387 return nil, err
388 }
389
390 zeroAddr := &net.TCPAddr{
391 IP: net.IPv4zero,
392 Port: 0,
393 }
394 return &chanConn{
395 Channel: ch,
396 laddr: zeroAddr,
397 raddr: zeroAddr,
398 }, nil
399 case "unix":
400 var err error
401 ch, err = c.dialStreamLocal(addr)
402 if err != nil {
403 return nil, err
404 }
405 return &chanConn{
406 Channel: ch,
407 laddr: &net.UnixAddr{
408 Name: "@",
409 Net: "unix",
410 },
411 raddr: &net.UnixAddr{
412 Name: addr,
413 Net: "unix",
414 },
415 }, nil
416 default:
417 return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
418 }
419 }
420
421
422
423
424 func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
425 if laddr == nil {
426 laddr = &net.TCPAddr{
427 IP: net.IPv4zero,
428 Port: 0,
429 }
430 }
431 ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
432 if err != nil {
433 return nil, err
434 }
435 return &chanConn{
436 Channel: ch,
437 laddr: laddr,
438 raddr: raddr,
439 }, nil
440 }
441
442
443 type channelOpenDirectMsg struct {
444 raddr string
445 rport uint32
446 laddr string
447 lport uint32
448 }
449
450 func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
451 msg := channelOpenDirectMsg{
452 raddr: raddr,
453 rport: uint32(rport),
454 laddr: laddr,
455 lport: uint32(lport),
456 }
457 ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
458 if err != nil {
459 return nil, err
460 }
461 go DiscardRequests(in)
462 return ch, err
463 }
464
465 type tcpChan struct {
466 Channel
467 }
468
469
470
471 type chanConn struct {
472 Channel
473 laddr, raddr net.Addr
474 }
475
476
477 func (t *chanConn) LocalAddr() net.Addr {
478 return t.laddr
479 }
480
481
482 func (t *chanConn) RemoteAddr() net.Addr {
483 return t.raddr
484 }
485
486
487
488 func (t *chanConn) SetDeadline(deadline time.Time) error {
489 if err := t.SetReadDeadline(deadline); err != nil {
490 return err
491 }
492 return t.SetWriteDeadline(deadline)
493 }
494
495
496
497
498
499 func (t *chanConn) SetReadDeadline(deadline time.Time) error {
500
501
502 return errors.New("ssh: tcpChan: deadline not supported")
503 }
504
505
506
507 func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
508 return errors.New("ssh: tcpChan: deadline not supported")
509 }
510
View as plain text