...

Source file src/google.golang.org/protobuf/internal/cmd/generate-types/proto.go

Documentation: google.golang.org/protobuf/internal/cmd/generate-types

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"strings"
     9  	"text/template"
    10  )
    11  
    12  type WireType string
    13  
    14  const (
    15  	WireVarint  WireType = "Varint"
    16  	WireFixed32 WireType = "Fixed32"
    17  	WireFixed64 WireType = "Fixed64"
    18  	WireBytes   WireType = "Bytes"
    19  	WireGroup   WireType = "Group"
    20  )
    21  
    22  func (w WireType) Expr() Expr {
    23  	if w == WireGroup {
    24  		return "protowire.StartGroupType"
    25  	}
    26  	return "protowire." + Expr(w) + "Type"
    27  }
    28  
    29  func (w WireType) Packable() bool {
    30  	return w == WireVarint || w == WireFixed32 || w == WireFixed64
    31  }
    32  
    33  func (w WireType) ConstSize() bool {
    34  	return w == WireFixed32 || w == WireFixed64
    35  }
    36  
    37  type GoType string
    38  
    39  var GoTypes = []GoType{
    40  	GoBool,
    41  	GoInt32,
    42  	GoUint32,
    43  	GoInt64,
    44  	GoUint64,
    45  	GoFloat32,
    46  	GoFloat64,
    47  	GoString,
    48  	GoBytes,
    49  }
    50  
    51  const (
    52  	GoBool    = "bool"
    53  	GoInt32   = "int32"
    54  	GoUint32  = "uint32"
    55  	GoInt64   = "int64"
    56  	GoUint64  = "uint64"
    57  	GoFloat32 = "float32"
    58  	GoFloat64 = "float64"
    59  	GoString  = "string"
    60  	GoBytes   = "[]byte"
    61  )
    62  
    63  func (g GoType) Zero() Expr {
    64  	switch g {
    65  	case GoBool:
    66  		return "false"
    67  	case GoString:
    68  		return `""`
    69  	case GoBytes:
    70  		return "nil"
    71  	}
    72  	return "0"
    73  }
    74  
    75  // Kind is the reflect.Kind of the type.
    76  func (g GoType) Kind() Expr {
    77  	if g == "" || g == GoBytes {
    78  		return ""
    79  	}
    80  	return "reflect." + Expr(strings.ToUpper(string(g[:1]))+string(g[1:]))
    81  }
    82  
    83  // PointerMethod is the "internal/impl".pointer method used to access a pointer to this type.
    84  func (g GoType) PointerMethod() Expr {
    85  	if g == GoBytes {
    86  		return "Bytes"
    87  	}
    88  	return Expr(strings.ToUpper(string(g[:1])) + string(g[1:]))
    89  }
    90  
    91  type ProtoKind struct {
    92  	Name     string
    93  	WireType WireType
    94  
    95  	// Conversions to/from protoreflect.Value.
    96  	ToValue   Expr
    97  	FromValue Expr
    98  
    99  	// Conversions to/from generated structures.
   100  	GoType         GoType
   101  	ToGoType       Expr
   102  	ToGoTypeNoZero Expr
   103  	FromGoType     Expr
   104  	NoPointer      bool
   105  	NoValueCodec   bool
   106  }
   107  
   108  func (k ProtoKind) Expr() Expr {
   109  	return "protoreflect." + Expr(k.Name) + "Kind"
   110  }
   111  
   112  var ProtoKinds = []ProtoKind{
   113  	{
   114  		Name:       "Bool",
   115  		WireType:   WireVarint,
   116  		ToValue:    "protoreflect.ValueOfBool(protowire.DecodeBool(v))",
   117  		FromValue:  "protowire.EncodeBool(v.Bool())",
   118  		GoType:     GoBool,
   119  		ToGoType:   "protowire.DecodeBool(v)",
   120  		FromGoType: "protowire.EncodeBool(v)",
   121  	},
   122  	{
   123  		Name:      "Enum",
   124  		WireType:  WireVarint,
   125  		ToValue:   "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))",
   126  		FromValue: "uint64(v.Enum())",
   127  	},
   128  	{
   129  		Name:       "Int32",
   130  		WireType:   WireVarint,
   131  		ToValue:    "protoreflect.ValueOfInt32(int32(v))",
   132  		FromValue:  "uint64(int32(v.Int()))",
   133  		GoType:     GoInt32,
   134  		ToGoType:   "int32(v)",
   135  		FromGoType: "uint64(v)",
   136  	},
   137  	{
   138  		Name:       "Sint32",
   139  		WireType:   WireVarint,
   140  		ToValue:    "protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))",
   141  		FromValue:  "protowire.EncodeZigZag(int64(int32(v.Int())))",
   142  		GoType:     GoInt32,
   143  		ToGoType:   "int32(protowire.DecodeZigZag(v & math.MaxUint32))",
   144  		FromGoType: "protowire.EncodeZigZag(int64(v))",
   145  	},
   146  	{
   147  		Name:       "Uint32",
   148  		WireType:   WireVarint,
   149  		ToValue:    "protoreflect.ValueOfUint32(uint32(v))",
   150  		FromValue:  "uint64(uint32(v.Uint()))",
   151  		GoType:     GoUint32,
   152  		ToGoType:   "uint32(v)",
   153  		FromGoType: "uint64(v)",
   154  	},
   155  	{
   156  		Name:       "Int64",
   157  		WireType:   WireVarint,
   158  		ToValue:    "protoreflect.ValueOfInt64(int64(v))",
   159  		FromValue:  "uint64(v.Int())",
   160  		GoType:     GoInt64,
   161  		ToGoType:   "int64(v)",
   162  		FromGoType: "uint64(v)",
   163  	},
   164  	{
   165  		Name:       "Sint64",
   166  		WireType:   WireVarint,
   167  		ToValue:    "protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))",
   168  		FromValue:  "protowire.EncodeZigZag(v.Int())",
   169  		GoType:     GoInt64,
   170  		ToGoType:   "protowire.DecodeZigZag(v)",
   171  		FromGoType: "protowire.EncodeZigZag(v)",
   172  	},
   173  	{
   174  		Name:       "Uint64",
   175  		WireType:   WireVarint,
   176  		ToValue:    "protoreflect.ValueOfUint64(v)",
   177  		FromValue:  "v.Uint()",
   178  		GoType:     GoUint64,
   179  		ToGoType:   "v",
   180  		FromGoType: "v",
   181  	},
   182  	{
   183  		Name:       "Sfixed32",
   184  		WireType:   WireFixed32,
   185  		ToValue:    "protoreflect.ValueOfInt32(int32(v))",
   186  		FromValue:  "uint32(v.Int())",
   187  		GoType:     GoInt32,
   188  		ToGoType:   "int32(v)",
   189  		FromGoType: "uint32(v)",
   190  	},
   191  	{
   192  		Name:       "Fixed32",
   193  		WireType:   WireFixed32,
   194  		ToValue:    "protoreflect.ValueOfUint32(uint32(v))",
   195  		FromValue:  "uint32(v.Uint())",
   196  		GoType:     GoUint32,
   197  		ToGoType:   "v",
   198  		FromGoType: "v",
   199  	},
   200  	{
   201  		Name:       "Float",
   202  		WireType:   WireFixed32,
   203  		ToValue:    "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))",
   204  		FromValue:  "math.Float32bits(float32(v.Float()))",
   205  		GoType:     GoFloat32,
   206  		ToGoType:   "math.Float32frombits(v)",
   207  		FromGoType: "math.Float32bits(v)",
   208  	},
   209  	{
   210  		Name:       "Sfixed64",
   211  		WireType:   WireFixed64,
   212  		ToValue:    "protoreflect.ValueOfInt64(int64(v))",
   213  		FromValue:  "uint64(v.Int())",
   214  		GoType:     GoInt64,
   215  		ToGoType:   "int64(v)",
   216  		FromGoType: "uint64(v)",
   217  	},
   218  	{
   219  		Name:       "Fixed64",
   220  		WireType:   WireFixed64,
   221  		ToValue:    "protoreflect.ValueOfUint64(v)",
   222  		FromValue:  "v.Uint()",
   223  		GoType:     GoUint64,
   224  		ToGoType:   "v",
   225  		FromGoType: "v",
   226  	},
   227  	{
   228  		Name:       "Double",
   229  		WireType:   WireFixed64,
   230  		ToValue:    "protoreflect.ValueOfFloat64(math.Float64frombits(v))",
   231  		FromValue:  "math.Float64bits(v.Float())",
   232  		GoType:     GoFloat64,
   233  		ToGoType:   "math.Float64frombits(v)",
   234  		FromGoType: "math.Float64bits(v)",
   235  	},
   236  	{
   237  		Name:       "String",
   238  		WireType:   WireBytes,
   239  		ToValue:    "protoreflect.ValueOfString(string(v))",
   240  		FromValue:  "v.String()",
   241  		GoType:     GoString,
   242  		ToGoType:   "string(v)",
   243  		FromGoType: "v",
   244  	},
   245  	{
   246  		Name:           "Bytes",
   247  		WireType:       WireBytes,
   248  		ToValue:        "protoreflect.ValueOfBytes(append(emptyBuf[:], v...))",
   249  		FromValue:      "v.Bytes()",
   250  		GoType:         GoBytes,
   251  		ToGoType:       "append(emptyBuf[:], v...)",
   252  		ToGoTypeNoZero: "append(([]byte)(nil), v...)",
   253  		FromGoType:     "v",
   254  		NoPointer:      true,
   255  	},
   256  	{
   257  		Name:         "Message",
   258  		WireType:     WireBytes,
   259  		ToValue:      "protoreflect.ValueOfBytes(v)",
   260  		FromValue:    "v",
   261  		NoValueCodec: true,
   262  	},
   263  	{
   264  		Name:         "Group",
   265  		WireType:     WireGroup,
   266  		ToValue:      "protoreflect.ValueOfBytes(v)",
   267  		FromValue:    "v",
   268  		NoValueCodec: true,
   269  	},
   270  }
   271  
   272  func generateProtoDecode() string {
   273  	return mustExecute(protoDecodeTemplate, ProtoKinds)
   274  }
   275  
   276  var protoDecodeTemplate = template.Must(template.New("").Parse(`
   277  // unmarshalScalar decodes a value of the given kind.
   278  //
   279  // Message values are decoded into a []byte which aliases the input data.
   280  func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
   281  	switch fd.Kind() {
   282  	{{- range .}}
   283  	case {{.Expr}}:
   284  		if wtyp != {{.WireType.Expr}} {
   285  			return val, 0, errUnknown
   286  		}
   287  		{{if (eq .WireType "Group") -}}
   288  		v, n := protowire.ConsumeGroup(fd.Number(), b)
   289  		{{- else -}}
   290  		v, n := protowire.Consume{{.WireType}}(b)
   291  		{{- end}}
   292  		if n < 0 {
   293  			return val, 0, errDecode
   294  		}
   295  		{{if (eq .Name "String") -}}
   296  		if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
   297  			return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
   298  		}
   299  		{{end -}}
   300  		return {{.ToValue}}, n, nil
   301  	{{- end}}
   302  	default:
   303  		return val, 0, errUnknown
   304  	}
   305  }
   306  
   307  func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
   308  	switch fd.Kind() {
   309  	{{- range .}}
   310  	case {{.Expr}}:
   311  		{{- if .WireType.Packable}}
   312  		if wtyp == protowire.BytesType {
   313  			buf, n := protowire.ConsumeBytes(b)
   314  			if n < 0 {
   315  				return 0, errDecode
   316  			}
   317  			for len(buf) > 0 {
   318  				v, n := protowire.Consume{{.WireType}}(buf)
   319  				if n < 0 {
   320  					return 0, errDecode
   321  				}
   322  				buf = buf[n:]
   323  				list.Append({{.ToValue}})
   324  			}
   325  			return n, nil
   326  		}
   327  		{{- end}}
   328  		if wtyp != {{.WireType.Expr}} {
   329  			return 0, errUnknown
   330  		}
   331  		{{if (eq .WireType "Group") -}}
   332  		v, n := protowire.ConsumeGroup(fd.Number(), b)
   333  		{{- else -}}
   334  		v, n := protowire.Consume{{.WireType}}(b)
   335  		{{- end}}
   336  		if n < 0 {
   337  			return 0, errDecode
   338  		}
   339  		{{if (eq .Name "String") -}}
   340  		if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
   341  			return 0, errors.InvalidUTF8(string(fd.FullName()))
   342  		}
   343  		{{end -}}
   344  		{{if or (eq .Name "Message") (eq .Name "Group") -}}
   345  		m := list.NewElement()
   346  		if err := o.unmarshalMessage(v, m.Message()); err != nil {
   347  			return 0, err
   348  		}
   349  		list.Append(m)
   350  		{{- else -}}
   351  		list.Append({{.ToValue}})
   352  		{{- end}}
   353  		return n, nil
   354  	{{- end}}
   355  	default:
   356  		return 0, errUnknown
   357  	}
   358  }
   359  
   360  // We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
   361  var emptyBuf [0]byte
   362  `))
   363  
   364  func generateProtoEncode() string {
   365  	return mustExecute(protoEncodeTemplate, ProtoKinds)
   366  }
   367  
   368  var protoEncodeTemplate = template.Must(template.New("").Parse(`
   369  var wireTypes = map[protoreflect.Kind]protowire.Type{
   370  {{- range .}}
   371  	{{.Expr}}: {{.WireType.Expr}},
   372  {{- end}}
   373  }
   374  
   375  func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
   376  	switch fd.Kind() {
   377  	{{- range .}}
   378  	case {{.Expr}}:
   379  		{{- if (eq .Name "String") }}
   380  		if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
   381  			return b, errors.InvalidUTF8(string(fd.FullName()))
   382  		}
   383  		b = protowire.AppendString(b, {{.FromValue}})
   384  		{{- else if (eq .Name "Message") -}}
   385  		var pos int
   386  		var err error
   387  		b, pos = appendSpeculativeLength(b)
   388  		b, err = o.marshalMessage(b, v.Message())
   389  		if err != nil {
   390  			return b, err
   391  		}
   392  		b = finishSpeculativeLength(b, pos)
   393  		{{- else if (eq .Name "Group") -}}
   394  		var err error
   395  		b, err = o.marshalMessage(b, v.Message())
   396  		if err != nil {
   397  			return b, err
   398  		}
   399  		b = protowire.AppendVarint(b, protowire.EncodeTag(fd.Number(), protowire.EndGroupType))
   400  		{{- else -}}
   401  		b = protowire.Append{{.WireType}}(b, {{.FromValue}})
   402  		{{- end}}
   403  	{{- end}}
   404  	default:
   405  		return b, errors.New("invalid kind %v", fd.Kind())
   406  	}
   407  	return b, nil
   408  }
   409  `))
   410  
   411  func generateProtoSize() string {
   412  	return mustExecute(protoSizeTemplate, ProtoKinds)
   413  }
   414  
   415  var protoSizeTemplate = template.Must(template.New("").Parse(`
   416  func (o MarshalOptions) sizeSingular(num protowire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
   417  	switch kind {
   418  	{{- range .}}
   419  	case {{.Expr}}:
   420  		{{if (eq .Name "Message") -}}
   421  		return protowire.SizeBytes(o.size(v.Message()))
   422  		{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
   423  		return protowire.Size{{.WireType}}()
   424  		{{- else if (eq .WireType "Bytes") -}}
   425  		return protowire.Size{{.WireType}}(len({{.FromValue}}))
   426  		{{- else if (eq .WireType "Group") -}}
   427  		return protowire.Size{{.WireType}}(num, o.size(v.Message()))
   428  		{{- else -}}
   429  		return protowire.Size{{.WireType}}({{.FromValue}})
   430  		{{- end}}
   431  	{{- end}}
   432  	default:
   433  		return 0
   434  	}
   435  }
   436  `))
   437  

View as plain text