...

Source file src/google.golang.org/protobuf/testing/protopack/pack.go

Documentation: google.golang.org/protobuf/testing/protopack

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package protopack enables manual encoding and decoding of protobuf wire data.
     6  //
     7  // This package is intended for use in debugging and/or creation of test data.
     8  // Proper usage of this package requires knowledge of the wire format.
     9  //
    10  // See https://protobuf.dev/programming-guides/encoding.
    11  package protopack
    12  
    13  import (
    14  	"fmt"
    15  	"io"
    16  	"math"
    17  	"path"
    18  	"reflect"
    19  	"strconv"
    20  	"strings"
    21  	"unicode"
    22  	"unicode/utf8"
    23  
    24  	"google.golang.org/protobuf/encoding/protowire"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  )
    27  
    28  // Number is the field number; aliased from the [protowire] package for convenience.
    29  type Number = protowire.Number
    30  
    31  // Number type constants; copied from the [protowire] package for convenience.
    32  const (
    33  	MinValidNumber      Number = protowire.MinValidNumber
    34  	FirstReservedNumber Number = protowire.FirstReservedNumber
    35  	LastReservedNumber  Number = protowire.LastReservedNumber
    36  	MaxValidNumber      Number = protowire.MaxValidNumber
    37  )
    38  
    39  // Type is the wire type; aliased from the [protowire] package for convenience.
    40  type Type = protowire.Type
    41  
    42  // Wire type constants; copied from the [protowire] package for convenience.
    43  const (
    44  	VarintType     Type = protowire.VarintType
    45  	Fixed32Type    Type = protowire.Fixed32Type
    46  	Fixed64Type    Type = protowire.Fixed64Type
    47  	BytesType      Type = protowire.BytesType
    48  	StartGroupType Type = protowire.StartGroupType
    49  	EndGroupType   Type = protowire.EndGroupType
    50  )
    51  
    52  type (
    53  	// Token is any other type (e.g., [Message], [Tag], [Varint], [Float32], etc).
    54  	Token token
    55  	// Message is an ordered sequence of [Token] values, where certain tokens may
    56  	// contain other tokens. It is functionally a concrete syntax tree that
    57  	// losslessly represents any arbitrary wire data (including invalid input).
    58  	Message []Token
    59  
    60  	// Tag is a tuple of the field number and the wire type.
    61  	Tag struct {
    62  		Number Number
    63  		Type   Type
    64  	}
    65  	// Bool is a boolean.
    66  	Bool bool
    67  	// Varint is a signed varint using 64-bit two's complement encoding.
    68  	Varint int64
    69  	// Svarint is a signed varint using zig-zag encoding.
    70  	Svarint int64
    71  	// Uvarint is a unsigned varint.
    72  	Uvarint uint64
    73  
    74  	// Int32 is a signed 32-bit fixed-width integer.
    75  	Int32 int32
    76  	// Uint32 is an unsigned 32-bit fixed-width integer.
    77  	Uint32 uint32
    78  	// Float32 is a 32-bit fixed-width floating point number.
    79  	Float32 float32
    80  
    81  	// Int64 is a signed 64-bit fixed-width integer.
    82  	Int64 int64
    83  	// Uint64 is an unsigned 64-bit fixed-width integer.
    84  	Uint64 uint64
    85  	// Float64 is a 64-bit fixed-width floating point number.
    86  	Float64 float64
    87  
    88  	// String is a length-prefixed string.
    89  	String string
    90  	// Bytes is a length-prefixed bytes.
    91  	Bytes []byte
    92  	// LengthPrefix is a length-prefixed message.
    93  	LengthPrefix Message
    94  
    95  	// Denormalized is a denormalized varint value, where a varint is encoded
    96  	// using more bytes than is strictly necessary. The number of extra bytes
    97  	// alone is sufficient to losslessly represent the denormalized varint.
    98  	//
    99  	// The value may be one of [Tag], [Bool], [Varint], [Svarint], or [Uvarint],
   100  	// where the varint representation of each token is denormalized.
   101  	//
   102  	// Alternatively, the value may be one of [String], [Bytes], or [LengthPrefix],
   103  	// where the varint representation of the length-prefix is denormalized.
   104  	Denormalized struct {
   105  		Count uint // number of extra bytes
   106  		Value Token
   107  	}
   108  
   109  	// Raw are bytes directly appended to output.
   110  	Raw []byte
   111  )
   112  
   113  type token interface {
   114  	isToken()
   115  }
   116  
   117  func (Message) isToken()      {}
   118  func (Tag) isToken()          {}
   119  func (Bool) isToken()         {}
   120  func (Varint) isToken()       {}
   121  func (Svarint) isToken()      {}
   122  func (Uvarint) isToken()      {}
   123  func (Int32) isToken()        {}
   124  func (Uint32) isToken()       {}
   125  func (Float32) isToken()      {}
   126  func (Int64) isToken()        {}
   127  func (Uint64) isToken()       {}
   128  func (Float64) isToken()      {}
   129  func (String) isToken()       {}
   130  func (Bytes) isToken()        {}
   131  func (LengthPrefix) isToken() {}
   132  func (Denormalized) isToken() {}
   133  func (Raw) isToken()          {}
   134  
   135  // Size reports the size in bytes of the marshaled message.
   136  func (m Message) Size() int {
   137  	var n int
   138  	for _, v := range m {
   139  		switch v := v.(type) {
   140  		case Message:
   141  			n += v.Size()
   142  		case Tag:
   143  			n += protowire.SizeTag(v.Number)
   144  		case Bool:
   145  			n += protowire.SizeVarint(protowire.EncodeBool(false))
   146  		case Varint:
   147  			n += protowire.SizeVarint(uint64(v))
   148  		case Svarint:
   149  			n += protowire.SizeVarint(protowire.EncodeZigZag(int64(v)))
   150  		case Uvarint:
   151  			n += protowire.SizeVarint(uint64(v))
   152  		case Int32, Uint32, Float32:
   153  			n += protowire.SizeFixed32()
   154  		case Int64, Uint64, Float64:
   155  			n += protowire.SizeFixed64()
   156  		case String:
   157  			n += protowire.SizeBytes(len(v))
   158  		case Bytes:
   159  			n += protowire.SizeBytes(len(v))
   160  		case LengthPrefix:
   161  			n += protowire.SizeBytes(Message(v).Size())
   162  		case Denormalized:
   163  			n += int(v.Count) + Message{v.Value}.Size()
   164  		case Raw:
   165  			n += len(v)
   166  		default:
   167  			panic(fmt.Sprintf("unknown type: %T", v))
   168  		}
   169  	}
   170  	return n
   171  }
   172  
   173  // Marshal encodes a syntax tree into the protobuf wire format.
   174  //
   175  // Example message definition:
   176  //
   177  //	message MyMessage {
   178  //		string field1 = 1;
   179  //		int64 field2 = 2;
   180  //		repeated float32 field3 = 3;
   181  //	}
   182  //
   183  // Example encoded message:
   184  //
   185  //	b := Message{
   186  //		Tag{1, BytesType}, String("Hello, world!"),
   187  //		Tag{2, VarintType}, Varint(-10),
   188  //		Tag{3, BytesType}, LengthPrefix{
   189  //			Float32(1.1), Float32(2.2), Float32(3.3),
   190  //		},
   191  //	}.Marshal()
   192  //
   193  // Resulting wire data:
   194  //
   195  //	0x0000  0a 0d 48 65 6c 6c 6f 2c  20 77 6f 72 6c 64 21 10  |..Hello, world!.|
   196  //	0x0010  f6 ff ff ff ff ff ff ff  ff 01 1a 0c cd cc 8c 3f  |...............?|
   197  //	0x0020  cd cc 0c 40 33 33 53 40                           |...@33S@|
   198  func (m Message) Marshal() []byte {
   199  	var out []byte
   200  	for _, v := range m {
   201  		switch v := v.(type) {
   202  		case Message:
   203  			out = append(out, v.Marshal()...)
   204  		case Tag:
   205  			out = protowire.AppendTag(out, v.Number, v.Type)
   206  		case Bool:
   207  			out = protowire.AppendVarint(out, protowire.EncodeBool(bool(v)))
   208  		case Varint:
   209  			out = protowire.AppendVarint(out, uint64(v))
   210  		case Svarint:
   211  			out = protowire.AppendVarint(out, protowire.EncodeZigZag(int64(v)))
   212  		case Uvarint:
   213  			out = protowire.AppendVarint(out, uint64(v))
   214  		case Int32:
   215  			out = protowire.AppendFixed32(out, uint32(v))
   216  		case Uint32:
   217  			out = protowire.AppendFixed32(out, uint32(v))
   218  		case Float32:
   219  			out = protowire.AppendFixed32(out, math.Float32bits(float32(v)))
   220  		case Int64:
   221  			out = protowire.AppendFixed64(out, uint64(v))
   222  		case Uint64:
   223  			out = protowire.AppendFixed64(out, uint64(v))
   224  		case Float64:
   225  			out = protowire.AppendFixed64(out, math.Float64bits(float64(v)))
   226  		case String:
   227  			out = protowire.AppendBytes(out, []byte(v))
   228  		case Bytes:
   229  			out = protowire.AppendBytes(out, []byte(v))
   230  		case LengthPrefix:
   231  			out = protowire.AppendBytes(out, Message(v).Marshal())
   232  		case Denormalized:
   233  			b := Message{v.Value}.Marshal()
   234  			_, n := protowire.ConsumeVarint(b)
   235  			out = append(out, b[:n]...)
   236  			for i := uint(0); i < v.Count; i++ {
   237  				out[len(out)-1] |= 0x80 // set continuation bit on previous
   238  				out = append(out, 0)
   239  			}
   240  			out = append(out, b[n:]...)
   241  		case Raw:
   242  			return append(out, v...)
   243  		default:
   244  			panic(fmt.Sprintf("unknown type: %T", v))
   245  		}
   246  	}
   247  	return out
   248  }
   249  
   250  // Unmarshal parses the input protobuf wire data as a syntax tree.
   251  // Any parsing error results in the remainder of the input being
   252  // concatenated to the message as a [Raw] type.
   253  //
   254  // Each tag (a tuple of the field number and wire type) encountered is
   255  // inserted into the syntax tree as a [Tag].
   256  //
   257  // The contents of each wire type is mapped to the following Go types:
   258  //
   259  //   - [VarintType] ⇒ [Uvarint]
   260  //   - [Fixed32Type] ⇒ [Uint32]
   261  //   - [Fixed64Type] ⇒ [Uint64]
   262  //   - [BytesType] ⇒ [Bytes]
   263  //   - [StartGroupType] and [StartGroupType] ⇒ [Message]
   264  //
   265  // Since the wire format is not self-describing, this function cannot parse
   266  // sub-messages and will leave them as the [Bytes] type. Further manual parsing
   267  // can be performed as such:
   268  //
   269  //	var m, m1, m2 Message
   270  //	m.Unmarshal(b)
   271  //	m1.Unmarshal(m[3].(Bytes))
   272  //	m[3] = LengthPrefix(m1)
   273  //	m2.Unmarshal(m[3].(LengthPrefix)[1].(Bytes))
   274  //	m[3].(LengthPrefix)[1] = LengthPrefix(m2)
   275  //
   276  // Unmarshal is useful for debugging the protobuf wire format.
   277  func (m *Message) Unmarshal(in []byte) {
   278  	m.unmarshal(in, nil, false)
   279  }
   280  
   281  // UnmarshalDescriptor parses the input protobuf wire data as a syntax tree
   282  // using the provided message descriptor for more accurate parsing of fields.
   283  // It operates like [Message.Unmarshal], but may use a wider range of Go types to
   284  // represent the wire data.
   285  //
   286  // The contents of each wire type is mapped to one of the following Go types:
   287  //
   288  //   - [VarintType] ⇒ [Bool], [Varint], [Svarint], [Uvarint]
   289  //   - [Fixed32Type] ⇒ [Int32], [Uint32], [Float32]
   290  //   - [Fixed64Type] ⇒ [Uint32], [Uint64], [Float64]
   291  //   - [BytesType] ⇒ [String], [Bytes], [LengthPrefix]
   292  //   - [StartGroupType] and [StartGroupType] ⇒ [Message]
   293  //
   294  // If the field is unknown, it uses the same mapping as [Message.Unmarshal].
   295  // Known sub-messages are parsed as a Message and packed repeated fields are
   296  // parsed as a [LengthPrefix].
   297  func (m *Message) UnmarshalDescriptor(in []byte, desc protoreflect.MessageDescriptor) {
   298  	m.unmarshal(in, desc, false)
   299  }
   300  
   301  // UnmarshalAbductive is like [Message.UnmarshalDescriptor], but infers abductively
   302  // whether any unknown bytes values is a message based on whether it is
   303  // a syntactically well-formed message.
   304  //
   305  // Note that the protobuf wire format is not fully self-describing,
   306  // so abductive inference may attempt to expand a bytes value as a message
   307  // that is not actually a message. It is a best-effort guess.
   308  func (m *Message) UnmarshalAbductive(in []byte, desc protoreflect.MessageDescriptor) {
   309  	m.unmarshal(in, desc, true)
   310  }
   311  
   312  func (m *Message) unmarshal(in []byte, desc protoreflect.MessageDescriptor, inferMessage bool) {
   313  	p := parser{in: in, out: *m}
   314  	p.parseMessage(desc, false, inferMessage)
   315  	*m = p.out
   316  }
   317  
   318  type parser struct {
   319  	in  []byte
   320  	out []Token
   321  
   322  	invalid bool
   323  }
   324  
   325  func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group, inferMessage bool) {
   326  	for len(p.in) > 0 {
   327  		v, n := protowire.ConsumeVarint(p.in)
   328  		num, typ := protowire.DecodeTag(v)
   329  		if n < 0 || num <= 0 || v > math.MaxUint32 {
   330  			p.out, p.in = append(p.out, Raw(p.in)), nil
   331  			p.invalid = true
   332  			return
   333  		}
   334  		if typ == EndGroupType && group {
   335  			return // if inside a group, then stop
   336  		}
   337  		p.out, p.in = append(p.out, Tag{num, typ}), p.in[n:]
   338  		if m := n - protowire.SizeVarint(v); m > 0 {
   339  			p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
   340  		}
   341  
   342  		// If descriptor is available, use it for more accurate parsing.
   343  		var isPacked bool
   344  		var kind protoreflect.Kind
   345  		var subDesc protoreflect.MessageDescriptor
   346  		if msgDesc != nil && !msgDesc.IsPlaceholder() {
   347  			if fieldDesc := msgDesc.Fields().ByNumber(num); fieldDesc != nil {
   348  				isPacked = fieldDesc.IsPacked()
   349  				kind = fieldDesc.Kind()
   350  				switch kind {
   351  				case protoreflect.MessageKind, protoreflect.GroupKind:
   352  					subDesc = fieldDesc.Message()
   353  					if subDesc == nil || subDesc.IsPlaceholder() {
   354  						kind = 0
   355  					}
   356  				}
   357  			}
   358  		}
   359  
   360  		switch typ {
   361  		case VarintType:
   362  			p.parseVarint(kind)
   363  		case Fixed32Type:
   364  			p.parseFixed32(kind)
   365  		case Fixed64Type:
   366  			p.parseFixed64(kind)
   367  		case BytesType:
   368  			p.parseBytes(isPacked, kind, subDesc, inferMessage)
   369  		case StartGroupType:
   370  			p.parseGroup(num, subDesc, inferMessage)
   371  		case EndGroupType:
   372  			// Handled by p.parseGroup.
   373  		default:
   374  			p.out, p.in = append(p.out, Raw(p.in)), nil
   375  			p.invalid = true
   376  		}
   377  	}
   378  }
   379  
   380  func (p *parser) parseVarint(kind protoreflect.Kind) {
   381  	v, n := protowire.ConsumeVarint(p.in)
   382  	if n < 0 {
   383  		p.out, p.in = append(p.out, Raw(p.in)), nil
   384  		p.invalid = true
   385  		return
   386  	}
   387  	switch kind {
   388  	case protoreflect.BoolKind:
   389  		switch v {
   390  		case 0:
   391  			p.out, p.in = append(p.out, Bool(false)), p.in[n:]
   392  		case 1:
   393  			p.out, p.in = append(p.out, Bool(true)), p.in[n:]
   394  		default:
   395  			p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
   396  		}
   397  	case protoreflect.Int32Kind, protoreflect.Int64Kind:
   398  		p.out, p.in = append(p.out, Varint(v)), p.in[n:]
   399  	case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
   400  		p.out, p.in = append(p.out, Svarint(protowire.DecodeZigZag(v))), p.in[n:]
   401  	default:
   402  		p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
   403  	}
   404  	if m := n - protowire.SizeVarint(v); m > 0 {
   405  		p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
   406  	}
   407  }
   408  
   409  func (p *parser) parseFixed32(kind protoreflect.Kind) {
   410  	v, n := protowire.ConsumeFixed32(p.in)
   411  	if n < 0 {
   412  		p.out, p.in = append(p.out, Raw(p.in)), nil
   413  		p.invalid = true
   414  		return
   415  	}
   416  	switch kind {
   417  	case protoreflect.FloatKind:
   418  		p.out, p.in = append(p.out, Float32(math.Float32frombits(v))), p.in[n:]
   419  	case protoreflect.Sfixed32Kind:
   420  		p.out, p.in = append(p.out, Int32(v)), p.in[n:]
   421  	default:
   422  		p.out, p.in = append(p.out, Uint32(v)), p.in[n:]
   423  	}
   424  }
   425  
   426  func (p *parser) parseFixed64(kind protoreflect.Kind) {
   427  	v, n := protowire.ConsumeFixed64(p.in)
   428  	if n < 0 {
   429  		p.out, p.in = append(p.out, Raw(p.in)), nil
   430  		p.invalid = true
   431  		return
   432  	}
   433  	switch kind {
   434  	case protoreflect.DoubleKind:
   435  		p.out, p.in = append(p.out, Float64(math.Float64frombits(v))), p.in[n:]
   436  	case protoreflect.Sfixed64Kind:
   437  		p.out, p.in = append(p.out, Int64(v)), p.in[n:]
   438  	default:
   439  		p.out, p.in = append(p.out, Uint64(v)), p.in[n:]
   440  	}
   441  }
   442  
   443  func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor, inferMessage bool) {
   444  	v, n := protowire.ConsumeVarint(p.in)
   445  	if n < 0 {
   446  		p.out, p.in = append(p.out, Raw(p.in)), nil
   447  		p.invalid = true
   448  		return
   449  	}
   450  	p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
   451  	if m := n - protowire.SizeVarint(v); m > 0 {
   452  		p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
   453  	}
   454  	if v > uint64(len(p.in)) {
   455  		p.out, p.in = append(p.out, Raw(p.in)), nil
   456  		p.invalid = true
   457  		return
   458  	}
   459  	p.out = p.out[:len(p.out)-1] // subsequent tokens contain prefix-length
   460  
   461  	if isPacked {
   462  		p.parsePacked(int(v), kind)
   463  	} else {
   464  		switch kind {
   465  		case protoreflect.MessageKind:
   466  			p2 := parser{in: p.in[:v]}
   467  			p2.parseMessage(desc, false, inferMessage)
   468  			p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
   469  		case protoreflect.StringKind:
   470  			p.out, p.in = append(p.out, String(p.in[:v])), p.in[v:]
   471  		case protoreflect.BytesKind:
   472  			p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
   473  		default:
   474  			if inferMessage {
   475  				// Check whether this is a syntactically valid message.
   476  				p2 := parser{in: p.in[:v]}
   477  				p2.parseMessage(nil, false, inferMessage)
   478  				if !p2.invalid {
   479  					p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
   480  					break
   481  				}
   482  			}
   483  			p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
   484  		}
   485  	}
   486  	if m := n - protowire.SizeVarint(v); m > 0 {
   487  		p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
   488  	}
   489  }
   490  
   491  func (p *parser) parsePacked(n int, kind protoreflect.Kind) {
   492  	p2 := parser{in: p.in[:n]}
   493  	for len(p2.in) > 0 {
   494  		switch kind {
   495  		case protoreflect.BoolKind, protoreflect.EnumKind,
   496  			protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind,
   497  			protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind:
   498  			p2.parseVarint(kind)
   499  		case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind:
   500  			p2.parseFixed32(kind)
   501  		case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind:
   502  			p2.parseFixed64(kind)
   503  		default:
   504  			panic(fmt.Sprintf("invalid packed kind: %v", kind))
   505  		}
   506  	}
   507  	p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[n:]
   508  }
   509  
   510  func (p *parser) parseGroup(startNum protowire.Number, desc protoreflect.MessageDescriptor, inferMessage bool) {
   511  	p2 := parser{in: p.in}
   512  	p2.parseMessage(desc, true, inferMessage)
   513  	if len(p2.out) > 0 {
   514  		p.out = append(p.out, Message(p2.out))
   515  	}
   516  	p.in = p2.in
   517  
   518  	// Append the trailing end group.
   519  	v, n := protowire.ConsumeVarint(p.in)
   520  	if endNum, typ := protowire.DecodeTag(v); typ == EndGroupType {
   521  		if startNum != endNum {
   522  			p.invalid = true
   523  		}
   524  		p.out, p.in = append(p.out, Tag{endNum, typ}), p.in[n:]
   525  		if m := n - protowire.SizeVarint(v); m > 0 {
   526  			p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
   527  		}
   528  	}
   529  }
   530  
   531  // Format implements a custom formatter to visualize the syntax tree.
   532  // Using "%#v" formats the Message in Go source code.
   533  func (m Message) Format(s fmt.State, r rune) {
   534  	switch r {
   535  	case 'x':
   536  		io.WriteString(s, fmt.Sprintf("%x", m.Marshal()))
   537  	case 'X':
   538  		io.WriteString(s, fmt.Sprintf("%X", m.Marshal()))
   539  	case 'v':
   540  		switch {
   541  		case s.Flag('#'):
   542  			io.WriteString(s, m.format(true, true))
   543  		case s.Flag('+'):
   544  			io.WriteString(s, m.format(false, true))
   545  		default:
   546  			io.WriteString(s, m.format(false, false))
   547  		}
   548  	default:
   549  		panic("invalid verb: " + string(r))
   550  	}
   551  }
   552  
   553  // format formats the message.
   554  // If source is enabled, this emits valid Go source.
   555  // If multi is enabled, the output may span multiple lines.
   556  func (m Message) format(source, multi bool) string {
   557  	var ss []string
   558  	var prefix, nextPrefix string
   559  	for _, v := range m {
   560  		// Ensure certain tokens have preceding or succeeding newlines.
   561  		prefix, nextPrefix = nextPrefix, " "
   562  		if multi {
   563  			switch v := v.(type) {
   564  			case Tag: // only has preceding newline
   565  				prefix = "\n"
   566  			case Denormalized: // only has preceding newline
   567  				if _, ok := v.Value.(Tag); ok {
   568  					prefix = "\n"
   569  				}
   570  			case Message, Raw: // has preceding and succeeding newlines
   571  				prefix, nextPrefix = "\n", "\n"
   572  			}
   573  		}
   574  
   575  		s := formatToken(v, source, multi)
   576  		ss = append(ss, prefix+s+",")
   577  	}
   578  
   579  	var s string
   580  	if len(ss) > 0 {
   581  		s = strings.TrimSpace(strings.Join(ss, ""))
   582  		if multi {
   583  			s = "\n\t" + strings.Join(strings.Split(s, "\n"), "\n\t") + "\n"
   584  		} else {
   585  			s = strings.TrimSuffix(s, ",")
   586  		}
   587  	}
   588  	s = fmt.Sprintf("%T{%s}", m, s)
   589  	if !source {
   590  		s = trimPackage(s)
   591  	}
   592  	return s
   593  }
   594  
   595  // formatToken formats a single token.
   596  func formatToken(t Token, source, multi bool) (s string) {
   597  	switch v := t.(type) {
   598  	case Message:
   599  		s = v.format(source, multi)
   600  	case LengthPrefix:
   601  		s = formatPacked(v, source, multi)
   602  		if s == "" {
   603  			ms := Message(v).format(source, multi)
   604  			s = fmt.Sprintf("%T(%s)", v, ms)
   605  		}
   606  	case Tag:
   607  		s = fmt.Sprintf("%T{%d, %s}", v, v.Number, formatType(v.Type, source))
   608  	case Bool, Varint, Svarint, Uvarint, Int32, Uint32, Float32, Int64, Uint64, Float64:
   609  		if source {
   610  			// Print floats in a way that preserves exact precision.
   611  			if f, _ := v.(Float32); math.IsNaN(float64(f)) || math.IsInf(float64(f), 0) {
   612  				switch {
   613  				case f > 0:
   614  					s = fmt.Sprintf("%T(math.Inf(+1))", v)
   615  				case f < 0:
   616  					s = fmt.Sprintf("%T(math.Inf(-1))", v)
   617  				case math.Float32bits(float32(math.NaN())) == math.Float32bits(float32(f)):
   618  					s = fmt.Sprintf("%T(math.NaN())", v)
   619  				default:
   620  					s = fmt.Sprintf("%T(math.Float32frombits(0x%08x))", v, math.Float32bits(float32(f)))
   621  				}
   622  				break
   623  			}
   624  			if f, _ := v.(Float64); math.IsNaN(float64(f)) || math.IsInf(float64(f), 0) {
   625  				switch {
   626  				case f > 0:
   627  					s = fmt.Sprintf("%T(math.Inf(+1))", v)
   628  				case f < 0:
   629  					s = fmt.Sprintf("%T(math.Inf(-1))", v)
   630  				case math.Float64bits(float64(math.NaN())) == math.Float64bits(float64(f)):
   631  					s = fmt.Sprintf("%T(math.NaN())", v)
   632  				default:
   633  					s = fmt.Sprintf("%T(math.Float64frombits(0x%016x))", v, math.Float64bits(float64(f)))
   634  				}
   635  				break
   636  			}
   637  		}
   638  		s = fmt.Sprintf("%T(%v)", v, v)
   639  	case String, Bytes, Raw:
   640  		s = fmt.Sprintf("%s", v)
   641  		s = fmt.Sprintf("%T(%s)", v, formatString(s))
   642  	case Denormalized:
   643  		s = fmt.Sprintf("%T{+%d, %v}", v, v.Count, formatToken(v.Value, source, multi))
   644  	default:
   645  		panic(fmt.Sprintf("unknown type: %T", v))
   646  	}
   647  	if !source {
   648  		s = trimPackage(s)
   649  	}
   650  	return s
   651  }
   652  
   653  // formatPacked returns a non-empty string if LengthPrefix looks like a packed
   654  // repeated field of primitives.
   655  func formatPacked(v LengthPrefix, source, multi bool) string {
   656  	var ss []string
   657  	for _, v := range v {
   658  		switch v.(type) {
   659  		case Bool, Varint, Svarint, Uvarint, Int32, Uint32, Float32, Int64, Uint64, Float64, Denormalized, Raw:
   660  			if v, ok := v.(Denormalized); ok {
   661  				switch v.Value.(type) {
   662  				case Bool, Varint, Svarint, Uvarint:
   663  				default:
   664  					return ""
   665  				}
   666  			}
   667  			ss = append(ss, formatToken(v, source, multi))
   668  		default:
   669  			return ""
   670  		}
   671  	}
   672  	s := fmt.Sprintf("%T{%s}", v, strings.Join(ss, ", "))
   673  	if !source {
   674  		s = trimPackage(s)
   675  	}
   676  	return s
   677  }
   678  
   679  // formatType returns the name for Type.
   680  func formatType(t Type, source bool) (s string) {
   681  	switch t {
   682  	case VarintType:
   683  		s = pkg + ".VarintType"
   684  	case Fixed32Type:
   685  		s = pkg + ".Fixed32Type"
   686  	case Fixed64Type:
   687  		s = pkg + ".Fixed64Type"
   688  	case BytesType:
   689  		s = pkg + ".BytesType"
   690  	case StartGroupType:
   691  		s = pkg + ".StartGroupType"
   692  	case EndGroupType:
   693  		s = pkg + ".EndGroupType"
   694  	default:
   695  		s = fmt.Sprintf("Type(%d)", t)
   696  	}
   697  	if !source {
   698  		s = strings.TrimSuffix(trimPackage(s), "Type")
   699  	}
   700  	return s
   701  }
   702  
   703  // formatString returns a quoted string for s.
   704  func formatString(s string) string {
   705  	// Use quoted string if it the same length as a raw string literal.
   706  	// Otherwise, attempt to use the raw string form.
   707  	qs := strconv.Quote(s)
   708  	if len(qs) == 1+len(s)+1 {
   709  		return qs
   710  	}
   711  
   712  	// Disallow newlines to ensure output is a single line.
   713  	// Disallow non-printable runes for readability purposes.
   714  	rawInvalid := func(r rune) bool {
   715  		return r == '`' || r == '\n' || r == utf8.RuneError || !unicode.IsPrint(r)
   716  	}
   717  	if strings.IndexFunc(s, rawInvalid) < 0 {
   718  		return "`" + s + "`"
   719  	}
   720  	return qs
   721  }
   722  
   723  var pkg = path.Base(reflect.TypeOf(Tag{}).PkgPath())
   724  
   725  func trimPackage(s string) string {
   726  	return strings.TrimPrefix(strings.TrimPrefix(s, pkg), ".")
   727  }
   728  

View as plain text