1
2
3
4
5
6 package protorange
7
8 import (
9 "bytes"
10 "errors"
11
12 "google.golang.org/protobuf/internal/genid"
13 "google.golang.org/protobuf/internal/order"
14 "google.golang.org/protobuf/proto"
15 "google.golang.org/protobuf/reflect/protopath"
16 "google.golang.org/protobuf/reflect/protoreflect"
17 "google.golang.org/protobuf/reflect/protoregistry"
18 )
19
20 var (
21
22
23
24 Break = errors.New("break traversal of children in current value")
25
26
27
28 Terminate = errors.New("terminate range operation")
29 )
30
31
32
33
34 func Range(m protoreflect.Message, f func(protopath.Values) error) error {
35 return Options{}.Range(m, f, nil)
36 }
37
38
39 type Options struct {
40
41
42
43
44
45
46
47
48
49 Stable bool
50
51
52
53
54
55
56
57 Resolver interface {
58 protoregistry.ExtensionTypeResolver
59 protoregistry.MessageTypeResolver
60 }
61 }
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102 func (o Options) Range(m protoreflect.Message, push, pop func(protopath.Values) error) error {
103 var err error
104 p := new(protopath.Values)
105 if o.Resolver == nil {
106 o.Resolver = protoregistry.GlobalTypes
107 }
108
109 pushStep(p, protopath.Root(m.Descriptor()), protoreflect.ValueOfMessage(m))
110 if push != nil {
111 err = amendError(err, push(*p))
112 }
113 if err == nil {
114 err = o.rangeMessage(p, m, push, pop)
115 }
116 if pop != nil {
117 err = amendError(err, pop(*p))
118 }
119 popStep(p)
120
121 if err == Break || err == Terminate {
122 err = nil
123 }
124 return err
125 }
126
127 func (o Options) rangeMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (err error) {
128 if ok, err := o.rangeAnyMessage(p, m, push, pop); ok {
129 return err
130 }
131
132 fieldOrder := order.AnyFieldOrder
133 if o.Stable {
134 fieldOrder = order.NumberFieldOrder
135 }
136 order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
137 pushStep(p, protopath.FieldAccess(fd), v)
138 if push != nil {
139 err = amendError(err, push(*p))
140 }
141 if err == nil {
142 switch {
143 case fd.IsMap():
144 err = o.rangeMap(p, fd, v.Map(), push, pop)
145 case fd.IsList():
146 err = o.rangeList(p, fd, v.List(), push, pop)
147 case fd.Message() != nil:
148 err = o.rangeMessage(p, v.Message(), push, pop)
149 }
150 }
151 if pop != nil {
152 err = amendError(err, pop(*p))
153 }
154 popStep(p)
155 return err == nil
156 })
157
158 if b := m.GetUnknown(); len(b) > 0 && err == nil {
159 pushStep(p, protopath.UnknownAccess(), protoreflect.ValueOfBytes(b))
160 if push != nil {
161 err = amendError(err, push(*p))
162 }
163 if pop != nil {
164 err = amendError(err, pop(*p))
165 }
166 popStep(p)
167 }
168
169 if err == Break {
170 err = nil
171 }
172 return err
173 }
174
175 func (o Options) rangeAnyMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (ok bool, err error) {
176 md := m.Descriptor()
177 if md.FullName() != "google.protobuf.Any" {
178 return false, nil
179 }
180
181 fds := md.Fields()
182 url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String()
183 val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes()
184 mt, errFind := o.Resolver.FindMessageByURL(url)
185 if errFind != nil {
186 return false, nil
187 }
188
189
190 m2 := mt.New()
191 errUnmarshal := proto.UnmarshalOptions{
192 Merge: true,
193 AllowPartial: true,
194 Resolver: o.Resolver,
195 }.Unmarshal(val, m2.Interface())
196 if errUnmarshal != nil {
197
198
199 return false, nil
200 }
201
202
203 b1, errMarshal := proto.MarshalOptions{
204 AllowPartial: true,
205 Deterministic: true,
206 }.Marshal(m2.Interface())
207 if errMarshal != nil {
208 return true, errMarshal
209 }
210
211 pushStep(p, protopath.AnyExpand(m2.Descriptor()), protoreflect.ValueOfMessage(m2))
212 if push != nil {
213 err = amendError(err, push(*p))
214 }
215 if err == nil {
216 err = o.rangeMessage(p, m2, push, pop)
217 }
218 if pop != nil {
219 err = amendError(err, pop(*p))
220 }
221 popStep(p)
222
223
224 b2, errMarshal := proto.MarshalOptions{
225 AllowPartial: true,
226 Deterministic: true,
227 }.Marshal(m2.Interface())
228 if errMarshal != nil {
229 return true, errMarshal
230 }
231
232
233 if !bytes.Equal(b1, b2) {
234 m.Set(fds.ByNumber(genid.Any_Value_field_number), protoreflect.ValueOfBytes(b2))
235 }
236
237 if err == Break {
238 err = nil
239 }
240 return true, err
241 }
242
243 func (o Options) rangeList(p *protopath.Values, fd protoreflect.FieldDescriptor, ls protoreflect.List, push, pop func(protopath.Values) error) (err error) {
244 for i := 0; i < ls.Len() && err == nil; i++ {
245 v := ls.Get(i)
246 pushStep(p, protopath.ListIndex(i), v)
247 if push != nil {
248 err = amendError(err, push(*p))
249 }
250 if err == nil && fd.Message() != nil {
251 err = o.rangeMessage(p, v.Message(), push, pop)
252 }
253 if pop != nil {
254 err = amendError(err, pop(*p))
255 }
256 popStep(p)
257 }
258
259 if err == Break {
260 err = nil
261 }
262 return err
263 }
264
265 func (o Options) rangeMap(p *protopath.Values, fd protoreflect.FieldDescriptor, ms protoreflect.Map, push, pop func(protopath.Values) error) (err error) {
266 keyOrder := order.AnyKeyOrder
267 if o.Stable {
268 keyOrder = order.GenericKeyOrder
269 }
270 order.RangeEntries(ms, keyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool {
271 pushStep(p, protopath.MapIndex(k), v)
272 if push != nil {
273 err = amendError(err, push(*p))
274 }
275 if err == nil && fd.MapValue().Message() != nil {
276 err = o.rangeMessage(p, v.Message(), push, pop)
277 }
278 if pop != nil {
279 err = amendError(err, pop(*p))
280 }
281 popStep(p)
282 return err == nil
283 })
284
285 if err == Break {
286 err = nil
287 }
288 return err
289 }
290
291 func pushStep(p *protopath.Values, s protopath.Step, v protoreflect.Value) {
292 p.Path = append(p.Path, s)
293 p.Values = append(p.Values, v)
294 }
295
296 func popStep(p *protopath.Values) {
297 p.Path = p.Path[:len(p.Path)-1]
298 p.Values = p.Values[:len(p.Values)-1]
299 }
300
301
302
303
304
305 func amendError(prev, curr error) error {
306 switch {
307 case curr == nil:
308 return prev
309 case curr == Break && prev != nil:
310 return prev
311 case curr == Terminate && prev != nil && prev != Break:
312 return prev
313 default:
314 return curr
315 }
316 }
317
View as plain text