1
2
3
4
5
6
7 package zstd
8
9 import (
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 )
15
16
17
18 var fuzzing = false
19
20
21 type Reader struct {
22
23 r io.Reader
24
25
26
27
28 sawFrameHeader bool
29
30
31 hasChecksum bool
32
33
34 readOneFrame bool
35
36
37 frameSizeUnknown bool
38
39
40
41 remainingFrameSize uint64
42
43
44
45 blockOffset int64
46
47
48 buffer []byte
49
50 off int
51
52
53 repeatedOffset1 uint32
54 repeatedOffset2 uint32
55 repeatedOffset3 uint32
56
57
58 huffmanTable []uint16
59 huffmanTableBits int
60
61
62 window window
63
64
65 compressedBuf []byte
66
67
68 literals []byte
69
70
71 seqTables [3][]fseBaselineEntry
72 seqTableBits [3]uint8
73
74
75 seqTableBuffers [3][]fseBaselineEntry
76
77
78 scratch [16]byte
79
80
81 fseScratch []fseEntry
82
83
84 checksum xxhash64
85 }
86
87
88 func NewReader(input io.Reader) *Reader {
89 r := new(Reader)
90 r.Reset(input)
91 return r
92 }
93
94
95
96 func (r *Reader) Reset(input io.Reader) {
97 r.r = input
98
99
100
101 r.sawFrameHeader = false
102 r.hasChecksum = false
103 r.readOneFrame = false
104 r.frameSizeUnknown = false
105 r.remainingFrameSize = 0
106 r.blockOffset = 0
107 r.buffer = r.buffer[:0]
108 r.off = 0
109
110
111
112
113
114
115
116
117
118
119
120
121
122 }
123
124
125 func (r *Reader) Read(p []byte) (int, error) {
126 if err := r.refillIfNeeded(); err != nil {
127 return 0, err
128 }
129 n := copy(p, r.buffer[r.off:])
130 r.off += n
131 return n, nil
132 }
133
134
135 func (r *Reader) ReadByte() (byte, error) {
136 if err := r.refillIfNeeded(); err != nil {
137 return 0, err
138 }
139 ret := r.buffer[r.off]
140 r.off++
141 return ret, nil
142 }
143
144
145 func (r *Reader) refillIfNeeded() error {
146 for r.off >= len(r.buffer) {
147 if err := r.refill(); err != nil {
148 return err
149 }
150 r.off = 0
151 }
152 return nil
153 }
154
155
156 func (r *Reader) refill() error {
157 if !r.sawFrameHeader {
158 if err := r.readFrameHeader(); err != nil {
159 return err
160 }
161 }
162 return r.readBlock()
163 }
164
165
166 func (r *Reader) readFrameHeader() error {
167 retry:
168 relativeOffset := 0
169
170
171 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
172
173 if err == io.EOF && !r.readOneFrame {
174 err = io.ErrUnexpectedEOF
175 }
176 return r.wrapError(relativeOffset, err)
177 }
178
179 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
180 if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
181
182 r.blockOffset += int64(relativeOffset) + 4
183 if err := r.skipFrame(); err != nil {
184 return err
185 }
186 r.readOneFrame = true
187 goto retry
188 }
189
190 return r.makeError(relativeOffset, "invalid magic number")
191 }
192
193 relativeOffset += 4
194
195
196 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
197 return r.wrapNonEOFError(relativeOffset, err)
198 }
199 descriptor := r.scratch[0]
200
201 singleSegment := descriptor&(1<<5) != 0
202
203 fcsFieldSize := 1 << (descriptor >> 6)
204 if fcsFieldSize == 1 && !singleSegment {
205 fcsFieldSize = 0
206 }
207
208 var windowDescriptorSize int
209 if singleSegment {
210 windowDescriptorSize = 0
211 } else {
212 windowDescriptorSize = 1
213 }
214
215 if descriptor&(1<<3) != 0 {
216 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
217 }
218
219 r.hasChecksum = descriptor&(1<<2) != 0
220 if r.hasChecksum {
221 r.checksum.reset()
222 }
223
224
225 dictionaryIdSize := 0
226 if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
227 dictionaryIdSize = 1 << (dictIdFlag - 1)
228 }
229
230 relativeOffset++
231
232 headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
233
234 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
235 return r.wrapNonEOFError(relativeOffset, err)
236 }
237
238
239
240 var windowSize int
241 if !singleSegment {
242
243 windowDescriptor := r.scratch[0]
244 exponent := uint64(windowDescriptor >> 3)
245 mantissa := uint64(windowDescriptor & 7)
246 windowLog := exponent + 10
247 windowBase := uint64(1) << windowLog
248 windowAdd := (windowBase / 8) * mantissa
249 windowSize = int(windowBase + windowAdd)
250
251
252 if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
253 return r.makeError(relativeOffset, "windowSize too large")
254 }
255 }
256
257
258 if dictionaryIdSize != 0 {
259 dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
260
261 for _, b := range dictionaryId {
262 if b != 0 {
263 return r.makeError(relativeOffset, "dictionaries are not supported")
264 }
265 }
266 }
267
268
269 r.frameSizeUnknown = false
270 r.remainingFrameSize = 0
271 fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
272 switch fcsFieldSize {
273 case 0:
274 r.frameSizeUnknown = true
275 case 1:
276 r.remainingFrameSize = uint64(fb[0])
277 case 2:
278 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
279 case 4:
280 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
281 case 8:
282 r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
283 default:
284 panic("unreachable")
285 }
286
287
288
289
290 if singleSegment {
291 windowSize = int(r.remainingFrameSize)
292 }
293
294
295 if windowSize > 8<<20 {
296 windowSize = 8 << 20
297 }
298
299 relativeOffset += headerSize
300
301 r.sawFrameHeader = true
302 r.readOneFrame = true
303 r.blockOffset += int64(relativeOffset)
304
305
306 r.repeatedOffset1 = 1
307 r.repeatedOffset2 = 4
308 r.repeatedOffset3 = 8
309 r.huffmanTableBits = 0
310 r.window.reset(windowSize)
311 r.seqTables[0] = nil
312 r.seqTables[1] = nil
313 r.seqTables[2] = nil
314
315 return nil
316 }
317
318
319 func (r *Reader) skipFrame() error {
320 relativeOffset := 0
321
322 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
323 return r.wrapNonEOFError(relativeOffset, err)
324 }
325
326 relativeOffset += 4
327
328 size := binary.LittleEndian.Uint32(r.scratch[:4])
329 if size == 0 {
330 r.blockOffset += int64(relativeOffset)
331 return nil
332 }
333
334 if seeker, ok := r.r.(io.Seeker); ok {
335 r.blockOffset += int64(relativeOffset)
336
337
338 prev, err := seeker.Seek(0, io.SeekCurrent)
339 if err != nil {
340 return r.wrapError(0, err)
341 }
342 end, err := seeker.Seek(0, io.SeekEnd)
343 if err != nil {
344 return r.wrapError(0, err)
345 }
346 if prev > end-int64(size) {
347 r.blockOffset += end - prev
348 return r.makeEOFError(0)
349 }
350
351
352 _, err = seeker.Seek(prev+int64(size), io.SeekStart)
353 if err != nil {
354 return r.wrapError(0, err)
355 }
356 r.blockOffset += int64(size)
357 return nil
358 }
359
360 var skip []byte
361 const chunk = 1 << 20
362 for size >= chunk {
363 if len(skip) == 0 {
364 skip = make([]byte, chunk)
365 }
366 if _, err := io.ReadFull(r.r, skip); err != nil {
367 return r.wrapNonEOFError(relativeOffset, err)
368 }
369 relativeOffset += chunk
370 size -= chunk
371 }
372 if size > 0 {
373 if len(skip) == 0 {
374 skip = make([]byte, size)
375 }
376 if _, err := io.ReadFull(r.r, skip); err != nil {
377 return r.wrapNonEOFError(relativeOffset, err)
378 }
379 relativeOffset += int(size)
380 }
381
382 r.blockOffset += int64(relativeOffset)
383
384 return nil
385 }
386
387
388 func (r *Reader) readBlock() error {
389 relativeOffset := 0
390
391
392 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
393 return r.wrapNonEOFError(relativeOffset, err)
394 }
395
396 relativeOffset += 3
397
398 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
399
400 lastBlock := header&1 != 0
401 blockType := (header >> 1) & 3
402 blockSize := int(header >> 3)
403
404
405
406
407 if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
408 return r.makeError(relativeOffset, "block size too large")
409 }
410
411
412 switch blockType {
413 case 0:
414 r.setBufferSize(blockSize)
415 if _, err := io.ReadFull(r.r, r.buffer); err != nil {
416 return r.wrapNonEOFError(relativeOffset, err)
417 }
418 relativeOffset += blockSize
419 r.blockOffset += int64(relativeOffset)
420 case 1:
421 r.setBufferSize(blockSize)
422 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
423 return r.wrapNonEOFError(relativeOffset, err)
424 }
425 relativeOffset++
426 v := r.scratch[0]
427 for i := range r.buffer {
428 r.buffer[i] = v
429 }
430 r.blockOffset += int64(relativeOffset)
431 case 2:
432 r.blockOffset += int64(relativeOffset)
433 if err := r.compressedBlock(blockSize); err != nil {
434 return err
435 }
436 r.blockOffset += int64(blockSize)
437 case 3:
438 return r.makeError(relativeOffset, "invalid block type")
439 }
440
441 if !r.frameSizeUnknown {
442 if uint64(len(r.buffer)) > r.remainingFrameSize {
443 return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
444 }
445 r.remainingFrameSize -= uint64(len(r.buffer))
446 }
447
448 if r.hasChecksum {
449 r.checksum.update(r.buffer)
450 }
451
452 if !lastBlock {
453 r.window.save(r.buffer)
454 } else {
455 if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
456 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
457 }
458
459 if r.hasChecksum {
460 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
461 return r.wrapNonEOFError(0, err)
462 }
463
464 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
465 dataChecksum := uint32(r.checksum.digest())
466 if inputChecksum != dataChecksum {
467 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
468 }
469
470 r.blockOffset += 4
471 }
472 r.sawFrameHeader = false
473 }
474
475 return nil
476 }
477
478
479
480 func (r *Reader) setBufferSize(size int) {
481 if cap(r.buffer) < size {
482 need := size - cap(r.buffer)
483 r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
484 }
485 r.buffer = r.buffer[:size]
486 }
487
488
489 type zstdError struct {
490 offset int64
491 err error
492 }
493
494 func (ze *zstdError) Error() string {
495 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
496 }
497
498 func (ze *zstdError) Unwrap() error {
499 return ze.err
500 }
501
502 func (r *Reader) makeEOFError(off int) error {
503 return r.wrapError(off, io.ErrUnexpectedEOF)
504 }
505
506 func (r *Reader) wrapNonEOFError(off int, err error) error {
507 if err == io.EOF {
508 err = io.ErrUnexpectedEOF
509 }
510 return r.wrapError(off, err)
511 }
512
513 func (r *Reader) makeError(off int, msg string) error {
514 return r.wrapError(off, errors.New(msg))
515 }
516
517 func (r *Reader) wrapError(off int, err error) error {
518 if err == io.EOF {
519 return err
520 }
521 return &zstdError{r.blockOffset + int64(off), err}
522 }
523
View as plain text