1
2
3
4
5 package packet
6
7 import (
8 "bytes"
9 "encoding/hex"
10 "fmt"
11 "golang.org/x/crypto/openpgp/errors"
12 "io"
13 "testing"
14 )
15
16 func TestReadFull(t *testing.T) {
17 var out [4]byte
18
19 b := bytes.NewBufferString("foo")
20 n, err := readFull(b, out[:3])
21 if n != 3 || err != nil {
22 t.Errorf("full read failed n:%d err:%s", n, err)
23 }
24
25 b = bytes.NewBufferString("foo")
26 n, err = readFull(b, out[:4])
27 if n != 3 || err != io.ErrUnexpectedEOF {
28 t.Errorf("partial read failed n:%d err:%s", n, err)
29 }
30
31 b = bytes.NewBuffer(nil)
32 n, err = readFull(b, out[:3])
33 if n != 0 || err != io.ErrUnexpectedEOF {
34 t.Errorf("empty read failed n:%d err:%s", n, err)
35 }
36 }
37
38 func readerFromHex(s string) io.Reader {
39 data, err := hex.DecodeString(s)
40 if err != nil {
41 panic("readerFromHex: bad input")
42 }
43 return bytes.NewBuffer(data)
44 }
45
46 var readLengthTests = []struct {
47 hexInput string
48 length int64
49 isPartial bool
50 err error
51 }{
52 {"", 0, false, io.ErrUnexpectedEOF},
53 {"1f", 31, false, nil},
54 {"c0", 0, false, io.ErrUnexpectedEOF},
55 {"c101", 256 + 1 + 192, false, nil},
56 {"e0", 1, true, nil},
57 {"e1", 2, true, nil},
58 {"e2", 4, true, nil},
59 {"ff", 0, false, io.ErrUnexpectedEOF},
60 {"ff00", 0, false, io.ErrUnexpectedEOF},
61 {"ff0000", 0, false, io.ErrUnexpectedEOF},
62 {"ff000000", 0, false, io.ErrUnexpectedEOF},
63 {"ff00000000", 0, false, nil},
64 {"ff01020304", 16909060, false, nil},
65 }
66
67 func TestReadLength(t *testing.T) {
68 for i, test := range readLengthTests {
69 length, isPartial, err := readLength(readerFromHex(test.hexInput))
70 if test.err != nil {
71 if err != test.err {
72 t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err)
73 }
74 continue
75 }
76 if err != nil {
77 t.Errorf("%d: unexpected error: %s", i, err)
78 continue
79 }
80 if length != test.length || isPartial != test.isPartial {
81 t.Errorf("%d: bad result got:(%d,%t) want:(%d,%t)", i, length, isPartial, test.length, test.isPartial)
82 }
83 }
84 }
85
86 var partialLengthReaderTests = []struct {
87 hexInput string
88 err error
89 hexOutput string
90 }{
91 {"e0", io.ErrUnexpectedEOF, ""},
92 {"e001", io.ErrUnexpectedEOF, ""},
93 {"e0010102", nil, "0102"},
94 {"ff00000000", nil, ""},
95 {"e10102e1030400", nil, "01020304"},
96 {"e101", io.ErrUnexpectedEOF, ""},
97 }
98
99 func TestPartialLengthReader(t *testing.T) {
100 for i, test := range partialLengthReaderTests {
101 r := &partialLengthReader{readerFromHex(test.hexInput), 0, true}
102 out, err := io.ReadAll(r)
103 if test.err != nil {
104 if err != test.err {
105 t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err)
106 }
107 continue
108 }
109 if err != nil {
110 t.Errorf("%d: unexpected error: %s", i, err)
111 continue
112 }
113
114 got := fmt.Sprintf("%x", out)
115 if got != test.hexOutput {
116 t.Errorf("%d: got:%s want:%s", i, test.hexOutput, got)
117 }
118 }
119 }
120
121 var readHeaderTests = []struct {
122 hexInput string
123 structuralError bool
124 unexpectedEOF bool
125 tag int
126 length int64
127 hexOutput string
128 }{
129 {"", false, false, 0, 0, ""},
130 {"7f", true, false, 0, 0, ""},
131
132
133 {"80", false, true, 0, 0, ""},
134 {"8001", false, true, 0, 1, ""},
135 {"800102", false, false, 0, 1, "02"},
136 {"81000102", false, false, 0, 1, "02"},
137 {"820000000102", false, false, 0, 1, "02"},
138 {"860000000102", false, false, 1, 1, "02"},
139 {"83010203", false, false, 0, -1, "010203"},
140
141
142 {"c0", false, true, 0, 0, ""},
143 {"c000", false, false, 0, 0, ""},
144 {"c00102", false, false, 0, 1, "02"},
145 {"c0020203", false, false, 0, 2, "0203"},
146 {"c00202", false, true, 0, 2, ""},
147 {"c3020203", false, false, 3, 2, "0203"},
148 }
149
150 func TestReadHeader(t *testing.T) {
151 for i, test := range readHeaderTests {
152 tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
153 if test.structuralError {
154 if _, ok := err.(errors.StructuralError); ok {
155 continue
156 }
157 t.Errorf("%d: expected StructuralError, got:%s", i, err)
158 continue
159 }
160 if err != nil {
161 if len(test.hexInput) == 0 && err == io.EOF {
162 continue
163 }
164 if !test.unexpectedEOF || err != io.ErrUnexpectedEOF {
165 t.Errorf("%d: unexpected error from readHeader: %s", i, err)
166 }
167 continue
168 }
169 if int(tag) != test.tag || length != test.length {
170 t.Errorf("%d: got:(%d,%d) want:(%d,%d)", i, int(tag), length, test.tag, test.length)
171 continue
172 }
173
174 body, err := io.ReadAll(contents)
175 if err != nil {
176 if !test.unexpectedEOF || err != io.ErrUnexpectedEOF {
177 t.Errorf("%d: unexpected error from contents: %s", i, err)
178 }
179 continue
180 }
181 if test.unexpectedEOF {
182 t.Errorf("%d: expected ErrUnexpectedEOF from contents but got no error", i)
183 continue
184 }
185 got := fmt.Sprintf("%x", body)
186 if got != test.hexOutput {
187 t.Errorf("%d: got:%s want:%s", i, got, test.hexOutput)
188 }
189 }
190 }
191
192 func TestSerializeHeader(t *testing.T) {
193 tag := packetTypePublicKey
194 lengths := []int{0, 1, 2, 64, 192, 193, 8000, 8384, 8385, 10000}
195
196 for _, length := range lengths {
197 buf := bytes.NewBuffer(nil)
198 serializeHeader(buf, tag, length)
199 tag2, length2, _, err := readHeader(buf)
200 if err != nil {
201 t.Errorf("length %d, err: %s", length, err)
202 }
203 if tag2 != tag {
204 t.Errorf("length %d, tag incorrect (got %d, want %d)", length, tag2, tag)
205 }
206 if int(length2) != length {
207 t.Errorf("length %d, length incorrect (got %d)", length, length2)
208 }
209 }
210 }
211
212 func TestPartialLengths(t *testing.T) {
213 buf := bytes.NewBuffer(nil)
214 w := new(partialLengthWriter)
215 w.w = noOpCloser{buf}
216
217 const maxChunkSize = 64
218
219 var b [maxChunkSize]byte
220 var n uint8
221 for l := 1; l <= maxChunkSize; l++ {
222 for i := 0; i < l; i++ {
223 b[i] = n
224 n++
225 }
226 m, err := w.Write(b[:l])
227 if m != l {
228 t.Errorf("short write got: %d want: %d", m, l)
229 }
230 if err != nil {
231 t.Errorf("error from write: %s", err)
232 }
233 }
234 if err := w.Close(); err != nil {
235 t.Fatal(err)
236 }
237
238
239 first, err := buf.ReadByte()
240 if err != nil {
241 t.Fatal(err)
242 }
243 if plen := 1 << (first & 0x1f); plen < 512 {
244 t.Errorf("first packet too short: got %d want at least %d", plen, 512)
245 }
246 if err := buf.UnreadByte(); err != nil {
247 t.Fatal(err)
248 }
249
250 want := (maxChunkSize * (maxChunkSize + 1)) / 2
251 copyBuf := bytes.NewBuffer(nil)
252 r := &partialLengthReader{buf, 0, true}
253 m, err := io.Copy(copyBuf, r)
254 if m != int64(want) {
255 t.Errorf("short copy got: %d want: %d", m, want)
256 }
257 if err != nil {
258 t.Errorf("error from copy: %s", err)
259 }
260
261 copyBytes := copyBuf.Bytes()
262 for i := 0; i < want; i++ {
263 if copyBytes[i] != uint8(i) {
264 t.Errorf("bad pattern in copy at %d", i)
265 break
266 }
267 }
268 }
269
270 func TestPartialLengthsShortWrite(t *testing.T) {
271 buf := bytes.NewBuffer(nil)
272 w := &partialLengthWriter{
273 w: noOpCloser{buf},
274 }
275 data := bytes.Repeat([]byte("a"), 510)
276 if _, err := w.Write(data); err != nil {
277 t.Fatal(err)
278 }
279 if err := w.Close(); err != nil {
280 t.Fatal(err)
281 }
282 copyBuf := bytes.NewBuffer(nil)
283 r := &partialLengthReader{buf, 0, true}
284 if _, err := io.Copy(copyBuf, r); err != nil {
285 t.Fatal(err)
286 }
287 if !bytes.Equal(copyBuf.Bytes(), data) {
288 t.Errorf("got %q want %q", buf.Bytes(), data)
289 }
290 }
291
View as plain text