1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "crypto"
10 "crypto/rand"
11 "encoding/binary"
12 "io"
13 "testing"
14
15 "golang.org/x/crypto/chacha20"
16 "golang.org/x/crypto/internal/poly1305"
17 )
18
19 func TestDefaultCiphersExist(t *testing.T) {
20 for _, cipherAlgo := range supportedCiphers {
21 if _, ok := cipherModes[cipherAlgo]; !ok {
22 t.Errorf("supported cipher %q is unknown", cipherAlgo)
23 }
24 }
25 for _, cipherAlgo := range preferredCiphers {
26 if _, ok := cipherModes[cipherAlgo]; !ok {
27 t.Errorf("preferred cipher %q is unknown", cipherAlgo)
28 }
29 }
30 }
31
32 func TestPacketCiphers(t *testing.T) {
33 defaultMac := "hmac-sha2-256"
34 defaultCipher := "aes128-ctr"
35 for cipher := range cipherModes {
36 t.Run("cipher="+cipher,
37 func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) })
38 }
39 for mac := range macModes {
40 t.Run("mac="+mac,
41 func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) })
42 }
43 }
44
45 func testPacketCipher(t *testing.T, cipher, mac string) {
46 kr := &kexResult{Hash: crypto.SHA1}
47 algs := directionAlgorithms{
48 Cipher: cipher,
49 MAC: mac,
50 Compression: "none",
51 }
52 client, err := newPacketCipher(clientKeys, algs, kr)
53 if err != nil {
54 t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
55 }
56 server, err := newPacketCipher(clientKeys, algs, kr)
57 if err != nil {
58 t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
59 }
60
61 want := "bla bla"
62 input := []byte(want)
63 buf := &bytes.Buffer{}
64 if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil {
65 t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err)
66 }
67
68 packet, err := server.readCipherPacket(0, buf)
69 if err != nil {
70 t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err)
71 }
72
73 if string(packet) != want {
74 t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
75 }
76 }
77
78 func TestCBCOracleCounterMeasure(t *testing.T) {
79 kr := &kexResult{Hash: crypto.SHA1}
80 algs := directionAlgorithms{
81 Cipher: aes128cbcID,
82 MAC: "hmac-sha1",
83 Compression: "none",
84 }
85 client, err := newPacketCipher(clientKeys, algs, kr)
86 if err != nil {
87 t.Fatalf("newPacketCipher(client): %v", err)
88 }
89
90 want := "bla bla"
91 input := []byte(want)
92 buf := &bytes.Buffer{}
93 if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil {
94 t.Errorf("writeCipherPacket: %v", err)
95 }
96
97 packetSize := buf.Len()
98 buf.Write(make([]byte, 2*maxPacket))
99
100
101
102 lastRead := -1
103 for i := 0; i < packetSize; i++ {
104 server, err := newPacketCipher(clientKeys, algs, kr)
105 if err != nil {
106 t.Fatalf("newPacketCipher(client): %v", err)
107 }
108
109 fresh := &bytes.Buffer{}
110 fresh.Write(buf.Bytes())
111 fresh.Bytes()[i] ^= 0x01
112
113 before := fresh.Len()
114 _, err = server.readCipherPacket(0, fresh)
115 if err == nil {
116 t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i)
117 continue
118 }
119 if _, ok := err.(cbcError); !ok {
120 t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
121 continue
122 }
123
124 after := fresh.Len()
125 bytesRead := before - after
126 if bytesRead < maxPacket {
127 t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
128 continue
129 }
130
131 if i > 0 && bytesRead != lastRead {
132 t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
133 }
134 lastRead = bytesRead
135 }
136 }
137
138 func TestCVE202143565(t *testing.T) {
139 tests := []struct {
140 cipher string
141 constructPacket func(packetCipher) io.Reader
142 }{
143 {
144 cipher: gcm128CipherID,
145 constructPacket: func(client packetCipher) io.Reader {
146 internalCipher := client.(*gcmCipher)
147 b := &bytes.Buffer{}
148 prefix := [4]byte{}
149 if _, err := b.Write(prefix[:]); err != nil {
150 t.Fatal(err)
151 }
152 internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:])
153 if _, err := b.Write(internalCipher.buf); err != nil {
154 t.Fatal(err)
155 }
156 internalCipher.incIV()
157
158 return b
159 },
160 },
161 {
162 cipher: chacha20Poly1305ID,
163 constructPacket: func(client packetCipher) io.Reader {
164 internalCipher := client.(*chacha20Poly1305Cipher)
165 b := &bytes.Buffer{}
166
167 nonce := make([]byte, 12)
168 s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce)
169 if err != nil {
170 t.Fatal(err)
171 }
172 var polyKey, discardBuf [32]byte
173 s.XORKeyStream(polyKey[:], polyKey[:])
174 s.XORKeyStream(discardBuf[:], discardBuf[:])
175
176 internalCipher.buf = make([]byte, 4+poly1305.TagSize)
177 binary.BigEndian.PutUint32(internalCipher.buf, 0)
178 ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce)
179 if err != nil {
180 t.Fatal(err)
181 }
182 ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4])
183 if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil {
184 t.Fatal(err)
185 }
186
187 s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4])
188
189 var tag [poly1305.TagSize]byte
190 poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey)
191
192 copy(internalCipher.buf[4:], tag[:])
193
194 if _, err := b.Write(internalCipher.buf); err != nil {
195 t.Fatal(err)
196 }
197
198 return b
199 },
200 },
201 }
202
203 for _, tc := range tests {
204 mac := "hmac-sha2-256"
205
206 kr := &kexResult{Hash: crypto.SHA1}
207 algs := directionAlgorithms{
208 Cipher: tc.cipher,
209 MAC: mac,
210 Compression: "none",
211 }
212 client, err := newPacketCipher(clientKeys, algs, kr)
213 if err != nil {
214 t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err)
215 }
216 server, err := newPacketCipher(clientKeys, algs, kr)
217 if err != nil {
218 t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err)
219 }
220
221 b := tc.constructPacket(client)
222
223 wantErr := "ssh: empty packet"
224 _, err = server.readCipherPacket(0, b)
225 if err == nil {
226 t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac)
227 } else if err.Error() != wantErr {
228 t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr)
229 }
230 }
231 }
232
View as plain text