1
2
3
4
5 package impl
6
7 import (
8 "math/bits"
9
10 "google.golang.org/protobuf/encoding/protowire"
11 "google.golang.org/protobuf/internal/errors"
12 "google.golang.org/protobuf/internal/flags"
13 "google.golang.org/protobuf/proto"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 "google.golang.org/protobuf/reflect/protoregistry"
16 "google.golang.org/protobuf/runtime/protoiface"
17 )
18
19 var errDecode = errors.New("cannot parse invalid wire-format data")
20 var errRecursionDepth = errors.New("exceeded maximum recursion depth")
21
22 type unmarshalOptions struct {
23 flags protoiface.UnmarshalInputFlags
24 resolver interface {
25 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
26 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
27 }
28 depth int
29 }
30
31 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
32 return proto.UnmarshalOptions{
33 Merge: true,
34 AllowPartial: true,
35 DiscardUnknown: o.DiscardUnknown(),
36 Resolver: o.resolver,
37 }
38 }
39
40 func (o unmarshalOptions) DiscardUnknown() bool {
41 return o.flags&protoiface.UnmarshalDiscardUnknown != 0
42 }
43
44 func (o unmarshalOptions) IsDefault() bool {
45 return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
46 }
47
48 var lazyUnmarshalOptions = unmarshalOptions{
49 resolver: protoregistry.GlobalTypes,
50 depth: protowire.DefaultRecursionLimit,
51 }
52
53 type unmarshalOutput struct {
54 n int
55 initialized bool
56 }
57
58
59 func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
60 var p pointer
61 if ms, ok := in.Message.(*messageState); ok {
62 p = ms.pointer()
63 } else {
64 p = in.Message.(*messageReflectWrapper).pointer()
65 }
66 out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
67 flags: in.Flags,
68 resolver: in.Resolver,
69 depth: in.Depth,
70 })
71 var flags protoiface.UnmarshalOutputFlags
72 if out.initialized {
73 flags |= protoiface.UnmarshalInitialized
74 }
75 return protoiface.UnmarshalOutput{
76 Flags: flags,
77 }, err
78 }
79
80
81
82
83
84
85
86 var errUnknown = errors.New("unknown")
87
88 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
89 mi.init()
90 opts.depth--
91 if opts.depth < 0 {
92 return out, errRecursionDepth
93 }
94 if flags.ProtoLegacy && mi.isMessageSet {
95 return unmarshalMessageSet(mi, b, p, opts)
96 }
97 initialized := true
98 var requiredMask uint64
99 var exts *map[int32]ExtensionField
100 start := len(b)
101 for len(b) > 0 {
102
103 var tag uint64
104 if b[0] < 0x80 {
105 tag = uint64(b[0])
106 b = b[1:]
107 } else if len(b) >= 2 && b[1] < 128 {
108 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
109 b = b[2:]
110 } else {
111 var n int
112 tag, n = protowire.ConsumeVarint(b)
113 if n < 0 {
114 return out, errDecode
115 }
116 b = b[n:]
117 }
118 var num protowire.Number
119 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
120 return out, errDecode
121 } else {
122 num = protowire.Number(n)
123 }
124 wtyp := protowire.Type(tag & 7)
125
126 if wtyp == protowire.EndGroupType {
127 if num != groupTag {
128 return out, errDecode
129 }
130 groupTag = 0
131 break
132 }
133
134 var f *coderFieldInfo
135 if int(num) < len(mi.denseCoderFields) {
136 f = mi.denseCoderFields[num]
137 } else {
138 f = mi.coderFields[num]
139 }
140 var n int
141 err := errUnknown
142 switch {
143 case f != nil:
144 if f.funcs.unmarshal == nil {
145 break
146 }
147 var o unmarshalOutput
148 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
149 n = o.n
150 if err != nil {
151 break
152 }
153 requiredMask |= f.validation.requiredBit
154 if f.funcs.isInit != nil && !o.initialized {
155 initialized = false
156 }
157 default:
158
159 if exts == nil && mi.extensionOffset.IsValid() {
160 exts = p.Apply(mi.extensionOffset).Extensions()
161 if *exts == nil {
162 *exts = make(map[int32]ExtensionField)
163 }
164 }
165 if exts == nil {
166 break
167 }
168 var o unmarshalOutput
169 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
170 if err != nil {
171 break
172 }
173 n = o.n
174 if !o.initialized {
175 initialized = false
176 }
177 }
178 if err != nil {
179 if err != errUnknown {
180 return out, err
181 }
182 n = protowire.ConsumeFieldValue(num, wtyp, b)
183 if n < 0 {
184 return out, errDecode
185 }
186 if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
187 u := mi.mutableUnknownBytes(p)
188 *u = protowire.AppendTag(*u, num, wtyp)
189 *u = append(*u, b[:n]...)
190 }
191 }
192 b = b[n:]
193 }
194 if groupTag != 0 {
195 return out, errDecode
196 }
197 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
198 initialized = false
199 }
200 if initialized {
201 out.initialized = true
202 }
203 out.n = start - len(b)
204 return out, nil
205 }
206
207 func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
208 x := exts[int32(num)]
209 xt := x.Type()
210 if xt == nil {
211 var err error
212 xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
213 if err != nil {
214 if err == protoregistry.NotFound {
215 return out, errUnknown
216 }
217 return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
218 }
219 }
220 xi := getExtensionFieldInfo(xt)
221 if xi.funcs.unmarshal == nil {
222 return out, errUnknown
223 }
224 if flags.LazyUnmarshalExtensions {
225 if opts.IsDefault() && x.canLazy(xt) {
226 out, valid := skipExtension(b, xi, num, wtyp, opts)
227 switch valid {
228 case ValidationValid:
229 if out.initialized {
230 x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
231 exts[int32(num)] = x
232 return out, nil
233 }
234 case ValidationInvalid:
235 return out, errDecode
236 case ValidationUnknown:
237 }
238 }
239 }
240 ival := x.Value()
241 if !ival.IsValid() && xi.unmarshalNeedsValue {
242
243
244
245 ival = xt.New()
246 }
247 v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
248 if err != nil {
249 return out, err
250 }
251 if xi.funcs.isInit == nil {
252 out.initialized = true
253 }
254 x.Set(xt, v)
255 exts[int32(num)] = x
256 return out, nil
257 }
258
259 func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
260 if xi.validation.mi == nil {
261 return out, ValidationUnknown
262 }
263 xi.validation.mi.init()
264 switch xi.validation.typ {
265 case validationTypeMessage:
266 if wtyp != protowire.BytesType {
267 return out, ValidationUnknown
268 }
269 v, n := protowire.ConsumeBytes(b)
270 if n < 0 {
271 return out, ValidationUnknown
272 }
273 out, st := xi.validation.mi.validate(v, 0, opts)
274 out.n = n
275 return out, st
276 case validationTypeGroup:
277 if wtyp != protowire.StartGroupType {
278 return out, ValidationUnknown
279 }
280 out, st := xi.validation.mi.validate(b, num, opts)
281 return out, st
282 default:
283 return out, ValidationUnknown
284 }
285 }
286
View as plain text