1
2
3
4
5
6
7
8
9
10
11 package tlog
12
13 import (
14 "crypto/sha256"
15 "encoding/base64"
16 "errors"
17 "fmt"
18 "math/bits"
19 )
20
21
22 type Hash [HashSize]byte
23
24
25 const HashSize = 32
26
27
28 func (h Hash) String() string {
29 return base64.StdEncoding.EncodeToString(h[:])
30 }
31
32
33 func (h Hash) MarshalJSON() ([]byte, error) {
34 return []byte(`"` + h.String() + `"`), nil
35 }
36
37
38 func (h *Hash) UnmarshalJSON(data []byte) error {
39 if len(data) != 1+44+1 || data[0] != '"' || data[len(data)-2] != '=' || data[len(data)-1] != '"' {
40 return errors.New("cannot decode hash")
41 }
42
43
44
45
46
47
48
49
50 var tmp Hash
51 n, err := base64.RawStdEncoding.Decode(tmp[:], data[1:len(data)-2])
52 if err != nil || n != HashSize {
53 return errors.New("cannot decode hash")
54 }
55 *h = tmp
56 return nil
57 }
58
59
60 func ParseHash(s string) (Hash, error) {
61 data, err := base64.StdEncoding.DecodeString(s)
62 if err != nil || len(data) != HashSize {
63 return Hash{}, fmt.Errorf("malformed hash")
64 }
65 var h Hash
66 copy(h[:], data)
67 return h, nil
68 }
69
70
71
72 func maxpow2(n int64) (k int64, l int) {
73 l = 0
74 for 1<<uint(l+1) < n {
75 l++
76 }
77 return 1 << uint(l), l
78 }
79
80 var zeroPrefix = []byte{0x00}
81
82
83 func RecordHash(data []byte) Hash {
84
85
86 h := sha256.New()
87 h.Write(zeroPrefix)
88 h.Write(data)
89 var h1 Hash
90 h.Sum(h1[:0])
91 return h1
92 }
93
94
95 func NodeHash(left, right Hash) Hash {
96
97
98
99
100 var buf [1 + HashSize + HashSize]byte
101 buf[0] = 0x01
102 copy(buf[1:], left[:])
103 copy(buf[1+HashSize:], right[:])
104 return sha256.Sum256(buf[:])
105 }
106
107
108
109
110
111
112
113
114
115
116
117 func StoredHashIndex(level int, n int64) int64 {
118
119
120
121 for l := level; l > 0; l-- {
122 n = 2*n + 1
123 }
124
125
126 i := int64(0)
127 for ; n > 0; n >>= 1 {
128 i += n
129 }
130
131 return i + int64(level)
132 }
133
134
135
136 func SplitStoredHashIndex(index int64) (level int, n int64) {
137
138
139
140 n = index / 2
141 indexN := StoredHashIndex(0, n)
142 if indexN > index {
143 panic("bad math")
144 }
145 for {
146
147 x := indexN + 1 + int64(bits.TrailingZeros64(uint64(n+1)))
148 if x > index {
149 break
150 }
151 n++
152 indexN = x
153 }
154
155
156 level = int(index - indexN)
157 return level, n >> uint(level)
158 }
159
160
161
162 func StoredHashCount(n int64) int64 {
163 if n == 0 {
164 return 0
165 }
166
167 numHash := StoredHashIndex(0, n-1) + 1
168
169 for i := uint64(n - 1); i&1 != 0; i >>= 1 {
170 numHash++
171 }
172 return numHash
173 }
174
175
176
177
178
179
180
181
182 func StoredHashes(n int64, data []byte, r HashReader) ([]Hash, error) {
183 return StoredHashesForRecordHash(n, RecordHash(data), r)
184 }
185
186
187
188 func StoredHashesForRecordHash(n int64, h Hash, r HashReader) ([]Hash, error) {
189
190 hashes := []Hash{h}
191
192
193
194
195 m := int(bits.TrailingZeros64(uint64(n + 1)))
196 indexes := make([]int64, m)
197 for i := 0; i < m; i++ {
198
199
200 indexes[m-1-i] = StoredHashIndex(i, n>>uint(i)-1)
201 }
202
203
204 old, err := r.ReadHashes(indexes)
205 if err != nil {
206 return nil, err
207 }
208 if len(old) != len(indexes) {
209 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(old))
210 }
211
212
213 for i := 0; i < m; i++ {
214 h = NodeHash(old[m-1-i], h)
215 hashes = append(hashes, h)
216 }
217 return hashes, nil
218 }
219
220
221 type HashReader interface {
222
223
224
225
226
227 ReadHashes(indexes []int64) ([]Hash, error)
228 }
229
230
231 type HashReaderFunc func([]int64) ([]Hash, error)
232
233 func (f HashReaderFunc) ReadHashes(indexes []int64) ([]Hash, error) {
234 return f(indexes)
235 }
236
237
238
239
240
241
242 func TreeHash(n int64, r HashReader) (Hash, error) {
243 if n == 0 {
244 return Hash{}, nil
245 }
246 indexes := subTreeIndex(0, n, nil)
247 hashes, err := r.ReadHashes(indexes)
248 if err != nil {
249 return Hash{}, err
250 }
251 if len(hashes) != len(indexes) {
252 return Hash{}, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
253 }
254 hash, hashes := subTreeHash(0, n, hashes)
255 if len(hashes) != 0 {
256 panic("tlog: bad index math in TreeHash")
257 }
258 return hash, nil
259 }
260
261
262
263
264
265 func subTreeIndex(lo, hi int64, need []int64) []int64 {
266
267 for lo < hi {
268 k, level := maxpow2(hi - lo + 1)
269 if lo&(k-1) != 0 {
270 panic("tlog: bad math in subTreeIndex")
271 }
272 need = append(need, StoredHashIndex(level, lo>>uint(level)))
273 lo += k
274 }
275 return need
276 }
277
278
279
280
281
282 func subTreeHash(lo, hi int64, hashes []Hash) (Hash, []Hash) {
283
284
285
286
287 numTree := 0
288 for lo < hi {
289 k, _ := maxpow2(hi - lo + 1)
290 if lo&(k-1) != 0 || lo >= hi {
291 panic("tlog: bad math in subTreeHash")
292 }
293 numTree++
294 lo += k
295 }
296
297 if len(hashes) < numTree {
298 panic("tlog: bad index math in subTreeHash")
299 }
300
301
302 h := hashes[numTree-1]
303 for i := numTree - 2; i >= 0; i-- {
304 h = NodeHash(hashes[i], h)
305 }
306 return h, hashes[numTree:]
307 }
308
309
310
311 type RecordProof []Hash
312
313
314 func ProveRecord(t, n int64, r HashReader) (RecordProof, error) {
315 if t < 0 || n < 0 || n >= t {
316 return nil, fmt.Errorf("tlog: invalid inputs in ProveRecord")
317 }
318 indexes := leafProofIndex(0, t, n, nil)
319 if len(indexes) == 0 {
320 return RecordProof{}, nil
321 }
322 hashes, err := r.ReadHashes(indexes)
323 if err != nil {
324 return nil, err
325 }
326 if len(hashes) != len(indexes) {
327 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
328 }
329
330 p, hashes := leafProof(0, t, n, hashes)
331 if len(hashes) != 0 {
332 panic("tlog: bad index math in ProveRecord")
333 }
334 return p, nil
335 }
336
337
338
339
340
341 func leafProofIndex(lo, hi, n int64, need []int64) []int64 {
342
343 if !(lo <= n && n < hi) {
344 panic("tlog: bad math in leafProofIndex")
345 }
346 if lo+1 == hi {
347 return need
348 }
349 if k, _ := maxpow2(hi - lo); n < lo+k {
350 need = leafProofIndex(lo, lo+k, n, need)
351 need = subTreeIndex(lo+k, hi, need)
352 } else {
353 need = subTreeIndex(lo, lo+k, need)
354 need = leafProofIndex(lo+k, hi, n, need)
355 }
356 return need
357 }
358
359
360
361
362 func leafProof(lo, hi, n int64, hashes []Hash) (RecordProof, []Hash) {
363
364 if !(lo <= n && n < hi) {
365 panic("tlog: bad math in leafProof")
366 }
367
368 if lo+1 == hi {
369
370
371 return RecordProof{}, hashes
372 }
373
374
375
376 var p RecordProof
377 var th Hash
378 if k, _ := maxpow2(hi - lo); n < lo+k {
379
380 p, hashes = leafProof(lo, lo+k, n, hashes)
381 th, hashes = subTreeHash(lo+k, hi, hashes)
382 } else {
383
384 th, hashes = subTreeHash(lo, lo+k, hashes)
385 p, hashes = leafProof(lo+k, hi, n, hashes)
386 }
387 return append(p, th), hashes
388 }
389
390 var errProofFailed = errors.New("invalid transparency proof")
391
392
393
394 func CheckRecord(p RecordProof, t int64, th Hash, n int64, h Hash) error {
395 if t < 0 || n < 0 || n >= t {
396 return fmt.Errorf("tlog: invalid inputs in CheckRecord")
397 }
398 th2, err := runRecordProof(p, 0, t, n, h)
399 if err != nil {
400 return err
401 }
402 if th2 == th {
403 return nil
404 }
405 return errProofFailed
406 }
407
408
409
410
411 func runRecordProof(p RecordProof, lo, hi, n int64, leafHash Hash) (Hash, error) {
412
413 if !(lo <= n && n < hi) {
414 panic("tlog: bad math in runRecordProof")
415 }
416
417 if lo+1 == hi {
418
419
420 if len(p) != 0 {
421 return Hash{}, errProofFailed
422 }
423 return leafHash, nil
424 }
425
426 if len(p) == 0 {
427 return Hash{}, errProofFailed
428 }
429
430 k, _ := maxpow2(hi - lo)
431 if n < lo+k {
432 th, err := runRecordProof(p[:len(p)-1], lo, lo+k, n, leafHash)
433 if err != nil {
434 return Hash{}, err
435 }
436 return NodeHash(th, p[len(p)-1]), nil
437 } else {
438 th, err := runRecordProof(p[:len(p)-1], lo+k, hi, n, leafHash)
439 if err != nil {
440 return Hash{}, err
441 }
442 return NodeHash(p[len(p)-1], th), nil
443 }
444 }
445
446
447
448
449 type TreeProof []Hash
450
451
452
453 func ProveTree(t, n int64, h HashReader) (TreeProof, error) {
454 if t < 1 || n < 1 || n > t {
455 return nil, fmt.Errorf("tlog: invalid inputs in ProveTree")
456 }
457 indexes := treeProofIndex(0, t, n, nil)
458 if len(indexes) == 0 {
459 return TreeProof{}, nil
460 }
461 hashes, err := h.ReadHashes(indexes)
462 if err != nil {
463 return nil, err
464 }
465 if len(hashes) != len(indexes) {
466 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
467 }
468
469 p, hashes := treeProof(0, t, n, hashes)
470 if len(hashes) != 0 {
471 panic("tlog: bad index math in ProveTree")
472 }
473 return p, nil
474 }
475
476
477
478
479 func treeProofIndex(lo, hi, n int64, need []int64) []int64 {
480
481 if !(lo < n && n <= hi) {
482 panic("tlog: bad math in treeProofIndex")
483 }
484
485 if n == hi {
486 if lo == 0 {
487 return need
488 }
489 return subTreeIndex(lo, hi, need)
490 }
491
492 if k, _ := maxpow2(hi - lo); n <= lo+k {
493 need = treeProofIndex(lo, lo+k, n, need)
494 need = subTreeIndex(lo+k, hi, need)
495 } else {
496 need = subTreeIndex(lo, lo+k, need)
497 need = treeProofIndex(lo+k, hi, n, need)
498 }
499 return need
500 }
501
502
503
504
505 func treeProof(lo, hi, n int64, hashes []Hash) (TreeProof, []Hash) {
506
507 if !(lo < n && n <= hi) {
508 panic("tlog: bad math in treeProof")
509 }
510
511
512 if n == hi {
513 if lo == 0 {
514
515
516 return TreeProof{}, hashes
517 }
518 th, hashes := subTreeHash(lo, hi, hashes)
519 return TreeProof{th}, hashes
520 }
521
522
523
524 var p TreeProof
525 var th Hash
526 if k, _ := maxpow2(hi - lo); n <= lo+k {
527
528 p, hashes = treeProof(lo, lo+k, n, hashes)
529 th, hashes = subTreeHash(lo+k, hi, hashes)
530 } else {
531
532 th, hashes = subTreeHash(lo, lo+k, hashes)
533 p, hashes = treeProof(lo+k, hi, n, hashes)
534 }
535 return append(p, th), hashes
536 }
537
538
539
540 func CheckTree(p TreeProof, t int64, th Hash, n int64, h Hash) error {
541 if t < 1 || n < 1 || n > t {
542 return fmt.Errorf("tlog: invalid inputs in CheckTree")
543 }
544 h2, th2, err := runTreeProof(p, 0, t, n, h)
545 if err != nil {
546 return err
547 }
548 if th2 == th && h2 == h {
549 return nil
550 }
551 return errProofFailed
552 }
553
554
555
556
557
558 func runTreeProof(p TreeProof, lo, hi, n int64, old Hash) (Hash, Hash, error) {
559
560 if !(lo < n && n <= hi) {
561 panic("tlog: bad math in runTreeProof")
562 }
563
564
565 if n == hi {
566 if lo == 0 {
567 if len(p) != 0 {
568 return Hash{}, Hash{}, errProofFailed
569 }
570 return old, old, nil
571 }
572 if len(p) != 1 {
573 return Hash{}, Hash{}, errProofFailed
574 }
575 return p[0], p[0], nil
576 }
577
578 if len(p) == 0 {
579 return Hash{}, Hash{}, errProofFailed
580 }
581
582
583 k, _ := maxpow2(hi - lo)
584 if n <= lo+k {
585 oh, th, err := runTreeProof(p[:len(p)-1], lo, lo+k, n, old)
586 if err != nil {
587 return Hash{}, Hash{}, err
588 }
589 return oh, NodeHash(th, p[len(p)-1]), nil
590 } else {
591 oh, th, err := runTreeProof(p[:len(p)-1], lo+k, hi, n, old)
592 if err != nil {
593 return Hash{}, Hash{}, err
594 }
595 return NodeHash(p[len(p)-1], oh), NodeHash(p[len(p)-1], th), nil
596 }
597 }
598
View as plain text