1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "math/big"
10 "math/rand"
11 "reflect"
12 "testing"
13 "testing/quick"
14 )
15
16 var intLengthTests = []struct {
17 val, length int
18 }{
19 {0, 4 + 0},
20 {1, 4 + 1},
21 {127, 4 + 1},
22 {128, 4 + 2},
23 {-1, 4 + 1},
24 }
25
26 func TestIntLength(t *testing.T) {
27 for _, test := range intLengthTests {
28 v := new(big.Int).SetInt64(int64(test.val))
29 length := intLength(v)
30 if length != test.length {
31 t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
32 }
33 }
34 }
35
36 type msgAllTypes struct {
37 Bool bool `sshtype:"21"`
38 Array [16]byte
39 Uint64 uint64
40 Uint32 uint32
41 Uint8 uint8
42 String string
43 Strings []string
44 Bytes []byte
45 Int *big.Int
46 Rest []byte `ssh:"rest"`
47 }
48
49 func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
50 m := &msgAllTypes{}
51 m.Bool = rand.Intn(2) == 1
52 randomBytes(m.Array[:], rand)
53 m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
54 m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
55 m.Uint8 = uint8(rand.Intn(1 << 8))
56 m.String = string(m.Array[:])
57 m.Strings = randomNameList(rand)
58 m.Bytes = m.Array[:]
59 m.Int = randomInt(rand)
60 m.Rest = m.Array[:]
61 return reflect.ValueOf(m)
62 }
63
64 func TestMarshalUnmarshal(t *testing.T) {
65 rand := rand.New(rand.NewSource(0))
66 iface := &msgAllTypes{}
67 ty := reflect.ValueOf(iface).Type()
68
69 n := 100
70 if testing.Short() {
71 n = 5
72 }
73 for j := 0; j < n; j++ {
74 v, ok := quick.Value(ty, rand)
75 if !ok {
76 t.Errorf("failed to create value")
77 break
78 }
79
80 m1 := v.Elem().Interface()
81 m2 := iface
82
83 marshaled := Marshal(m1)
84 if err := Unmarshal(marshaled, m2); err != nil {
85 t.Errorf("Unmarshal %#v: %s", m1, err)
86 break
87 }
88
89 if !reflect.DeepEqual(v.Interface(), m2) {
90 t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
91 break
92 }
93 }
94 }
95
96 func TestUnmarshalEmptyPacket(t *testing.T) {
97 var b []byte
98 var m channelRequestSuccessMsg
99 if err := Unmarshal(b, &m); err == nil {
100 t.Fatalf("unmarshal of empty slice succeeded")
101 }
102 }
103
104 func TestUnmarshalUnexpectedPacket(t *testing.T) {
105 type S struct {
106 I uint32 `sshtype:"43"`
107 S string
108 B bool
109 }
110
111 s := S{11, "hello", true}
112 packet := Marshal(s)
113 packet[0] = 42
114 roundtrip := S{}
115 err := Unmarshal(packet, &roundtrip)
116 if err == nil {
117 t.Fatal("expected error, not nil")
118 }
119 }
120
121 func TestMarshalPtr(t *testing.T) {
122 s := struct {
123 S string
124 }{"hello"}
125
126 m1 := Marshal(s)
127 m2 := Marshal(&s)
128 if !bytes.Equal(m1, m2) {
129 t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
130 }
131 }
132
133 func TestBareMarshalUnmarshal(t *testing.T) {
134 type S struct {
135 I uint32
136 S string
137 B bool
138 }
139
140 s := S{42, "hello", true}
141 packet := Marshal(s)
142 roundtrip := S{}
143 Unmarshal(packet, &roundtrip)
144
145 if !reflect.DeepEqual(s, roundtrip) {
146 t.Errorf("got %#v, want %#v", roundtrip, s)
147 }
148 }
149
150 func TestBareMarshal(t *testing.T) {
151 type S2 struct {
152 I uint32
153 }
154 s := S2{42}
155 packet := Marshal(s)
156 i, rest, ok := parseUint32(packet)
157 if len(rest) > 0 || !ok {
158 t.Errorf("parseInt(%q): parse error", packet)
159 }
160 if i != s.I {
161 t.Errorf("got %d, want %d", i, s.I)
162 }
163 }
164
165 func TestUnmarshalShortKexInitPacket(t *testing.T) {
166
167
168 packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
169 kim := &kexInitMsg{}
170 if err := Unmarshal(packet, kim); err == nil {
171 t.Error("truncated packet unmarshaled without error")
172 }
173 }
174
175 func TestMarshalMultiTag(t *testing.T) {
176 var res struct {
177 A uint32 `sshtype:"1|2"`
178 }
179
180 good1 := struct {
181 A uint32 `sshtype:"1"`
182 }{
183 1,
184 }
185 good2 := struct {
186 A uint32 `sshtype:"2"`
187 }{
188 1,
189 }
190
191 if e := Unmarshal(Marshal(good1), &res); e != nil {
192 t.Errorf("error unmarshaling multipart tag: %v", e)
193 }
194
195 if e := Unmarshal(Marshal(good2), &res); e != nil {
196 t.Errorf("error unmarshaling multipart tag: %v", e)
197 }
198
199 bad1 := struct {
200 A uint32 `sshtype:"3"`
201 }{
202 1,
203 }
204 if e := Unmarshal(Marshal(bad1), &res); e == nil {
205 t.Errorf("bad struct unmarshaled without error")
206 }
207 }
208
209 func randomBytes(out []byte, rand *rand.Rand) {
210 for i := 0; i < len(out); i++ {
211 out[i] = byte(rand.Int31())
212 }
213 }
214
215 func randomNameList(rand *rand.Rand) []string {
216 ret := make([]string, rand.Int31()&15)
217 for i := range ret {
218 s := make([]byte, 1+(rand.Int31()&15))
219 for j := range s {
220 s[j] = 'a' + uint8(rand.Int31()&15)
221 }
222 ret[i] = string(s)
223 }
224 return ret
225 }
226
227 func randomInt(rand *rand.Rand) *big.Int {
228 return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
229 }
230
231 func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
232 ki := &kexInitMsg{}
233 randomBytes(ki.Cookie[:], rand)
234 ki.KexAlgos = randomNameList(rand)
235 ki.ServerHostKeyAlgos = randomNameList(rand)
236 ki.CiphersClientServer = randomNameList(rand)
237 ki.CiphersServerClient = randomNameList(rand)
238 ki.MACsClientServer = randomNameList(rand)
239 ki.MACsServerClient = randomNameList(rand)
240 ki.CompressionClientServer = randomNameList(rand)
241 ki.CompressionServerClient = randomNameList(rand)
242 ki.LanguagesClientServer = randomNameList(rand)
243 ki.LanguagesServerClient = randomNameList(rand)
244 if rand.Int31()&1 == 1 {
245 ki.FirstKexFollows = true
246 }
247 return reflect.ValueOf(ki)
248 }
249
250 func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
251 dhi := &kexDHInitMsg{}
252 dhi.X = randomInt(rand)
253 return reflect.ValueOf(dhi)
254 }
255
256 var (
257 _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
258 _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
259
260 _kexInit = Marshal(_kexInitMsg)
261 _kexDHInit = Marshal(_kexDHInitMsg)
262 )
263
264 func BenchmarkMarshalKexInitMsg(b *testing.B) {
265 for i := 0; i < b.N; i++ {
266 Marshal(_kexInitMsg)
267 }
268 }
269
270 func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
271 m := new(kexInitMsg)
272 for i := 0; i < b.N; i++ {
273 Unmarshal(_kexInit, m)
274 }
275 }
276
277 func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
278 for i := 0; i < b.N; i++ {
279 Marshal(_kexDHInitMsg)
280 }
281 }
282
283 func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
284 m := new(kexDHInitMsg)
285 for i := 0; i < b.N; i++ {
286 Unmarshal(_kexDHInit, m)
287 }
288 }
289
View as plain text