1
2
3
4
5 package main
6
7 import (
8 "strings"
9 "text/template"
10 )
11
12 type WireType string
13
14 const (
15 WireVarint WireType = "Varint"
16 WireFixed32 WireType = "Fixed32"
17 WireFixed64 WireType = "Fixed64"
18 WireBytes WireType = "Bytes"
19 WireGroup WireType = "Group"
20 )
21
22 func (w WireType) Expr() Expr {
23 if w == WireGroup {
24 return "protowire.StartGroupType"
25 }
26 return "protowire." + Expr(w) + "Type"
27 }
28
29 func (w WireType) Packable() bool {
30 return w == WireVarint || w == WireFixed32 || w == WireFixed64
31 }
32
33 func (w WireType) ConstSize() bool {
34 return w == WireFixed32 || w == WireFixed64
35 }
36
37 type GoType string
38
39 var GoTypes = []GoType{
40 GoBool,
41 GoInt32,
42 GoUint32,
43 GoInt64,
44 GoUint64,
45 GoFloat32,
46 GoFloat64,
47 GoString,
48 GoBytes,
49 }
50
51 const (
52 GoBool = "bool"
53 GoInt32 = "int32"
54 GoUint32 = "uint32"
55 GoInt64 = "int64"
56 GoUint64 = "uint64"
57 GoFloat32 = "float32"
58 GoFloat64 = "float64"
59 GoString = "string"
60 GoBytes = "[]byte"
61 )
62
63 func (g GoType) Zero() Expr {
64 switch g {
65 case GoBool:
66 return "false"
67 case GoString:
68 return `""`
69 case GoBytes:
70 return "nil"
71 }
72 return "0"
73 }
74
75
76 func (g GoType) Kind() Expr {
77 if g == "" || g == GoBytes {
78 return ""
79 }
80 return "reflect." + Expr(strings.ToUpper(string(g[:1]))+string(g[1:]))
81 }
82
83
84 func (g GoType) PointerMethod() Expr {
85 if g == GoBytes {
86 return "Bytes"
87 }
88 return Expr(strings.ToUpper(string(g[:1])) + string(g[1:]))
89 }
90
91 type ProtoKind struct {
92 Name string
93 WireType WireType
94
95
96 ToValue Expr
97 FromValue Expr
98
99
100 GoType GoType
101 ToGoType Expr
102 ToGoTypeNoZero Expr
103 FromGoType Expr
104 NoPointer bool
105 NoValueCodec bool
106 }
107
108 func (k ProtoKind) Expr() Expr {
109 return "protoreflect." + Expr(k.Name) + "Kind"
110 }
111
112 var ProtoKinds = []ProtoKind{
113 {
114 Name: "Bool",
115 WireType: WireVarint,
116 ToValue: "protoreflect.ValueOfBool(protowire.DecodeBool(v))",
117 FromValue: "protowire.EncodeBool(v.Bool())",
118 GoType: GoBool,
119 ToGoType: "protowire.DecodeBool(v)",
120 FromGoType: "protowire.EncodeBool(v)",
121 },
122 {
123 Name: "Enum",
124 WireType: WireVarint,
125 ToValue: "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))",
126 FromValue: "uint64(v.Enum())",
127 },
128 {
129 Name: "Int32",
130 WireType: WireVarint,
131 ToValue: "protoreflect.ValueOfInt32(int32(v))",
132 FromValue: "uint64(int32(v.Int()))",
133 GoType: GoInt32,
134 ToGoType: "int32(v)",
135 FromGoType: "uint64(v)",
136 },
137 {
138 Name: "Sint32",
139 WireType: WireVarint,
140 ToValue: "protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))",
141 FromValue: "protowire.EncodeZigZag(int64(int32(v.Int())))",
142 GoType: GoInt32,
143 ToGoType: "int32(protowire.DecodeZigZag(v & math.MaxUint32))",
144 FromGoType: "protowire.EncodeZigZag(int64(v))",
145 },
146 {
147 Name: "Uint32",
148 WireType: WireVarint,
149 ToValue: "protoreflect.ValueOfUint32(uint32(v))",
150 FromValue: "uint64(uint32(v.Uint()))",
151 GoType: GoUint32,
152 ToGoType: "uint32(v)",
153 FromGoType: "uint64(v)",
154 },
155 {
156 Name: "Int64",
157 WireType: WireVarint,
158 ToValue: "protoreflect.ValueOfInt64(int64(v))",
159 FromValue: "uint64(v.Int())",
160 GoType: GoInt64,
161 ToGoType: "int64(v)",
162 FromGoType: "uint64(v)",
163 },
164 {
165 Name: "Sint64",
166 WireType: WireVarint,
167 ToValue: "protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))",
168 FromValue: "protowire.EncodeZigZag(v.Int())",
169 GoType: GoInt64,
170 ToGoType: "protowire.DecodeZigZag(v)",
171 FromGoType: "protowire.EncodeZigZag(v)",
172 },
173 {
174 Name: "Uint64",
175 WireType: WireVarint,
176 ToValue: "protoreflect.ValueOfUint64(v)",
177 FromValue: "v.Uint()",
178 GoType: GoUint64,
179 ToGoType: "v",
180 FromGoType: "v",
181 },
182 {
183 Name: "Sfixed32",
184 WireType: WireFixed32,
185 ToValue: "protoreflect.ValueOfInt32(int32(v))",
186 FromValue: "uint32(v.Int())",
187 GoType: GoInt32,
188 ToGoType: "int32(v)",
189 FromGoType: "uint32(v)",
190 },
191 {
192 Name: "Fixed32",
193 WireType: WireFixed32,
194 ToValue: "protoreflect.ValueOfUint32(uint32(v))",
195 FromValue: "uint32(v.Uint())",
196 GoType: GoUint32,
197 ToGoType: "v",
198 FromGoType: "v",
199 },
200 {
201 Name: "Float",
202 WireType: WireFixed32,
203 ToValue: "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))",
204 FromValue: "math.Float32bits(float32(v.Float()))",
205 GoType: GoFloat32,
206 ToGoType: "math.Float32frombits(v)",
207 FromGoType: "math.Float32bits(v)",
208 },
209 {
210 Name: "Sfixed64",
211 WireType: WireFixed64,
212 ToValue: "protoreflect.ValueOfInt64(int64(v))",
213 FromValue: "uint64(v.Int())",
214 GoType: GoInt64,
215 ToGoType: "int64(v)",
216 FromGoType: "uint64(v)",
217 },
218 {
219 Name: "Fixed64",
220 WireType: WireFixed64,
221 ToValue: "protoreflect.ValueOfUint64(v)",
222 FromValue: "v.Uint()",
223 GoType: GoUint64,
224 ToGoType: "v",
225 FromGoType: "v",
226 },
227 {
228 Name: "Double",
229 WireType: WireFixed64,
230 ToValue: "protoreflect.ValueOfFloat64(math.Float64frombits(v))",
231 FromValue: "math.Float64bits(v.Float())",
232 GoType: GoFloat64,
233 ToGoType: "math.Float64frombits(v)",
234 FromGoType: "math.Float64bits(v)",
235 },
236 {
237 Name: "String",
238 WireType: WireBytes,
239 ToValue: "protoreflect.ValueOfString(string(v))",
240 FromValue: "v.String()",
241 GoType: GoString,
242 ToGoType: "string(v)",
243 FromGoType: "v",
244 },
245 {
246 Name: "Bytes",
247 WireType: WireBytes,
248 ToValue: "protoreflect.ValueOfBytes(append(emptyBuf[:], v...))",
249 FromValue: "v.Bytes()",
250 GoType: GoBytes,
251 ToGoType: "append(emptyBuf[:], v...)",
252 ToGoTypeNoZero: "append(([]byte)(nil), v...)",
253 FromGoType: "v",
254 NoPointer: true,
255 },
256 {
257 Name: "Message",
258 WireType: WireBytes,
259 ToValue: "protoreflect.ValueOfBytes(v)",
260 FromValue: "v",
261 NoValueCodec: true,
262 },
263 {
264 Name: "Group",
265 WireType: WireGroup,
266 ToValue: "protoreflect.ValueOfBytes(v)",
267 FromValue: "v",
268 NoValueCodec: true,
269 },
270 }
271
272 func generateProtoDecode() string {
273 return mustExecute(protoDecodeTemplate, ProtoKinds)
274 }
275
276 var protoDecodeTemplate = template.Must(template.New("").Parse(`
277 // unmarshalScalar decodes a value of the given kind.
278 //
279 // Message values are decoded into a []byte which aliases the input data.
280 func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
281 switch fd.Kind() {
282 {{- range .}}
283 case {{.Expr}}:
284 if wtyp != {{.WireType.Expr}} {
285 return val, 0, errUnknown
286 }
287 {{if (eq .WireType "Group") -}}
288 v, n := protowire.ConsumeGroup(fd.Number(), b)
289 {{- else -}}
290 v, n := protowire.Consume{{.WireType}}(b)
291 {{- end}}
292 if n < 0 {
293 return val, 0, errDecode
294 }
295 {{if (eq .Name "String") -}}
296 if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
297 return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
298 }
299 {{end -}}
300 return {{.ToValue}}, n, nil
301 {{- end}}
302 default:
303 return val, 0, errUnknown
304 }
305 }
306
307 func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
308 switch fd.Kind() {
309 {{- range .}}
310 case {{.Expr}}:
311 {{- if .WireType.Packable}}
312 if wtyp == protowire.BytesType {
313 buf, n := protowire.ConsumeBytes(b)
314 if n < 0 {
315 return 0, errDecode
316 }
317 for len(buf) > 0 {
318 v, n := protowire.Consume{{.WireType}}(buf)
319 if n < 0 {
320 return 0, errDecode
321 }
322 buf = buf[n:]
323 list.Append({{.ToValue}})
324 }
325 return n, nil
326 }
327 {{- end}}
328 if wtyp != {{.WireType.Expr}} {
329 return 0, errUnknown
330 }
331 {{if (eq .WireType "Group") -}}
332 v, n := protowire.ConsumeGroup(fd.Number(), b)
333 {{- else -}}
334 v, n := protowire.Consume{{.WireType}}(b)
335 {{- end}}
336 if n < 0 {
337 return 0, errDecode
338 }
339 {{if (eq .Name "String") -}}
340 if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
341 return 0, errors.InvalidUTF8(string(fd.FullName()))
342 }
343 {{end -}}
344 {{if or (eq .Name "Message") (eq .Name "Group") -}}
345 m := list.NewElement()
346 if err := o.unmarshalMessage(v, m.Message()); err != nil {
347 return 0, err
348 }
349 list.Append(m)
350 {{- else -}}
351 list.Append({{.ToValue}})
352 {{- end}}
353 return n, nil
354 {{- end}}
355 default:
356 return 0, errUnknown
357 }
358 }
359
360 // We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
361 var emptyBuf [0]byte
362 `))
363
364 func generateProtoEncode() string {
365 return mustExecute(protoEncodeTemplate, ProtoKinds)
366 }
367
368 var protoEncodeTemplate = template.Must(template.New("").Parse(`
369 var wireTypes = map[protoreflect.Kind]protowire.Type{
370 {{- range .}}
371 {{.Expr}}: {{.WireType.Expr}},
372 {{- end}}
373 }
374
375 func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
376 switch fd.Kind() {
377 {{- range .}}
378 case {{.Expr}}:
379 {{- if (eq .Name "String") }}
380 if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
381 return b, errors.InvalidUTF8(string(fd.FullName()))
382 }
383 b = protowire.AppendString(b, {{.FromValue}})
384 {{- else if (eq .Name "Message") -}}
385 var pos int
386 var err error
387 b, pos = appendSpeculativeLength(b)
388 b, err = o.marshalMessage(b, v.Message())
389 if err != nil {
390 return b, err
391 }
392 b = finishSpeculativeLength(b, pos)
393 {{- else if (eq .Name "Group") -}}
394 var err error
395 b, err = o.marshalMessage(b, v.Message())
396 if err != nil {
397 return b, err
398 }
399 b = protowire.AppendVarint(b, protowire.EncodeTag(fd.Number(), protowire.EndGroupType))
400 {{- else -}}
401 b = protowire.Append{{.WireType}}(b, {{.FromValue}})
402 {{- end}}
403 {{- end}}
404 default:
405 return b, errors.New("invalid kind %v", fd.Kind())
406 }
407 return b, nil
408 }
409 `))
410
411 func generateProtoSize() string {
412 return mustExecute(protoSizeTemplate, ProtoKinds)
413 }
414
415 var protoSizeTemplate = template.Must(template.New("").Parse(`
416 func (o MarshalOptions) sizeSingular(num protowire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
417 switch kind {
418 {{- range .}}
419 case {{.Expr}}:
420 {{if (eq .Name "Message") -}}
421 return protowire.SizeBytes(o.size(v.Message()))
422 {{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
423 return protowire.Size{{.WireType}}()
424 {{- else if (eq .WireType "Bytes") -}}
425 return protowire.Size{{.WireType}}(len({{.FromValue}}))
426 {{- else if (eq .WireType "Group") -}}
427 return protowire.Size{{.WireType}}(num, o.size(v.Message()))
428 {{- else -}}
429 return protowire.Size{{.WireType}}({{.FromValue}})
430 {{- end}}
431 {{- end}}
432 default:
433 return 0
434 }
435 }
436 `))
437
View as plain text