...

Source file src/google.golang.org/protobuf/testing/protocmp/util.go

Documentation: google.golang.org/protobuf/testing/protocmp

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package protocmp
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"math"
    11  	"reflect"
    12  	"strings"
    13  
    14  	"github.com/google/go-cmp/cmp"
    15  	"github.com/google/go-cmp/cmp/cmpopts"
    16  
    17  	"google.golang.org/protobuf/proto"
    18  	"google.golang.org/protobuf/reflect/protoreflect"
    19  )
    20  
    21  var (
    22  	enumReflectType    = reflect.TypeOf(Enum{})
    23  	messageReflectType = reflect.TypeOf(Message{})
    24  )
    25  
    26  // FilterEnum filters opt to only be applicable on a standalone [Enum],
    27  // singular fields of enums, list fields of enums, or map fields of enum values,
    28  // where the enum is the same type as the specified enum.
    29  //
    30  // The Go type of the last path step may be an:
    31  //   - [Enum] for singular fields, elements of a repeated field,
    32  //     values of a map field, or standalone [Enum] values
    33  //   - [][Enum] for list fields
    34  //   - map[K][Enum] for map fields
    35  //   - interface{} for a [Message] map entry value
    36  //
    37  // This must be used in conjunction with [Transform].
    38  func FilterEnum(enum protoreflect.Enum, opt cmp.Option) cmp.Option {
    39  	return FilterDescriptor(enum.Descriptor(), opt)
    40  }
    41  
    42  // FilterMessage filters opt to only be applicable on a standalone [Message] values,
    43  // singular fields of messages, list fields of messages, or map fields of
    44  // message values, where the message is the same type as the specified message.
    45  //
    46  // The Go type of the last path step may be an:
    47  //   - [Message] for singular fields, elements of a repeated field,
    48  //     values of a map field, or standalone [Message] values
    49  //   - [][Message] for list fields
    50  //   - map[K][Message] for map fields
    51  //   - interface{} for a [Message] map entry value
    52  //
    53  // This must be used in conjunction with [Transform].
    54  func FilterMessage(message proto.Message, opt cmp.Option) cmp.Option {
    55  	return FilterDescriptor(message.ProtoReflect().Descriptor(), opt)
    56  }
    57  
    58  // FilterField filters opt to only be applicable on the specified field
    59  // in the message. It panics if a field of the given name does not exist.
    60  //
    61  // The Go type of the last path step may be an:
    62  //   - T for singular fields
    63  //   - []T for list fields
    64  //   - map[K]T for map fields
    65  //   - interface{} for a [Message] map entry value
    66  //
    67  // This must be used in conjunction with [Transform].
    68  func FilterField(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option {
    69  	md := message.ProtoReflect().Descriptor()
    70  	return FilterDescriptor(mustFindFieldDescriptor(md, name), opt)
    71  }
    72  
    73  // FilterOneof filters opt to only be applicable on all fields within the
    74  // specified oneof in the message. It panics if a oneof of the given name
    75  // does not exist.
    76  //
    77  // The Go type of the last path step may be an:
    78  //   - T for singular fields
    79  //   - []T for list fields
    80  //   - map[K]T for map fields
    81  //   - interface{} for a [Message] map entry value
    82  //
    83  // This must be used in conjunction with [Transform].
    84  func FilterOneof(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option {
    85  	md := message.ProtoReflect().Descriptor()
    86  	return FilterDescriptor(mustFindOneofDescriptor(md, name), opt)
    87  }
    88  
    89  // FilterDescriptor ignores the specified descriptor.
    90  //
    91  // The following descriptor types may be specified:
    92  //   - [protoreflect.EnumDescriptor]
    93  //   - [protoreflect.MessageDescriptor]
    94  //   - [protoreflect.FieldDescriptor]
    95  //   - [protoreflect.OneofDescriptor]
    96  //
    97  // For the behavior of each, see the corresponding filter function.
    98  // Since this filter accepts a [protoreflect.FieldDescriptor], it can be used
    99  // to also filter for extension fields as a [protoreflect.ExtensionDescriptor]
   100  // is just an alias to [protoreflect.FieldDescriptor].
   101  //
   102  // This must be used in conjunction with [Transform].
   103  func FilterDescriptor(desc protoreflect.Descriptor, opt cmp.Option) cmp.Option {
   104  	f := newNameFilters(desc)
   105  	return cmp.FilterPath(f.Filter, opt)
   106  }
   107  
   108  // IgnoreEnums ignores all enums of the specified types.
   109  // It is equivalent to FilterEnum(enum, cmp.Ignore()) for each enum.
   110  //
   111  // This must be used in conjunction with [Transform].
   112  func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option {
   113  	var ds []protoreflect.Descriptor
   114  	for _, e := range enums {
   115  		ds = append(ds, e.Descriptor())
   116  	}
   117  	return IgnoreDescriptors(ds...)
   118  }
   119  
   120  // IgnoreMessages ignores all messages of the specified types.
   121  // It is equivalent to [FilterMessage](message, [cmp.Ignore]()) for each message.
   122  //
   123  // This must be used in conjunction with [Transform].
   124  func IgnoreMessages(messages ...proto.Message) cmp.Option {
   125  	var ds []protoreflect.Descriptor
   126  	for _, m := range messages {
   127  		ds = append(ds, m.ProtoReflect().Descriptor())
   128  	}
   129  	return IgnoreDescriptors(ds...)
   130  }
   131  
   132  // IgnoreFields ignores the specified fields in the specified message.
   133  // It is equivalent to [FilterField](message, name, [cmp.Ignore]()) for each field
   134  // in the message.
   135  //
   136  // This must be used in conjunction with [Transform].
   137  func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
   138  	var ds []protoreflect.Descriptor
   139  	md := message.ProtoReflect().Descriptor()
   140  	for _, s := range names {
   141  		ds = append(ds, mustFindFieldDescriptor(md, s))
   142  	}
   143  	return IgnoreDescriptors(ds...)
   144  }
   145  
   146  // IgnoreOneofs ignores fields of the specified oneofs in the specified message.
   147  // It is equivalent to FilterOneof(message, name, cmp.Ignore()) for each oneof
   148  // in the message.
   149  //
   150  // This must be used in conjunction with [Transform].
   151  func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option {
   152  	var ds []protoreflect.Descriptor
   153  	md := message.ProtoReflect().Descriptor()
   154  	for _, s := range names {
   155  		ds = append(ds, mustFindOneofDescriptor(md, s))
   156  	}
   157  	return IgnoreDescriptors(ds...)
   158  }
   159  
   160  // IgnoreDescriptors ignores the specified set of descriptors.
   161  // It is equivalent to [FilterDescriptor](desc, [cmp.Ignore]()) for each descriptor.
   162  //
   163  // This must be used in conjunction with [Transform].
   164  func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option {
   165  	return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore())
   166  }
   167  
   168  func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor {
   169  	d := findDescriptor(md, s)
   170  	if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.TextName() == string(s) {
   171  		return fd
   172  	}
   173  
   174  	var suggestion string
   175  	switch d := d.(type) {
   176  	case protoreflect.FieldDescriptor:
   177  		suggestion = fmt.Sprintf("; consider specifying field %q instead", d.TextName())
   178  	case protoreflect.OneofDescriptor:
   179  		suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name())
   180  	}
   181  	panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion))
   182  }
   183  
   184  func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor {
   185  	d := findDescriptor(md, s)
   186  	if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s {
   187  		return od
   188  	}
   189  
   190  	var suggestion string
   191  	switch d := d.(type) {
   192  	case protoreflect.OneofDescriptor:
   193  		suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name())
   194  	case protoreflect.FieldDescriptor:
   195  		suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.TextName())
   196  	}
   197  	panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion))
   198  }
   199  
   200  func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor {
   201  	// Exact match.
   202  	if fd := md.Fields().ByTextName(string(s)); fd != nil {
   203  		return fd
   204  	}
   205  	if od := md.Oneofs().ByName(s); od != nil && !od.IsSynthetic() {
   206  		return od
   207  	}
   208  
   209  	// Best-effort match.
   210  	//
   211  	// It's a common user mistake to use the CamelCased field name as it appears
   212  	// in the generated Go struct. Instead of complaining that it doesn't exist,
   213  	// suggest the real protobuf name that the user may have desired.
   214  	normalize := func(s protoreflect.Name) string {
   215  		return strings.Replace(strings.ToLower(string(s)), "_", "", -1)
   216  	}
   217  	for i := 0; i < md.Fields().Len(); i++ {
   218  		if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) {
   219  			return fd
   220  		}
   221  	}
   222  	for i := 0; i < md.Oneofs().Len(); i++ {
   223  		if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) {
   224  			return od
   225  		}
   226  	}
   227  	return nil
   228  }
   229  
   230  type nameFilters struct {
   231  	names map[protoreflect.FullName]bool
   232  }
   233  
   234  func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters {
   235  	f := &nameFilters{names: make(map[protoreflect.FullName]bool)}
   236  	for _, d := range descs {
   237  		switch d := d.(type) {
   238  		case protoreflect.EnumDescriptor:
   239  			f.names[d.FullName()] = true
   240  		case protoreflect.MessageDescriptor:
   241  			f.names[d.FullName()] = true
   242  		case protoreflect.FieldDescriptor:
   243  			f.names[d.FullName()] = true
   244  		case protoreflect.OneofDescriptor:
   245  			for i := 0; i < d.Fields().Len(); i++ {
   246  				f.names[d.Fields().Get(i).FullName()] = true
   247  			}
   248  		default:
   249  			panic("invalid descriptor type")
   250  		}
   251  	}
   252  	return f
   253  }
   254  
   255  func (f *nameFilters) Filter(p cmp.Path) bool {
   256  	vx, vy := p.Last().Values()
   257  	return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p)
   258  }
   259  
   260  func (f *nameFilters) filterFields(p cmp.Path) bool {
   261  	// Trim off trailing type-assertions so that the filter can match on the
   262  	// concrete value held within an interface value.
   263  	if _, ok := p.Last().(cmp.TypeAssertion); ok {
   264  		p = p[:len(p)-1]
   265  	}
   266  
   267  	// Filter for Message maps.
   268  	mi, ok := p.Index(-1).(cmp.MapIndex)
   269  	if !ok {
   270  		return false
   271  	}
   272  	ps := p.Index(-2)
   273  	if ps.Type() != messageReflectType {
   274  		return false
   275  	}
   276  
   277  	// Check field name.
   278  	vx, vy := ps.Values()
   279  	mx := vx.Interface().(Message)
   280  	my := vy.Interface().(Message)
   281  	k := mi.Key().String()
   282  	if f.filterFieldName(mx, k) && f.filterFieldName(my, k) {
   283  		return true
   284  	}
   285  
   286  	// Check field value.
   287  	vx, vy = mi.Values()
   288  	if f.filterFieldValue(vx) && f.filterFieldValue(vy) {
   289  		return true
   290  	}
   291  
   292  	return false
   293  }
   294  
   295  func (f *nameFilters) filterFieldName(m Message, k string) bool {
   296  	if _, ok := m[k]; !ok {
   297  		return true // treat missing fields as already filtered
   298  	}
   299  	var fd protoreflect.FieldDescriptor
   300  	switch mm := m[messageTypeKey].(messageMeta); {
   301  	case protoreflect.Name(k).IsValid():
   302  		fd = mm.md.Fields().ByTextName(k)
   303  	default:
   304  		fd = mm.xds[k]
   305  	}
   306  	if fd != nil {
   307  		return f.names[fd.FullName()]
   308  	}
   309  	return false
   310  }
   311  
   312  func (f *nameFilters) filterFieldValue(v reflect.Value) bool {
   313  	if !v.IsValid() {
   314  		return true // implies missing slice element or map entry
   315  	}
   316  	v = v.Elem() // map entries are always populated values
   317  	switch t := v.Type(); {
   318  	case t == enumReflectType || t == messageReflectType:
   319  		// Check for singular message or enum field.
   320  		return f.filterValue(v)
   321  	case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
   322  		// Check for list field of enum or message type.
   323  		return f.filterValue(v.Index(0))
   324  	case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
   325  		// Check for map field of enum or message type.
   326  		return f.filterValue(v.MapIndex(v.MapKeys()[0]))
   327  	}
   328  	return false
   329  }
   330  
   331  func (f *nameFilters) filterValue(v reflect.Value) bool {
   332  	if !v.IsValid() {
   333  		return true // implies missing slice element or map entry
   334  	}
   335  	if !v.CanInterface() {
   336  		return false // implies unexported struct field
   337  	}
   338  	switch v := v.Interface().(type) {
   339  	case Enum:
   340  		return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
   341  	case Message:
   342  		return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
   343  	}
   344  	return false
   345  }
   346  
   347  // IgnoreDefaultScalars ignores singular scalars that are unpopulated or
   348  // explicitly set to the default value.
   349  // This option does not effect elements in a list or entries in a map.
   350  //
   351  // This must be used in conjunction with [Transform].
   352  func IgnoreDefaultScalars() cmp.Option {
   353  	return cmp.FilterPath(func(p cmp.Path) bool {
   354  		// Filter for Message maps.
   355  		mi, ok := p.Index(-1).(cmp.MapIndex)
   356  		if !ok {
   357  			return false
   358  		}
   359  		ps := p.Index(-2)
   360  		if ps.Type() != messageReflectType {
   361  			return false
   362  		}
   363  
   364  		// Check whether both fields are default or unpopulated scalars.
   365  		vx, vy := ps.Values()
   366  		mx := vx.Interface().(Message)
   367  		my := vy.Interface().(Message)
   368  		k := mi.Key().String()
   369  		return isDefaultScalar(mx, k) && isDefaultScalar(my, k)
   370  	}, cmp.Ignore())
   371  }
   372  
   373  func isDefaultScalar(m Message, k string) bool {
   374  	if _, ok := m[k]; !ok {
   375  		return true
   376  	}
   377  
   378  	var fd protoreflect.FieldDescriptor
   379  	switch mm := m[messageTypeKey].(messageMeta); {
   380  	case protoreflect.Name(k).IsValid():
   381  		fd = mm.md.Fields().ByTextName(k)
   382  	default:
   383  		fd = mm.xds[k]
   384  	}
   385  	if fd == nil || !fd.Default().IsValid() {
   386  		return false
   387  	}
   388  	switch fd.Kind() {
   389  	case protoreflect.BytesKind:
   390  		v, ok := m[k].([]byte)
   391  		return ok && bytes.Equal(fd.Default().Bytes(), v)
   392  	case protoreflect.FloatKind:
   393  		v, ok := m[k].(float32)
   394  		return ok && equalFloat64(fd.Default().Float(), float64(v))
   395  	case protoreflect.DoubleKind:
   396  		v, ok := m[k].(float64)
   397  		return ok && equalFloat64(fd.Default().Float(), float64(v))
   398  	case protoreflect.EnumKind:
   399  		v, ok := m[k].(Enum)
   400  		return ok && fd.Default().Enum() == v.Number()
   401  	default:
   402  		return reflect.DeepEqual(fd.Default().Interface(), m[k])
   403  	}
   404  }
   405  
   406  func equalFloat64(x, y float64) bool {
   407  	return x == y || (math.IsNaN(x) && math.IsNaN(y))
   408  }
   409  
   410  // IgnoreEmptyMessages ignores messages that are empty or unpopulated.
   411  // It applies to standalone [Message] values, singular message fields,
   412  // list fields of messages, and map fields of message values.
   413  //
   414  // This must be used in conjunction with [Transform].
   415  func IgnoreEmptyMessages() cmp.Option {
   416  	return cmp.FilterPath(func(p cmp.Path) bool {
   417  		vx, vy := p.Last().Values()
   418  		return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p)
   419  	}, cmp.Ignore())
   420  }
   421  
   422  func isEmptyMessageFields(p cmp.Path) bool {
   423  	// Filter for Message maps.
   424  	mi, ok := p.Index(-1).(cmp.MapIndex)
   425  	if !ok {
   426  		return false
   427  	}
   428  	ps := p.Index(-2)
   429  	if ps.Type() != messageReflectType {
   430  		return false
   431  	}
   432  
   433  	// Check field value.
   434  	vx, vy := mi.Values()
   435  	if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) {
   436  		return true
   437  	}
   438  
   439  	return false
   440  }
   441  
   442  func isEmptyMessageFieldValue(v reflect.Value) bool {
   443  	if !v.IsValid() {
   444  		return true // implies missing slice element or map entry
   445  	}
   446  	v = v.Elem() // map entries are always populated values
   447  	switch t := v.Type(); {
   448  	case t == messageReflectType:
   449  		// Check singular field for empty message.
   450  		if !isEmptyMessage(v) {
   451  			return false
   452  		}
   453  	case t.Kind() == reflect.Slice && t.Elem() == messageReflectType:
   454  		// Check list field for all empty message elements.
   455  		for i := 0; i < v.Len(); i++ {
   456  			if !isEmptyMessage(v.Index(i)) {
   457  				return false
   458  			}
   459  		}
   460  	case t.Kind() == reflect.Map && t.Elem() == messageReflectType:
   461  		// Check map field for all empty message values.
   462  		for _, k := range v.MapKeys() {
   463  			if !isEmptyMessage(v.MapIndex(k)) {
   464  				return false
   465  			}
   466  		}
   467  	default:
   468  		return false
   469  	}
   470  	return true
   471  }
   472  
   473  func isEmptyMessage(v reflect.Value) bool {
   474  	if !v.IsValid() {
   475  		return true // implies missing slice element or map entry
   476  	}
   477  	if !v.CanInterface() {
   478  		return false // implies unexported struct field
   479  	}
   480  	if m, ok := v.Interface().(Message); ok {
   481  		for k := range m {
   482  			if k != messageTypeKey && k != messageInvalidKey {
   483  				return false
   484  			}
   485  		}
   486  		return true
   487  	}
   488  	return false
   489  }
   490  
   491  // IgnoreUnknown ignores unknown fields in all messages.
   492  //
   493  // This must be used in conjunction with [Transform].
   494  func IgnoreUnknown() cmp.Option {
   495  	return cmp.FilterPath(func(p cmp.Path) bool {
   496  		// Filter for Message maps.
   497  		mi, ok := p.Index(-1).(cmp.MapIndex)
   498  		if !ok {
   499  			return false
   500  		}
   501  		ps := p.Index(-2)
   502  		if ps.Type() != messageReflectType {
   503  			return false
   504  		}
   505  
   506  		// Filter for unknown fields (which always have a numeric map key).
   507  		return strings.Trim(mi.Key().String(), "0123456789") == ""
   508  	}, cmp.Ignore())
   509  }
   510  
   511  // SortRepeated sorts repeated fields of the specified element type.
   512  // The less function must be of the form "func(T, T) bool" where T is the
   513  // Go element type for the repeated field kind.
   514  //
   515  // The element type T can be one of the following:
   516  //   - Go type for a protobuf scalar kind except for an enum
   517  //     (i.e., bool, int32, int64, uint32, uint64, float32, float64, string, and []byte)
   518  //   - E where E is a concrete enum type that implements [protoreflect.Enum]
   519  //   - M where M is a concrete message type that implement [proto.Message]
   520  //
   521  // This option only applies to repeated fields within a protobuf message.
   522  // It does not operate on higher-order Go types that seem like a repeated field.
   523  // For example, a []T outside the context of a protobuf message will not be
   524  // handled by this option. To sort Go slices that are not repeated fields,
   525  // consider using [github.com/google/go-cmp/cmp/cmpopts.SortSlices] instead.
   526  //
   527  // This must be used in conjunction with [Transform].
   528  func SortRepeated(lessFunc interface{}) cmp.Option {
   529  	t, ok := checkTTBFunc(lessFunc)
   530  	if !ok {
   531  		panic(fmt.Sprintf("invalid less function: %T", lessFunc))
   532  	}
   533  
   534  	var opt cmp.Option
   535  	var sliceType reflect.Type
   536  	switch vf := reflect.ValueOf(lessFunc); {
   537  	case t.Implements(enumV2Type):
   538  		et := reflect.Zero(t).Interface().(protoreflect.Enum).Type()
   539  		lessFunc = func(x, y Enum) bool {
   540  			vx := reflect.ValueOf(et.New(x.Number()))
   541  			vy := reflect.ValueOf(et.New(y.Number()))
   542  			return vf.Call([]reflect.Value{vx, vy})[0].Bool()
   543  		}
   544  		opt = FilterDescriptor(et.Descriptor(), cmpopts.SortSlices(lessFunc))
   545  		sliceType = reflect.SliceOf(enumReflectType)
   546  	case t.Implements(messageV2Type):
   547  		mt := reflect.Zero(t).Interface().(protoreflect.ProtoMessage).ProtoReflect().Type()
   548  		lessFunc = func(x, y Message) bool {
   549  			mx := mt.New().Interface()
   550  			my := mt.New().Interface()
   551  			proto.Merge(mx, x)
   552  			proto.Merge(my, y)
   553  			vx := reflect.ValueOf(mx)
   554  			vy := reflect.ValueOf(my)
   555  			return vf.Call([]reflect.Value{vx, vy})[0].Bool()
   556  		}
   557  		opt = FilterDescriptor(mt.Descriptor(), cmpopts.SortSlices(lessFunc))
   558  		sliceType = reflect.SliceOf(messageReflectType)
   559  	default:
   560  		switch t {
   561  		case reflect.TypeOf(bool(false)):
   562  		case reflect.TypeOf(int32(0)):
   563  		case reflect.TypeOf(int64(0)):
   564  		case reflect.TypeOf(uint32(0)):
   565  		case reflect.TypeOf(uint64(0)):
   566  		case reflect.TypeOf(float32(0)):
   567  		case reflect.TypeOf(float64(0)):
   568  		case reflect.TypeOf(string("")):
   569  		case reflect.TypeOf([]byte(nil)):
   570  		default:
   571  			panic(fmt.Sprintf("invalid element type: %v", t))
   572  		}
   573  		opt = cmpopts.SortSlices(lessFunc)
   574  		sliceType = reflect.SliceOf(t)
   575  	}
   576  
   577  	return cmp.FilterPath(func(p cmp.Path) bool {
   578  		// Filter to only apply to repeated fields within a message.
   579  		if t := p.Index(-1).Type(); t == nil || t != sliceType {
   580  			return false
   581  		}
   582  		if t := p.Index(-2).Type(); t == nil || t.Kind() != reflect.Interface {
   583  			return false
   584  		}
   585  		if t := p.Index(-3).Type(); t == nil || t != messageReflectType {
   586  			return false
   587  		}
   588  		return true
   589  	}, opt)
   590  }
   591  
   592  func checkTTBFunc(lessFunc interface{}) (reflect.Type, bool) {
   593  	switch t := reflect.TypeOf(lessFunc); {
   594  	case t == nil:
   595  		return nil, false
   596  	case t.NumIn() != 2 || t.In(0) != t.In(1) || t.IsVariadic():
   597  		return nil, false
   598  	case t.NumOut() != 1 || t.Out(0) != reflect.TypeOf(false):
   599  		return nil, false
   600  	default:
   601  		return t.In(0), true
   602  	}
   603  }
   604  
   605  // SortRepeatedFields sorts the specified repeated fields.
   606  // Sorting a repeated field is useful for treating the list as a multiset
   607  // (i.e., a set where each value can appear multiple times).
   608  // It panics if the field does not exist or is not a repeated field.
   609  //
   610  // The sort ordering is as follows:
   611  //   - Booleans are sorted where false is sorted before true.
   612  //   - Integers are sorted in ascending order.
   613  //   - Floating-point numbers are sorted in ascending order according to
   614  //     the total ordering defined by IEEE-754 (section 5.10).
   615  //   - Strings and bytes are sorted lexicographically in ascending order.
   616  //   - [Enum] values are sorted in ascending order based on its numeric value.
   617  //   - [Message] values are sorted according to some arbitrary ordering
   618  //     which is undefined and may change in future implementations.
   619  //
   620  // The ordering chosen for repeated messages is unlikely to be aesthetically
   621  // preferred by humans. Consider using a custom sort function:
   622  //
   623  //	FilterField(m, "foo_field", SortRepeated(func(x, y *foopb.MyMessage) bool {
   624  //	    ... // user-provided definition for less
   625  //	}))
   626  //
   627  // This must be used in conjunction with [Transform].
   628  func SortRepeatedFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
   629  	var opts cmp.Options
   630  	md := message.ProtoReflect().Descriptor()
   631  	for _, name := range names {
   632  		fd := mustFindFieldDescriptor(md, name)
   633  		if !fd.IsList() {
   634  			panic(fmt.Sprintf("message field %q is not repeated", fd.FullName()))
   635  		}
   636  
   637  		var lessFunc interface{}
   638  		switch fd.Kind() {
   639  		case protoreflect.BoolKind:
   640  			lessFunc = func(x, y bool) bool { return !x && y }
   641  		case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   642  			lessFunc = func(x, y int32) bool { return x < y }
   643  		case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   644  			lessFunc = func(x, y int64) bool { return x < y }
   645  		case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   646  			lessFunc = func(x, y uint32) bool { return x < y }
   647  		case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   648  			lessFunc = func(x, y uint64) bool { return x < y }
   649  		case protoreflect.FloatKind:
   650  			lessFunc = lessF32
   651  		case protoreflect.DoubleKind:
   652  			lessFunc = lessF64
   653  		case protoreflect.StringKind:
   654  			lessFunc = func(x, y string) bool { return x < y }
   655  		case protoreflect.BytesKind:
   656  			lessFunc = func(x, y []byte) bool { return bytes.Compare(x, y) < 0 }
   657  		case protoreflect.EnumKind:
   658  			lessFunc = func(x, y Enum) bool { return x.Number() < y.Number() }
   659  		case protoreflect.MessageKind, protoreflect.GroupKind:
   660  			lessFunc = func(x, y Message) bool { return x.String() < y.String() }
   661  		default:
   662  			panic(fmt.Sprintf("invalid kind: %v", fd.Kind()))
   663  		}
   664  		opts = append(opts, FilterDescriptor(fd, cmpopts.SortSlices(lessFunc)))
   665  	}
   666  	return opts
   667  }
   668  
   669  func lessF32(x, y float32) bool {
   670  	// Bit-wise implementation of IEEE-754, section 5.10.
   671  	xi := int32(math.Float32bits(x))
   672  	yi := int32(math.Float32bits(y))
   673  	xi ^= int32(uint32(xi>>31) >> 1)
   674  	yi ^= int32(uint32(yi>>31) >> 1)
   675  	return xi < yi
   676  }
   677  func lessF64(x, y float64) bool {
   678  	// Bit-wise implementation of IEEE-754, section 5.10.
   679  	xi := int64(math.Float64bits(x))
   680  	yi := int64(math.Float64bits(y))
   681  	xi ^= int64(uint64(xi>>63) >> 1)
   682  	yi ^= int64(uint64(yi>>63) >> 1)
   683  	return xi < yi
   684  }
   685  

View as plain text