...
Source file
src/image/jpeg/huffman.go
1
2
3
4
5 package jpeg
6
7 import (
8 "io"
9 )
10
11
12 const maxCodeLength = 16
13
14
15 const maxNCodes = 256
16
17
18 const lutSize = 8
19
20
21 type huffman struct {
22
23 nCodes int32
24
25
26
27
28 lut [1 << lutSize]uint16
29
30 vals [maxNCodes]uint8
31
32
33 minCodes [maxCodeLength]int32
34
35
36 maxCodes [maxCodeLength]int32
37
38 valsIndices [maxCodeLength]int32
39 }
40
41
42
43 var errShortHuffmanData = FormatError("short Huffman data")
44
45
46
47
48 func (d *decoder) ensureNBits(n int32) error {
49 for {
50 c, err := d.readByteStuffedByte()
51 if err != nil {
52 if err == io.ErrUnexpectedEOF {
53 return errShortHuffmanData
54 }
55 return err
56 }
57 d.bits.a = d.bits.a<<8 | uint32(c)
58 d.bits.n += 8
59 if d.bits.m == 0 {
60 d.bits.m = 1 << 7
61 } else {
62 d.bits.m <<= 8
63 }
64 if d.bits.n >= n {
65 break
66 }
67 }
68 return nil
69 }
70
71
72
73 func (d *decoder) receiveExtend(t uint8) (int32, error) {
74 if d.bits.n < int32(t) {
75 if err := d.ensureNBits(int32(t)); err != nil {
76 return 0, err
77 }
78 }
79 d.bits.n -= int32(t)
80 d.bits.m >>= t
81 s := int32(1) << t
82 x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1)
83 if x < s>>1 {
84 x += ((-1) << t) + 1
85 }
86 return x, nil
87 }
88
89
90
91 func (d *decoder) processDHT(n int) error {
92 for n > 0 {
93 if n < 17 {
94 return FormatError("DHT has wrong length")
95 }
96 if err := d.readFull(d.tmp[:17]); err != nil {
97 return err
98 }
99 tc := d.tmp[0] >> 4
100 if tc > maxTc {
101 return FormatError("bad Tc value")
102 }
103 th := d.tmp[0] & 0x0f
104
105 if th > maxTh || (d.baseline && th > 1) {
106 return FormatError("bad Th value")
107 }
108 h := &d.huff[tc][th]
109
110
111
112
113 h.nCodes = 0
114 var nCodes [maxCodeLength]int32
115 for i := range nCodes {
116 nCodes[i] = int32(d.tmp[i+1])
117 h.nCodes += nCodes[i]
118 }
119 if h.nCodes == 0 {
120 return FormatError("Huffman table has zero length")
121 }
122 if h.nCodes > maxNCodes {
123 return FormatError("Huffman table has excessive length")
124 }
125 n -= int(h.nCodes) + 17
126 if n < 0 {
127 return FormatError("DHT has wrong length")
128 }
129 if err := d.readFull(h.vals[:h.nCodes]); err != nil {
130 return err
131 }
132
133
134 for i := range h.lut {
135 h.lut[i] = 0
136 }
137 var x, code uint32
138 for i := uint32(0); i < lutSize; i++ {
139 code <<= 1
140 for j := int32(0); j < nCodes[i]; j++ {
141
142
143
144
145
146 base := uint8(code << (7 - i))
147 lutValue := uint16(h.vals[x])<<8 | uint16(2+i)
148 for k := uint8(0); k < 1<<(7-i); k++ {
149 h.lut[base|k] = lutValue
150 }
151 code++
152 x++
153 }
154 }
155
156
157 var c, index int32
158 for i, n := range nCodes {
159 if n == 0 {
160 h.minCodes[i] = -1
161 h.maxCodes[i] = -1
162 h.valsIndices[i] = -1
163 } else {
164 h.minCodes[i] = c
165 h.maxCodes[i] = c + n - 1
166 h.valsIndices[i] = index
167 c += n
168 index += n
169 }
170 c <<= 1
171 }
172 }
173 return nil
174 }
175
176
177
178 func (d *decoder) decodeHuffman(h *huffman) (uint8, error) {
179 if h.nCodes == 0 {
180 return 0, FormatError("uninitialized Huffman table")
181 }
182
183 if d.bits.n < 8 {
184 if err := d.ensureNBits(8); err != nil {
185 if err != errMissingFF00 && err != errShortHuffmanData {
186 return 0, err
187 }
188
189
190
191 if d.bytes.nUnreadable != 0 {
192 d.unreadByteStuffedByte()
193 }
194 goto slowPath
195 }
196 }
197 if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 {
198 n := (v & 0xff) - 1
199 d.bits.n -= int32(n)
200 d.bits.m >>= n
201 return uint8(v >> 8), nil
202 }
203
204 slowPath:
205 for i, code := 0, int32(0); i < maxCodeLength; i++ {
206 if d.bits.n == 0 {
207 if err := d.ensureNBits(1); err != nil {
208 return 0, err
209 }
210 }
211 if d.bits.a&d.bits.m != 0 {
212 code |= 1
213 }
214 d.bits.n--
215 d.bits.m >>= 1
216 if code <= h.maxCodes[i] {
217 return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil
218 }
219 code <<= 1
220 }
221 return 0, FormatError("bad Huffman code")
222 }
223
224 func (d *decoder) decodeBit() (bool, error) {
225 if d.bits.n == 0 {
226 if err := d.ensureNBits(1); err != nil {
227 return false, err
228 }
229 }
230 ret := d.bits.a&d.bits.m != 0
231 d.bits.n--
232 d.bits.m >>= 1
233 return ret, nil
234 }
235
236 func (d *decoder) decodeBits(n int32) (uint32, error) {
237 if d.bits.n < n {
238 if err := d.ensureNBits(n); err != nil {
239 return 0, err
240 }
241 }
242 ret := d.bits.a >> uint32(d.bits.n-n)
243 ret &= (1 << uint32(n)) - 1
244 d.bits.n -= n
245 d.bits.m >>= uint32(n)
246 return ret, nil
247 }
248
View as plain text