1
2
3
4
5
6
7
8
9
10
11 package protopack
12
13 import (
14 "fmt"
15 "io"
16 "math"
17 "path"
18 "reflect"
19 "strconv"
20 "strings"
21 "unicode"
22 "unicode/utf8"
23
24 "google.golang.org/protobuf/encoding/protowire"
25 "google.golang.org/protobuf/reflect/protoreflect"
26 )
27
28
29 type Number = protowire.Number
30
31
32 const (
33 MinValidNumber Number = protowire.MinValidNumber
34 FirstReservedNumber Number = protowire.FirstReservedNumber
35 LastReservedNumber Number = protowire.LastReservedNumber
36 MaxValidNumber Number = protowire.MaxValidNumber
37 )
38
39
40 type Type = protowire.Type
41
42
43 const (
44 VarintType Type = protowire.VarintType
45 Fixed32Type Type = protowire.Fixed32Type
46 Fixed64Type Type = protowire.Fixed64Type
47 BytesType Type = protowire.BytesType
48 StartGroupType Type = protowire.StartGroupType
49 EndGroupType Type = protowire.EndGroupType
50 )
51
52 type (
53
54 Token token
55
56
57
58 Message []Token
59
60
61 Tag struct {
62 Number Number
63 Type Type
64 }
65
66 Bool bool
67
68 Varint int64
69
70 Svarint int64
71
72 Uvarint uint64
73
74
75 Int32 int32
76
77 Uint32 uint32
78
79 Float32 float32
80
81
82 Int64 int64
83
84 Uint64 uint64
85
86 Float64 float64
87
88
89 String string
90
91 Bytes []byte
92
93 LengthPrefix Message
94
95
96
97
98
99
100
101
102
103
104 Denormalized struct {
105 Count uint
106 Value Token
107 }
108
109
110 Raw []byte
111 )
112
113 type token interface {
114 isToken()
115 }
116
117 func (Message) isToken() {}
118 func (Tag) isToken() {}
119 func (Bool) isToken() {}
120 func (Varint) isToken() {}
121 func (Svarint) isToken() {}
122 func (Uvarint) isToken() {}
123 func (Int32) isToken() {}
124 func (Uint32) isToken() {}
125 func (Float32) isToken() {}
126 func (Int64) isToken() {}
127 func (Uint64) isToken() {}
128 func (Float64) isToken() {}
129 func (String) isToken() {}
130 func (Bytes) isToken() {}
131 func (LengthPrefix) isToken() {}
132 func (Denormalized) isToken() {}
133 func (Raw) isToken() {}
134
135
136 func (m Message) Size() int {
137 var n int
138 for _, v := range m {
139 switch v := v.(type) {
140 case Message:
141 n += v.Size()
142 case Tag:
143 n += protowire.SizeTag(v.Number)
144 case Bool:
145 n += protowire.SizeVarint(protowire.EncodeBool(false))
146 case Varint:
147 n += protowire.SizeVarint(uint64(v))
148 case Svarint:
149 n += protowire.SizeVarint(protowire.EncodeZigZag(int64(v)))
150 case Uvarint:
151 n += protowire.SizeVarint(uint64(v))
152 case Int32, Uint32, Float32:
153 n += protowire.SizeFixed32()
154 case Int64, Uint64, Float64:
155 n += protowire.SizeFixed64()
156 case String:
157 n += protowire.SizeBytes(len(v))
158 case Bytes:
159 n += protowire.SizeBytes(len(v))
160 case LengthPrefix:
161 n += protowire.SizeBytes(Message(v).Size())
162 case Denormalized:
163 n += int(v.Count) + Message{v.Value}.Size()
164 case Raw:
165 n += len(v)
166 default:
167 panic(fmt.Sprintf("unknown type: %T", v))
168 }
169 }
170 return n
171 }
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198 func (m Message) Marshal() []byte {
199 var out []byte
200 for _, v := range m {
201 switch v := v.(type) {
202 case Message:
203 out = append(out, v.Marshal()...)
204 case Tag:
205 out = protowire.AppendTag(out, v.Number, v.Type)
206 case Bool:
207 out = protowire.AppendVarint(out, protowire.EncodeBool(bool(v)))
208 case Varint:
209 out = protowire.AppendVarint(out, uint64(v))
210 case Svarint:
211 out = protowire.AppendVarint(out, protowire.EncodeZigZag(int64(v)))
212 case Uvarint:
213 out = protowire.AppendVarint(out, uint64(v))
214 case Int32:
215 out = protowire.AppendFixed32(out, uint32(v))
216 case Uint32:
217 out = protowire.AppendFixed32(out, uint32(v))
218 case Float32:
219 out = protowire.AppendFixed32(out, math.Float32bits(float32(v)))
220 case Int64:
221 out = protowire.AppendFixed64(out, uint64(v))
222 case Uint64:
223 out = protowire.AppendFixed64(out, uint64(v))
224 case Float64:
225 out = protowire.AppendFixed64(out, math.Float64bits(float64(v)))
226 case String:
227 out = protowire.AppendBytes(out, []byte(v))
228 case Bytes:
229 out = protowire.AppendBytes(out, []byte(v))
230 case LengthPrefix:
231 out = protowire.AppendBytes(out, Message(v).Marshal())
232 case Denormalized:
233 b := Message{v.Value}.Marshal()
234 _, n := protowire.ConsumeVarint(b)
235 out = append(out, b[:n]...)
236 for i := uint(0); i < v.Count; i++ {
237 out[len(out)-1] |= 0x80
238 out = append(out, 0)
239 }
240 out = append(out, b[n:]...)
241 case Raw:
242 return append(out, v...)
243 default:
244 panic(fmt.Sprintf("unknown type: %T", v))
245 }
246 }
247 return out
248 }
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277 func (m *Message) Unmarshal(in []byte) {
278 m.unmarshal(in, nil, false)
279 }
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297 func (m *Message) UnmarshalDescriptor(in []byte, desc protoreflect.MessageDescriptor) {
298 m.unmarshal(in, desc, false)
299 }
300
301
302
303
304
305
306
307
308 func (m *Message) UnmarshalAbductive(in []byte, desc protoreflect.MessageDescriptor) {
309 m.unmarshal(in, desc, true)
310 }
311
312 func (m *Message) unmarshal(in []byte, desc protoreflect.MessageDescriptor, inferMessage bool) {
313 p := parser{in: in, out: *m}
314 p.parseMessage(desc, false, inferMessage)
315 *m = p.out
316 }
317
318 type parser struct {
319 in []byte
320 out []Token
321
322 invalid bool
323 }
324
325 func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group, inferMessage bool) {
326 for len(p.in) > 0 {
327 v, n := protowire.ConsumeVarint(p.in)
328 num, typ := protowire.DecodeTag(v)
329 if n < 0 || num <= 0 || v > math.MaxUint32 {
330 p.out, p.in = append(p.out, Raw(p.in)), nil
331 p.invalid = true
332 return
333 }
334 if typ == EndGroupType && group {
335 return
336 }
337 p.out, p.in = append(p.out, Tag{num, typ}), p.in[n:]
338 if m := n - protowire.SizeVarint(v); m > 0 {
339 p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
340 }
341
342
343 var isPacked bool
344 var kind protoreflect.Kind
345 var subDesc protoreflect.MessageDescriptor
346 if msgDesc != nil && !msgDesc.IsPlaceholder() {
347 if fieldDesc := msgDesc.Fields().ByNumber(num); fieldDesc != nil {
348 isPacked = fieldDesc.IsPacked()
349 kind = fieldDesc.Kind()
350 switch kind {
351 case protoreflect.MessageKind, protoreflect.GroupKind:
352 subDesc = fieldDesc.Message()
353 if subDesc == nil || subDesc.IsPlaceholder() {
354 kind = 0
355 }
356 }
357 }
358 }
359
360 switch typ {
361 case VarintType:
362 p.parseVarint(kind)
363 case Fixed32Type:
364 p.parseFixed32(kind)
365 case Fixed64Type:
366 p.parseFixed64(kind)
367 case BytesType:
368 p.parseBytes(isPacked, kind, subDesc, inferMessage)
369 case StartGroupType:
370 p.parseGroup(num, subDesc, inferMessage)
371 case EndGroupType:
372
373 default:
374 p.out, p.in = append(p.out, Raw(p.in)), nil
375 p.invalid = true
376 }
377 }
378 }
379
380 func (p *parser) parseVarint(kind protoreflect.Kind) {
381 v, n := protowire.ConsumeVarint(p.in)
382 if n < 0 {
383 p.out, p.in = append(p.out, Raw(p.in)), nil
384 p.invalid = true
385 return
386 }
387 switch kind {
388 case protoreflect.BoolKind:
389 switch v {
390 case 0:
391 p.out, p.in = append(p.out, Bool(false)), p.in[n:]
392 case 1:
393 p.out, p.in = append(p.out, Bool(true)), p.in[n:]
394 default:
395 p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
396 }
397 case protoreflect.Int32Kind, protoreflect.Int64Kind:
398 p.out, p.in = append(p.out, Varint(v)), p.in[n:]
399 case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
400 p.out, p.in = append(p.out, Svarint(protowire.DecodeZigZag(v))), p.in[n:]
401 default:
402 p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
403 }
404 if m := n - protowire.SizeVarint(v); m > 0 {
405 p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
406 }
407 }
408
409 func (p *parser) parseFixed32(kind protoreflect.Kind) {
410 v, n := protowire.ConsumeFixed32(p.in)
411 if n < 0 {
412 p.out, p.in = append(p.out, Raw(p.in)), nil
413 p.invalid = true
414 return
415 }
416 switch kind {
417 case protoreflect.FloatKind:
418 p.out, p.in = append(p.out, Float32(math.Float32frombits(v))), p.in[n:]
419 case protoreflect.Sfixed32Kind:
420 p.out, p.in = append(p.out, Int32(v)), p.in[n:]
421 default:
422 p.out, p.in = append(p.out, Uint32(v)), p.in[n:]
423 }
424 }
425
426 func (p *parser) parseFixed64(kind protoreflect.Kind) {
427 v, n := protowire.ConsumeFixed64(p.in)
428 if n < 0 {
429 p.out, p.in = append(p.out, Raw(p.in)), nil
430 p.invalid = true
431 return
432 }
433 switch kind {
434 case protoreflect.DoubleKind:
435 p.out, p.in = append(p.out, Float64(math.Float64frombits(v))), p.in[n:]
436 case protoreflect.Sfixed64Kind:
437 p.out, p.in = append(p.out, Int64(v)), p.in[n:]
438 default:
439 p.out, p.in = append(p.out, Uint64(v)), p.in[n:]
440 }
441 }
442
443 func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor, inferMessage bool) {
444 v, n := protowire.ConsumeVarint(p.in)
445 if n < 0 {
446 p.out, p.in = append(p.out, Raw(p.in)), nil
447 p.invalid = true
448 return
449 }
450 p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
451 if m := n - protowire.SizeVarint(v); m > 0 {
452 p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
453 }
454 if v > uint64(len(p.in)) {
455 p.out, p.in = append(p.out, Raw(p.in)), nil
456 p.invalid = true
457 return
458 }
459 p.out = p.out[:len(p.out)-1]
460
461 if isPacked {
462 p.parsePacked(int(v), kind)
463 } else {
464 switch kind {
465 case protoreflect.MessageKind:
466 p2 := parser{in: p.in[:v]}
467 p2.parseMessage(desc, false, inferMessage)
468 p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
469 case protoreflect.StringKind:
470 p.out, p.in = append(p.out, String(p.in[:v])), p.in[v:]
471 case protoreflect.BytesKind:
472 p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
473 default:
474 if inferMessage {
475
476 p2 := parser{in: p.in[:v]}
477 p2.parseMessage(nil, false, inferMessage)
478 if !p2.invalid {
479 p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
480 break
481 }
482 }
483 p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
484 }
485 }
486 if m := n - protowire.SizeVarint(v); m > 0 {
487 p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
488 }
489 }
490
491 func (p *parser) parsePacked(n int, kind protoreflect.Kind) {
492 p2 := parser{in: p.in[:n]}
493 for len(p2.in) > 0 {
494 switch kind {
495 case protoreflect.BoolKind, protoreflect.EnumKind,
496 protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind,
497 protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind:
498 p2.parseVarint(kind)
499 case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind:
500 p2.parseFixed32(kind)
501 case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind:
502 p2.parseFixed64(kind)
503 default:
504 panic(fmt.Sprintf("invalid packed kind: %v", kind))
505 }
506 }
507 p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[n:]
508 }
509
510 func (p *parser) parseGroup(startNum protowire.Number, desc protoreflect.MessageDescriptor, inferMessage bool) {
511 p2 := parser{in: p.in}
512 p2.parseMessage(desc, true, inferMessage)
513 if len(p2.out) > 0 {
514 p.out = append(p.out, Message(p2.out))
515 }
516 p.in = p2.in
517
518
519 v, n := protowire.ConsumeVarint(p.in)
520 if endNum, typ := protowire.DecodeTag(v); typ == EndGroupType {
521 if startNum != endNum {
522 p.invalid = true
523 }
524 p.out, p.in = append(p.out, Tag{endNum, typ}), p.in[n:]
525 if m := n - protowire.SizeVarint(v); m > 0 {
526 p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
527 }
528 }
529 }
530
531
532
533 func (m Message) Format(s fmt.State, r rune) {
534 switch r {
535 case 'x':
536 io.WriteString(s, fmt.Sprintf("%x", m.Marshal()))
537 case 'X':
538 io.WriteString(s, fmt.Sprintf("%X", m.Marshal()))
539 case 'v':
540 switch {
541 case s.Flag('#'):
542 io.WriteString(s, m.format(true, true))
543 case s.Flag('+'):
544 io.WriteString(s, m.format(false, true))
545 default:
546 io.WriteString(s, m.format(false, false))
547 }
548 default:
549 panic("invalid verb: " + string(r))
550 }
551 }
552
553
554
555
556 func (m Message) format(source, multi bool) string {
557 var ss []string
558 var prefix, nextPrefix string
559 for _, v := range m {
560
561 prefix, nextPrefix = nextPrefix, " "
562 if multi {
563 switch v := v.(type) {
564 case Tag:
565 prefix = "\n"
566 case Denormalized:
567 if _, ok := v.Value.(Tag); ok {
568 prefix = "\n"
569 }
570 case Message, Raw:
571 prefix, nextPrefix = "\n", "\n"
572 }
573 }
574
575 s := formatToken(v, source, multi)
576 ss = append(ss, prefix+s+",")
577 }
578
579 var s string
580 if len(ss) > 0 {
581 s = strings.TrimSpace(strings.Join(ss, ""))
582 if multi {
583 s = "\n\t" + strings.Join(strings.Split(s, "\n"), "\n\t") + "\n"
584 } else {
585 s = strings.TrimSuffix(s, ",")
586 }
587 }
588 s = fmt.Sprintf("%T{%s}", m, s)
589 if !source {
590 s = trimPackage(s)
591 }
592 return s
593 }
594
595
596 func formatToken(t Token, source, multi bool) (s string) {
597 switch v := t.(type) {
598 case Message:
599 s = v.format(source, multi)
600 case LengthPrefix:
601 s = formatPacked(v, source, multi)
602 if s == "" {
603 ms := Message(v).format(source, multi)
604 s = fmt.Sprintf("%T(%s)", v, ms)
605 }
606 case Tag:
607 s = fmt.Sprintf("%T{%d, %s}", v, v.Number, formatType(v.Type, source))
608 case Bool, Varint, Svarint, Uvarint, Int32, Uint32, Float32, Int64, Uint64, Float64:
609 if source {
610
611 if f, _ := v.(Float32); math.IsNaN(float64(f)) || math.IsInf(float64(f), 0) {
612 switch {
613 case f > 0:
614 s = fmt.Sprintf("%T(math.Inf(+1))", v)
615 case f < 0:
616 s = fmt.Sprintf("%T(math.Inf(-1))", v)
617 case math.Float32bits(float32(math.NaN())) == math.Float32bits(float32(f)):
618 s = fmt.Sprintf("%T(math.NaN())", v)
619 default:
620 s = fmt.Sprintf("%T(math.Float32frombits(0x%08x))", v, math.Float32bits(float32(f)))
621 }
622 break
623 }
624 if f, _ := v.(Float64); math.IsNaN(float64(f)) || math.IsInf(float64(f), 0) {
625 switch {
626 case f > 0:
627 s = fmt.Sprintf("%T(math.Inf(+1))", v)
628 case f < 0:
629 s = fmt.Sprintf("%T(math.Inf(-1))", v)
630 case math.Float64bits(float64(math.NaN())) == math.Float64bits(float64(f)):
631 s = fmt.Sprintf("%T(math.NaN())", v)
632 default:
633 s = fmt.Sprintf("%T(math.Float64frombits(0x%016x))", v, math.Float64bits(float64(f)))
634 }
635 break
636 }
637 }
638 s = fmt.Sprintf("%T(%v)", v, v)
639 case String, Bytes, Raw:
640 s = fmt.Sprintf("%s", v)
641 s = fmt.Sprintf("%T(%s)", v, formatString(s))
642 case Denormalized:
643 s = fmt.Sprintf("%T{+%d, %v}", v, v.Count, formatToken(v.Value, source, multi))
644 default:
645 panic(fmt.Sprintf("unknown type: %T", v))
646 }
647 if !source {
648 s = trimPackage(s)
649 }
650 return s
651 }
652
653
654
655 func formatPacked(v LengthPrefix, source, multi bool) string {
656 var ss []string
657 for _, v := range v {
658 switch v.(type) {
659 case Bool, Varint, Svarint, Uvarint, Int32, Uint32, Float32, Int64, Uint64, Float64, Denormalized, Raw:
660 if v, ok := v.(Denormalized); ok {
661 switch v.Value.(type) {
662 case Bool, Varint, Svarint, Uvarint:
663 default:
664 return ""
665 }
666 }
667 ss = append(ss, formatToken(v, source, multi))
668 default:
669 return ""
670 }
671 }
672 s := fmt.Sprintf("%T{%s}", v, strings.Join(ss, ", "))
673 if !source {
674 s = trimPackage(s)
675 }
676 return s
677 }
678
679
680 func formatType(t Type, source bool) (s string) {
681 switch t {
682 case VarintType:
683 s = pkg + ".VarintType"
684 case Fixed32Type:
685 s = pkg + ".Fixed32Type"
686 case Fixed64Type:
687 s = pkg + ".Fixed64Type"
688 case BytesType:
689 s = pkg + ".BytesType"
690 case StartGroupType:
691 s = pkg + ".StartGroupType"
692 case EndGroupType:
693 s = pkg + ".EndGroupType"
694 default:
695 s = fmt.Sprintf("Type(%d)", t)
696 }
697 if !source {
698 s = strings.TrimSuffix(trimPackage(s), "Type")
699 }
700 return s
701 }
702
703
704 func formatString(s string) string {
705
706
707 qs := strconv.Quote(s)
708 if len(qs) == 1+len(s)+1 {
709 return qs
710 }
711
712
713
714 rawInvalid := func(r rune) bool {
715 return r == '`' || r == '\n' || r == utf8.RuneError || !unicode.IsPrint(r)
716 }
717 if strings.IndexFunc(s, rawInvalid) < 0 {
718 return "`" + s + "`"
719 }
720 return qs
721 }
722
723 var pkg = path.Base(reflect.TypeOf(Tag{}).PkgPath())
724
725 func trimPackage(s string) string {
726 return strings.TrimPrefix(strings.TrimPrefix(s, pkg), ".")
727 }
728
View as plain text