1
2
3
4
5 package protocmp
6
7 import (
8 "bytes"
9 "fmt"
10 "math"
11 "reflect"
12 "strings"
13
14 "github.com/google/go-cmp/cmp"
15 "github.com/google/go-cmp/cmp/cmpopts"
16
17 "google.golang.org/protobuf/proto"
18 "google.golang.org/protobuf/reflect/protoreflect"
19 )
20
21 var (
22 enumReflectType = reflect.TypeOf(Enum{})
23 messageReflectType = reflect.TypeOf(Message{})
24 )
25
26
27
28
29
30
31
32
33
34
35
36
37
38 func FilterEnum(enum protoreflect.Enum, opt cmp.Option) cmp.Option {
39 return FilterDescriptor(enum.Descriptor(), opt)
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54 func FilterMessage(message proto.Message, opt cmp.Option) cmp.Option {
55 return FilterDescriptor(message.ProtoReflect().Descriptor(), opt)
56 }
57
58
59
60
61
62
63
64
65
66
67
68 func FilterField(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option {
69 md := message.ProtoReflect().Descriptor()
70 return FilterDescriptor(mustFindFieldDescriptor(md, name), opt)
71 }
72
73
74
75
76
77
78
79
80
81
82
83
84 func FilterOneof(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option {
85 md := message.ProtoReflect().Descriptor()
86 return FilterDescriptor(mustFindOneofDescriptor(md, name), opt)
87 }
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103 func FilterDescriptor(desc protoreflect.Descriptor, opt cmp.Option) cmp.Option {
104 f := newNameFilters(desc)
105 return cmp.FilterPath(f.Filter, opt)
106 }
107
108
109
110
111
112 func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option {
113 var ds []protoreflect.Descriptor
114 for _, e := range enums {
115 ds = append(ds, e.Descriptor())
116 }
117 return IgnoreDescriptors(ds...)
118 }
119
120
121
122
123
124 func IgnoreMessages(messages ...proto.Message) cmp.Option {
125 var ds []protoreflect.Descriptor
126 for _, m := range messages {
127 ds = append(ds, m.ProtoReflect().Descriptor())
128 }
129 return IgnoreDescriptors(ds...)
130 }
131
132
133
134
135
136
137 func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
138 var ds []protoreflect.Descriptor
139 md := message.ProtoReflect().Descriptor()
140 for _, s := range names {
141 ds = append(ds, mustFindFieldDescriptor(md, s))
142 }
143 return IgnoreDescriptors(ds...)
144 }
145
146
147
148
149
150
151 func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option {
152 var ds []protoreflect.Descriptor
153 md := message.ProtoReflect().Descriptor()
154 for _, s := range names {
155 ds = append(ds, mustFindOneofDescriptor(md, s))
156 }
157 return IgnoreDescriptors(ds...)
158 }
159
160
161
162
163
164 func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option {
165 return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore())
166 }
167
168 func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor {
169 d := findDescriptor(md, s)
170 if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.TextName() == string(s) {
171 return fd
172 }
173
174 var suggestion string
175 switch d := d.(type) {
176 case protoreflect.FieldDescriptor:
177 suggestion = fmt.Sprintf("; consider specifying field %q instead", d.TextName())
178 case protoreflect.OneofDescriptor:
179 suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name())
180 }
181 panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion))
182 }
183
184 func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor {
185 d := findDescriptor(md, s)
186 if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s {
187 return od
188 }
189
190 var suggestion string
191 switch d := d.(type) {
192 case protoreflect.OneofDescriptor:
193 suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name())
194 case protoreflect.FieldDescriptor:
195 suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.TextName())
196 }
197 panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion))
198 }
199
200 func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor {
201
202 if fd := md.Fields().ByTextName(string(s)); fd != nil {
203 return fd
204 }
205 if od := md.Oneofs().ByName(s); od != nil && !od.IsSynthetic() {
206 return od
207 }
208
209
210
211
212
213
214 normalize := func(s protoreflect.Name) string {
215 return strings.Replace(strings.ToLower(string(s)), "_", "", -1)
216 }
217 for i := 0; i < md.Fields().Len(); i++ {
218 if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) {
219 return fd
220 }
221 }
222 for i := 0; i < md.Oneofs().Len(); i++ {
223 if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) {
224 return od
225 }
226 }
227 return nil
228 }
229
230 type nameFilters struct {
231 names map[protoreflect.FullName]bool
232 }
233
234 func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters {
235 f := &nameFilters{names: make(map[protoreflect.FullName]bool)}
236 for _, d := range descs {
237 switch d := d.(type) {
238 case protoreflect.EnumDescriptor:
239 f.names[d.FullName()] = true
240 case protoreflect.MessageDescriptor:
241 f.names[d.FullName()] = true
242 case protoreflect.FieldDescriptor:
243 f.names[d.FullName()] = true
244 case protoreflect.OneofDescriptor:
245 for i := 0; i < d.Fields().Len(); i++ {
246 f.names[d.Fields().Get(i).FullName()] = true
247 }
248 default:
249 panic("invalid descriptor type")
250 }
251 }
252 return f
253 }
254
255 func (f *nameFilters) Filter(p cmp.Path) bool {
256 vx, vy := p.Last().Values()
257 return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p)
258 }
259
260 func (f *nameFilters) filterFields(p cmp.Path) bool {
261
262
263 if _, ok := p.Last().(cmp.TypeAssertion); ok {
264 p = p[:len(p)-1]
265 }
266
267
268 mi, ok := p.Index(-1).(cmp.MapIndex)
269 if !ok {
270 return false
271 }
272 ps := p.Index(-2)
273 if ps.Type() != messageReflectType {
274 return false
275 }
276
277
278 vx, vy := ps.Values()
279 mx := vx.Interface().(Message)
280 my := vy.Interface().(Message)
281 k := mi.Key().String()
282 if f.filterFieldName(mx, k) && f.filterFieldName(my, k) {
283 return true
284 }
285
286
287 vx, vy = mi.Values()
288 if f.filterFieldValue(vx) && f.filterFieldValue(vy) {
289 return true
290 }
291
292 return false
293 }
294
295 func (f *nameFilters) filterFieldName(m Message, k string) bool {
296 if _, ok := m[k]; !ok {
297 return true
298 }
299 var fd protoreflect.FieldDescriptor
300 switch mm := m[messageTypeKey].(messageMeta); {
301 case protoreflect.Name(k).IsValid():
302 fd = mm.md.Fields().ByTextName(k)
303 default:
304 fd = mm.xds[k]
305 }
306 if fd != nil {
307 return f.names[fd.FullName()]
308 }
309 return false
310 }
311
312 func (f *nameFilters) filterFieldValue(v reflect.Value) bool {
313 if !v.IsValid() {
314 return true
315 }
316 v = v.Elem()
317 switch t := v.Type(); {
318 case t == enumReflectType || t == messageReflectType:
319
320 return f.filterValue(v)
321 case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
322
323 return f.filterValue(v.Index(0))
324 case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
325
326 return f.filterValue(v.MapIndex(v.MapKeys()[0]))
327 }
328 return false
329 }
330
331 func (f *nameFilters) filterValue(v reflect.Value) bool {
332 if !v.IsValid() {
333 return true
334 }
335 if !v.CanInterface() {
336 return false
337 }
338 switch v := v.Interface().(type) {
339 case Enum:
340 return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
341 case Message:
342 return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
343 }
344 return false
345 }
346
347
348
349
350
351
352 func IgnoreDefaultScalars() cmp.Option {
353 return cmp.FilterPath(func(p cmp.Path) bool {
354
355 mi, ok := p.Index(-1).(cmp.MapIndex)
356 if !ok {
357 return false
358 }
359 ps := p.Index(-2)
360 if ps.Type() != messageReflectType {
361 return false
362 }
363
364
365 vx, vy := ps.Values()
366 mx := vx.Interface().(Message)
367 my := vy.Interface().(Message)
368 k := mi.Key().String()
369 return isDefaultScalar(mx, k) && isDefaultScalar(my, k)
370 }, cmp.Ignore())
371 }
372
373 func isDefaultScalar(m Message, k string) bool {
374 if _, ok := m[k]; !ok {
375 return true
376 }
377
378 var fd protoreflect.FieldDescriptor
379 switch mm := m[messageTypeKey].(messageMeta); {
380 case protoreflect.Name(k).IsValid():
381 fd = mm.md.Fields().ByTextName(k)
382 default:
383 fd = mm.xds[k]
384 }
385 if fd == nil || !fd.Default().IsValid() {
386 return false
387 }
388 switch fd.Kind() {
389 case protoreflect.BytesKind:
390 v, ok := m[k].([]byte)
391 return ok && bytes.Equal(fd.Default().Bytes(), v)
392 case protoreflect.FloatKind:
393 v, ok := m[k].(float32)
394 return ok && equalFloat64(fd.Default().Float(), float64(v))
395 case protoreflect.DoubleKind:
396 v, ok := m[k].(float64)
397 return ok && equalFloat64(fd.Default().Float(), float64(v))
398 case protoreflect.EnumKind:
399 v, ok := m[k].(Enum)
400 return ok && fd.Default().Enum() == v.Number()
401 default:
402 return reflect.DeepEqual(fd.Default().Interface(), m[k])
403 }
404 }
405
406 func equalFloat64(x, y float64) bool {
407 return x == y || (math.IsNaN(x) && math.IsNaN(y))
408 }
409
410
411
412
413
414
415 func IgnoreEmptyMessages() cmp.Option {
416 return cmp.FilterPath(func(p cmp.Path) bool {
417 vx, vy := p.Last().Values()
418 return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p)
419 }, cmp.Ignore())
420 }
421
422 func isEmptyMessageFields(p cmp.Path) bool {
423
424 mi, ok := p.Index(-1).(cmp.MapIndex)
425 if !ok {
426 return false
427 }
428 ps := p.Index(-2)
429 if ps.Type() != messageReflectType {
430 return false
431 }
432
433
434 vx, vy := mi.Values()
435 if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) {
436 return true
437 }
438
439 return false
440 }
441
442 func isEmptyMessageFieldValue(v reflect.Value) bool {
443 if !v.IsValid() {
444 return true
445 }
446 v = v.Elem()
447 switch t := v.Type(); {
448 case t == messageReflectType:
449
450 if !isEmptyMessage(v) {
451 return false
452 }
453 case t.Kind() == reflect.Slice && t.Elem() == messageReflectType:
454
455 for i := 0; i < v.Len(); i++ {
456 if !isEmptyMessage(v.Index(i)) {
457 return false
458 }
459 }
460 case t.Kind() == reflect.Map && t.Elem() == messageReflectType:
461
462 for _, k := range v.MapKeys() {
463 if !isEmptyMessage(v.MapIndex(k)) {
464 return false
465 }
466 }
467 default:
468 return false
469 }
470 return true
471 }
472
473 func isEmptyMessage(v reflect.Value) bool {
474 if !v.IsValid() {
475 return true
476 }
477 if !v.CanInterface() {
478 return false
479 }
480 if m, ok := v.Interface().(Message); ok {
481 for k := range m {
482 if k != messageTypeKey && k != messageInvalidKey {
483 return false
484 }
485 }
486 return true
487 }
488 return false
489 }
490
491
492
493
494 func IgnoreUnknown() cmp.Option {
495 return cmp.FilterPath(func(p cmp.Path) bool {
496
497 mi, ok := p.Index(-1).(cmp.MapIndex)
498 if !ok {
499 return false
500 }
501 ps := p.Index(-2)
502 if ps.Type() != messageReflectType {
503 return false
504 }
505
506
507 return strings.Trim(mi.Key().String(), "0123456789") == ""
508 }, cmp.Ignore())
509 }
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528 func SortRepeated(lessFunc interface{}) cmp.Option {
529 t, ok := checkTTBFunc(lessFunc)
530 if !ok {
531 panic(fmt.Sprintf("invalid less function: %T", lessFunc))
532 }
533
534 var opt cmp.Option
535 var sliceType reflect.Type
536 switch vf := reflect.ValueOf(lessFunc); {
537 case t.Implements(enumV2Type):
538 et := reflect.Zero(t).Interface().(protoreflect.Enum).Type()
539 lessFunc = func(x, y Enum) bool {
540 vx := reflect.ValueOf(et.New(x.Number()))
541 vy := reflect.ValueOf(et.New(y.Number()))
542 return vf.Call([]reflect.Value{vx, vy})[0].Bool()
543 }
544 opt = FilterDescriptor(et.Descriptor(), cmpopts.SortSlices(lessFunc))
545 sliceType = reflect.SliceOf(enumReflectType)
546 case t.Implements(messageV2Type):
547 mt := reflect.Zero(t).Interface().(protoreflect.ProtoMessage).ProtoReflect().Type()
548 lessFunc = func(x, y Message) bool {
549 mx := mt.New().Interface()
550 my := mt.New().Interface()
551 proto.Merge(mx, x)
552 proto.Merge(my, y)
553 vx := reflect.ValueOf(mx)
554 vy := reflect.ValueOf(my)
555 return vf.Call([]reflect.Value{vx, vy})[0].Bool()
556 }
557 opt = FilterDescriptor(mt.Descriptor(), cmpopts.SortSlices(lessFunc))
558 sliceType = reflect.SliceOf(messageReflectType)
559 default:
560 switch t {
561 case reflect.TypeOf(bool(false)):
562 case reflect.TypeOf(int32(0)):
563 case reflect.TypeOf(int64(0)):
564 case reflect.TypeOf(uint32(0)):
565 case reflect.TypeOf(uint64(0)):
566 case reflect.TypeOf(float32(0)):
567 case reflect.TypeOf(float64(0)):
568 case reflect.TypeOf(string("")):
569 case reflect.TypeOf([]byte(nil)):
570 default:
571 panic(fmt.Sprintf("invalid element type: %v", t))
572 }
573 opt = cmpopts.SortSlices(lessFunc)
574 sliceType = reflect.SliceOf(t)
575 }
576
577 return cmp.FilterPath(func(p cmp.Path) bool {
578
579 if t := p.Index(-1).Type(); t == nil || t != sliceType {
580 return false
581 }
582 if t := p.Index(-2).Type(); t == nil || t.Kind() != reflect.Interface {
583 return false
584 }
585 if t := p.Index(-3).Type(); t == nil || t != messageReflectType {
586 return false
587 }
588 return true
589 }, opt)
590 }
591
592 func checkTTBFunc(lessFunc interface{}) (reflect.Type, bool) {
593 switch t := reflect.TypeOf(lessFunc); {
594 case t == nil:
595 return nil, false
596 case t.NumIn() != 2 || t.In(0) != t.In(1) || t.IsVariadic():
597 return nil, false
598 case t.NumOut() != 1 || t.Out(0) != reflect.TypeOf(false):
599 return nil, false
600 default:
601 return t.In(0), true
602 }
603 }
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628 func SortRepeatedFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
629 var opts cmp.Options
630 md := message.ProtoReflect().Descriptor()
631 for _, name := range names {
632 fd := mustFindFieldDescriptor(md, name)
633 if !fd.IsList() {
634 panic(fmt.Sprintf("message field %q is not repeated", fd.FullName()))
635 }
636
637 var lessFunc interface{}
638 switch fd.Kind() {
639 case protoreflect.BoolKind:
640 lessFunc = func(x, y bool) bool { return !x && y }
641 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
642 lessFunc = func(x, y int32) bool { return x < y }
643 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
644 lessFunc = func(x, y int64) bool { return x < y }
645 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
646 lessFunc = func(x, y uint32) bool { return x < y }
647 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
648 lessFunc = func(x, y uint64) bool { return x < y }
649 case protoreflect.FloatKind:
650 lessFunc = lessF32
651 case protoreflect.DoubleKind:
652 lessFunc = lessF64
653 case protoreflect.StringKind:
654 lessFunc = func(x, y string) bool { return x < y }
655 case protoreflect.BytesKind:
656 lessFunc = func(x, y []byte) bool { return bytes.Compare(x, y) < 0 }
657 case protoreflect.EnumKind:
658 lessFunc = func(x, y Enum) bool { return x.Number() < y.Number() }
659 case protoreflect.MessageKind, protoreflect.GroupKind:
660 lessFunc = func(x, y Message) bool { return x.String() < y.String() }
661 default:
662 panic(fmt.Sprintf("invalid kind: %v", fd.Kind()))
663 }
664 opts = append(opts, FilterDescriptor(fd, cmpopts.SortSlices(lessFunc)))
665 }
666 return opts
667 }
668
669 func lessF32(x, y float32) bool {
670
671 xi := int32(math.Float32bits(x))
672 yi := int32(math.Float32bits(y))
673 xi ^= int32(uint32(xi>>31) >> 1)
674 yi ^= int32(uint32(yi>>31) >> 1)
675 return xi < yi
676 }
677 func lessF64(x, y float64) bool {
678
679 xi := int64(math.Float64bits(x))
680 yi := int64(math.Float64bits(y))
681 xi ^= int64(uint64(xi>>63) >> 1)
682 yi ^= int64(uint64(yi>>63) >> 1)
683 return xi < yi
684 }
685
View as plain text