1
2
3
4
5 package ssh
6
7 import (
8 "reflect"
9 "testing"
10 )
11
12 func TestFindAgreedAlgorithms(t *testing.T) {
13 initKex := func(k *kexInitMsg) {
14 if k.KexAlgos == nil {
15 k.KexAlgos = []string{"kex1"}
16 }
17 if k.ServerHostKeyAlgos == nil {
18 k.ServerHostKeyAlgos = []string{"hostkey1"}
19 }
20 if k.CiphersClientServer == nil {
21 k.CiphersClientServer = []string{"cipher1"}
22
23 }
24 if k.CiphersServerClient == nil {
25 k.CiphersServerClient = []string{"cipher1"}
26
27 }
28 if k.MACsClientServer == nil {
29 k.MACsClientServer = []string{"mac1"}
30
31 }
32 if k.MACsServerClient == nil {
33 k.MACsServerClient = []string{"mac1"}
34
35 }
36 if k.CompressionClientServer == nil {
37 k.CompressionClientServer = []string{"compression1"}
38
39 }
40 if k.CompressionServerClient == nil {
41 k.CompressionServerClient = []string{"compression1"}
42
43 }
44 if k.LanguagesClientServer == nil {
45 k.LanguagesClientServer = []string{"language1"}
46
47 }
48 if k.LanguagesServerClient == nil {
49 k.LanguagesServerClient = []string{"language1"}
50
51 }
52 }
53
54 initDirAlgs := func(a *directionAlgorithms) {
55 if a.Cipher == "" {
56 a.Cipher = "cipher1"
57 }
58 if a.MAC == "" {
59 a.MAC = "mac1"
60 }
61 if a.Compression == "" {
62 a.Compression = "compression1"
63 }
64 }
65
66 initAlgs := func(a *algorithms) {
67 if a.kex == "" {
68 a.kex = "kex1"
69 }
70 if a.hostKey == "" {
71 a.hostKey = "hostkey1"
72 }
73 initDirAlgs(&a.r)
74 initDirAlgs(&a.w)
75 }
76
77 type testcase struct {
78 name string
79 clientIn, serverIn kexInitMsg
80 wantClient, wantServer algorithms
81 wantErr bool
82 }
83
84 cases := []testcase{
85 {
86 name: "standard",
87 },
88
89 {
90 name: "no common hostkey",
91 serverIn: kexInitMsg{
92 ServerHostKeyAlgos: []string{"hostkey2"},
93 },
94 wantErr: true,
95 },
96
97 {
98 name: "no common kex",
99 serverIn: kexInitMsg{
100 KexAlgos: []string{"kex2"},
101 },
102 wantErr: true,
103 },
104
105 {
106 name: "no common cipher",
107 serverIn: kexInitMsg{
108 CiphersClientServer: []string{"cipher2"},
109 },
110 wantErr: true,
111 },
112
113 {
114 name: "client decides cipher",
115 serverIn: kexInitMsg{
116 CiphersClientServer: []string{"cipher1", "cipher2"},
117 CiphersServerClient: []string{"cipher2", "cipher3"},
118 },
119 clientIn: kexInitMsg{
120 CiphersClientServer: []string{"cipher2", "cipher1"},
121 CiphersServerClient: []string{"cipher3", "cipher2"},
122 },
123 wantClient: algorithms{
124 r: directionAlgorithms{
125 Cipher: "cipher3",
126 },
127 w: directionAlgorithms{
128 Cipher: "cipher2",
129 },
130 },
131 wantServer: algorithms{
132 w: directionAlgorithms{
133 Cipher: "cipher3",
134 },
135 r: directionAlgorithms{
136 Cipher: "cipher2",
137 },
138 },
139 },
140
141
142
143 }
144
145 for i := range cases {
146 initKex(&cases[i].clientIn)
147 initKex(&cases[i].serverIn)
148 initAlgs(&cases[i].wantClient)
149 initAlgs(&cases[i].wantServer)
150 }
151
152 for _, c := range cases {
153 t.Run(c.name, func(t *testing.T) {
154 serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
155 clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
156
157 serverHasErr := serverErr != nil
158 clientHasErr := clientErr != nil
159 if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
160 t.Fatalf("got client/server error (%v, %v), want hasError %v",
161 clientErr, serverErr, c.wantErr)
162
163 }
164 if c.wantErr {
165 return
166 }
167
168 if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
169 t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
170 }
171 if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
172 t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
173 }
174 })
175 }
176 }
177
View as plain text