...
1
2
3
4
5 package ssh
6
7 import (
8 "io"
9 "sync"
10 "testing"
11 )
12
13
14
15 type memTransport struct {
16 eof bool
17 pending [][]byte
18 write *memTransport
19 writeCount uint64
20 sync.Mutex
21 *sync.Cond
22 }
23
24 func (t *memTransport) readPacket() ([]byte, error) {
25 t.Lock()
26 defer t.Unlock()
27 for {
28 if len(t.pending) > 0 {
29 r := t.pending[0]
30 t.pending = t.pending[1:]
31 return r, nil
32 }
33 if t.eof {
34 return nil, io.EOF
35 }
36 t.Cond.Wait()
37 }
38 }
39
40 func (t *memTransport) closeSelf() error {
41 t.Lock()
42 defer t.Unlock()
43 if t.eof {
44 return io.EOF
45 }
46 t.eof = true
47 t.Cond.Broadcast()
48 return nil
49 }
50
51 func (t *memTransport) Close() error {
52 err := t.write.closeSelf()
53 t.closeSelf()
54 return err
55 }
56
57 func (t *memTransport) writePacket(p []byte) error {
58 t.write.Lock()
59 defer t.write.Unlock()
60 if t.write.eof {
61 return io.EOF
62 }
63 c := make([]byte, len(p))
64 copy(c, p)
65 t.write.pending = append(t.write.pending, c)
66 t.write.Cond.Signal()
67 t.writeCount++
68 return nil
69 }
70
71 func (t *memTransport) getWriteCount() uint64 {
72 t.write.Lock()
73 defer t.write.Unlock()
74 return t.writeCount
75 }
76
77 func memPipe() (a, b packetConn) {
78 t1 := memTransport{}
79 t2 := memTransport{}
80 t1.write = &t2
81 t2.write = &t1
82 t1.Cond = sync.NewCond(&t1.Mutex)
83 t2.Cond = sync.NewCond(&t2.Mutex)
84 return &t1, &t2
85 }
86
87 func TestMemPipe(t *testing.T) {
88 a, b := memPipe()
89 if err := a.writePacket([]byte{42}); err != nil {
90 t.Fatalf("writePacket: %v", err)
91 }
92 if wc := a.(*memTransport).getWriteCount(); wc != 1 {
93 t.Fatalf("got %v, want 1", wc)
94 }
95 if err := a.Close(); err != nil {
96 t.Fatal("Close: ", err)
97 }
98 p, err := b.readPacket()
99 if err != nil {
100 t.Fatal("readPacket: ", err)
101 }
102 if len(p) != 1 || p[0] != 42 {
103 t.Fatalf("got %v, want {42}", p)
104 }
105 p, err = b.readPacket()
106 if err != io.EOF {
107 t.Fatalf("got %v, %v, want EOF", p, err)
108 }
109 if wc := b.(*memTransport).getWriteCount(); wc != 0 {
110 t.Fatalf("got %v, want 0", wc)
111 }
112 }
113
114 func TestDoubleClose(t *testing.T) {
115 a, _ := memPipe()
116 err := a.Close()
117 if err != nil {
118 t.Errorf("Close: %v", err)
119 }
120 err = a.Close()
121 if err != io.EOF {
122 t.Errorf("expect EOF on double close.")
123 }
124 }
125
View as plain text