...

Source file src/google.golang.org/protobuf/testing/protocmp/xform.go

Documentation: google.golang.org/protobuf/testing/protocmp

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package protocmp provides protobuf specific options for the
     6  // [github.com/google/go-cmp/cmp] package.
     7  //
     8  // The primary feature is the [Transform] option, which transform [proto.Message]
     9  // types into a [Message] map that is suitable for cmp to introspect upon.
    10  // All other options in this package must be used in conjunction with [Transform].
    11  package protocmp
    12  
    13  import (
    14  	"reflect"
    15  	"strconv"
    16  
    17  	"github.com/google/go-cmp/cmp"
    18  
    19  	"google.golang.org/protobuf/encoding/protowire"
    20  	"google.golang.org/protobuf/internal/genid"
    21  	"google.golang.org/protobuf/internal/msgfmt"
    22  	"google.golang.org/protobuf/proto"
    23  	"google.golang.org/protobuf/reflect/protoreflect"
    24  	"google.golang.org/protobuf/reflect/protoregistry"
    25  	"google.golang.org/protobuf/runtime/protoiface"
    26  	"google.golang.org/protobuf/runtime/protoimpl"
    27  )
    28  
    29  var (
    30  	enumV2Type    = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
    31  	messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
    32  	messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
    33  )
    34  
    35  // Enum is a dynamic representation of a protocol buffer enum that is
    36  // suitable for [cmp.Equal] and [cmp.Diff] to compare upon.
    37  type Enum struct {
    38  	num protoreflect.EnumNumber
    39  	ed  protoreflect.EnumDescriptor
    40  }
    41  
    42  // Descriptor returns the enum descriptor.
    43  // It returns nil for a zero Enum value.
    44  func (e Enum) Descriptor() protoreflect.EnumDescriptor {
    45  	return e.ed
    46  }
    47  
    48  // Number returns the enum value as an integer.
    49  func (e Enum) Number() protoreflect.EnumNumber {
    50  	return e.num
    51  }
    52  
    53  // Equal reports whether e1 and e2 represent the same enum value.
    54  func (e1 Enum) Equal(e2 Enum) bool {
    55  	if e1.ed.FullName() != e2.ed.FullName() {
    56  		return false
    57  	}
    58  	return e1.num == e2.num
    59  }
    60  
    61  // String returns the name of the enum value if known (e.g., "ENUM_VALUE"),
    62  // otherwise it returns the formatted decimal enum number (e.g., "14").
    63  func (e Enum) String() string {
    64  	if ev := e.ed.Values().ByNumber(e.num); ev != nil {
    65  		return string(ev.Name())
    66  	}
    67  	return strconv.Itoa(int(e.num))
    68  }
    69  
    70  const (
    71  	// messageTypeKey indicates the protobuf message type.
    72  	// The value type is always messageMeta.
    73  	// From the public API, it presents itself as only the type, but the
    74  	// underlying data structure holds arbitrary metadata about the message.
    75  	messageTypeKey = "@type"
    76  
    77  	// messageInvalidKey indicates that the message is invalid.
    78  	// The value is always the boolean "true".
    79  	messageInvalidKey = "@invalid"
    80  )
    81  
    82  type messageMeta struct {
    83  	m   proto.Message
    84  	md  protoreflect.MessageDescriptor
    85  	xds map[string]protoreflect.ExtensionDescriptor
    86  }
    87  
    88  func (t messageMeta) String() string {
    89  	return string(t.md.FullName())
    90  }
    91  
    92  func (t1 messageMeta) Equal(t2 messageMeta) bool {
    93  	return t1.md.FullName() == t2.md.FullName()
    94  }
    95  
    96  // Message is a dynamic representation of a protocol buffer message that is
    97  // suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon.
    98  //
    99  // Every populated known field (excluding extension fields) is stored in the map
   100  // with the key being the short name of the field (e.g., "field_name") and
   101  // the value determined by the kind and cardinality of the field.
   102  //
   103  // Singular scalars are represented by the same Go type as [protoreflect.Value],
   104  // singular messages are represented by the [Message] type,
   105  // singular enums are represented by the [Enum] type,
   106  // list fields are represented as a Go slice, and
   107  // map fields are represented as a Go map.
   108  //
   109  // Every populated extension field is stored in the map with the key being the
   110  // full name of the field surrounded by brackets (e.g., "[extension.full.name]")
   111  // and the value determined according to the same rules as known fields.
   112  //
   113  // Every unknown field is stored in the map with the key being the field number
   114  // encoded as a decimal string (e.g., "132") and the value being the raw bytes
   115  // of the encoded field (as the [protoreflect.RawFields] type).
   116  //
   117  // Message values must not be created by or mutated by users.
   118  type Message map[string]interface{}
   119  
   120  // Unwrap returns the original message value.
   121  // It returns nil if this Message was not constructed from another message.
   122  func (m Message) Unwrap() proto.Message {
   123  	mm, _ := m[messageTypeKey].(messageMeta)
   124  	return mm.m
   125  }
   126  
   127  // Descriptor return the message descriptor.
   128  // It returns nil for a zero Message value.
   129  func (m Message) Descriptor() protoreflect.MessageDescriptor {
   130  	mm, _ := m[messageTypeKey].(messageMeta)
   131  	return mm.md
   132  }
   133  
   134  // ProtoReflect returns a reflective view of m.
   135  // It only implements the read-only operations of [protoreflect.Message].
   136  // Calling any mutating operations on m panics.
   137  func (m Message) ProtoReflect() protoreflect.Message {
   138  	return (reflectMessage)(m)
   139  }
   140  
   141  // ProtoMessage is a marker method from the legacy message interface.
   142  func (m Message) ProtoMessage() {}
   143  
   144  // Reset is the required Reset method from the legacy message interface.
   145  func (m Message) Reset() {
   146  	panic("invalid mutation of a read-only message")
   147  }
   148  
   149  // String returns a formatted string for the message.
   150  // It is intended for human debugging and has no guarantees about its
   151  // exact format or the stability of its output.
   152  func (m Message) String() string {
   153  	switch {
   154  	case m == nil:
   155  		return "<nil>"
   156  	case !m.ProtoReflect().IsValid():
   157  		return "<invalid>"
   158  	default:
   159  		return msgfmt.Format(m)
   160  	}
   161  }
   162  
   163  type option struct{}
   164  
   165  // Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message].
   166  // The transformation does not mutate nor alias any converted messages.
   167  //
   168  // The google.protobuf.Any message is automatically unmarshaled such that the
   169  // "value" field is a [Message] representing the underlying message value
   170  // assuming it could be resolved and properly unmarshaled.
   171  //
   172  // This does not directly transform higher-order composite Go types.
   173  // For example, []*foopb.Message is not transformed into []Message,
   174  // but rather the individual message elements of the slice are transformed.
   175  //
   176  // Note that there are currently no custom options for Transform,
   177  // but the use of an unexported type keeps the future open.
   178  func Transform(...option) cmp.Option {
   179  	// addrType returns a pointer to t if t isn't a pointer or interface.
   180  	addrType := func(t reflect.Type) reflect.Type {
   181  		if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
   182  			return t
   183  		}
   184  		return reflect.PtrTo(t)
   185  	}
   186  
   187  	// TODO: Should this transform protoreflect.Enum types to Enum as well?
   188  	return cmp.FilterPath(func(p cmp.Path) bool {
   189  		ps := p.Last()
   190  		if isMessageType(addrType(ps.Type())) {
   191  			return true
   192  		}
   193  
   194  		// Check whether the concrete values of an interface both satisfy
   195  		// the Message interface.
   196  		if ps.Type().Kind() == reflect.Interface {
   197  			vx, vy := ps.Values()
   198  			if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
   199  				return false
   200  			}
   201  			return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
   202  		}
   203  
   204  		return false
   205  	}, cmp.Transformer("protocmp.Transform", func(v interface{}) Message {
   206  		// For user convenience, shallow copy the message value if necessary
   207  		// in order for it to implement the message interface.
   208  		if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
   209  			pv := reflect.New(rv.Type())
   210  			pv.Elem().Set(rv)
   211  			v = pv.Interface()
   212  		}
   213  
   214  		m := protoimpl.X.MessageOf(v)
   215  		switch {
   216  		case m == nil:
   217  			return nil
   218  		case !m.IsValid():
   219  			return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
   220  		default:
   221  			return transformMessage(m)
   222  		}
   223  	}))
   224  }
   225  
   226  func isMessageType(t reflect.Type) bool {
   227  	// Avoid transforming the Message itself.
   228  	if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
   229  		return false
   230  	}
   231  	return t.Implements(messageV1Type) || t.Implements(messageV2Type)
   232  }
   233  
   234  func transformMessage(m protoreflect.Message) Message {
   235  	mx := Message{}
   236  	mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
   237  
   238  	// Handle known and extension fields.
   239  	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   240  		s := fd.TextName()
   241  		if fd.IsExtension() {
   242  			mt.xds[s] = fd
   243  		}
   244  		switch {
   245  		case fd.IsList():
   246  			mx[s] = transformList(fd, v.List())
   247  		case fd.IsMap():
   248  			mx[s] = transformMap(fd, v.Map())
   249  		default:
   250  			mx[s] = transformSingular(fd, v)
   251  		}
   252  		return true
   253  	})
   254  
   255  	// Handle unknown fields.
   256  	for b := m.GetUnknown(); len(b) > 0; {
   257  		num, _, n := protowire.ConsumeField(b)
   258  		s := strconv.Itoa(int(num))
   259  		b2, _ := mx[s].(protoreflect.RawFields)
   260  		mx[s] = append(b2, b[:n]...)
   261  		b = b[n:]
   262  	}
   263  
   264  	// Expand Any messages.
   265  	if mt.md.FullName() == genid.Any_message_fullname {
   266  		// TODO: Expose Transform option to specify a custom resolver?
   267  		s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
   268  		b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
   269  		mt, err := protoregistry.GlobalTypes.FindMessageByURL(s)
   270  		if mt != nil && err == nil {
   271  			m2 := mt.New()
   272  			err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
   273  			if err == nil {
   274  				mx[string(genid.Any_Value_field_name)] = transformMessage(m2)
   275  			}
   276  		}
   277  	}
   278  
   279  	mx[messageTypeKey] = mt
   280  	return mx
   281  }
   282  
   283  func transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
   284  	t := protoKindToGoType(fd.Kind())
   285  	rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
   286  	for i := 0; i < lv.Len(); i++ {
   287  		v := reflect.ValueOf(transformSingular(fd, lv.Get(i)))
   288  		rv.Index(i).Set(v)
   289  	}
   290  	return rv.Interface()
   291  }
   292  
   293  func transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
   294  	kfd := fd.MapKey()
   295  	vfd := fd.MapValue()
   296  	kt := protoKindToGoType(kfd.Kind())
   297  	vt := protoKindToGoType(vfd.Kind())
   298  	rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
   299  	mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   300  		kv := reflect.ValueOf(transformSingular(kfd, k.Value()))
   301  		vv := reflect.ValueOf(transformSingular(vfd, v))
   302  		rv.SetMapIndex(kv, vv)
   303  		return true
   304  	})
   305  	return rv.Interface()
   306  }
   307  
   308  func transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
   309  	switch fd.Kind() {
   310  	case protoreflect.EnumKind:
   311  		return Enum{num: v.Enum(), ed: fd.Enum()}
   312  	case protoreflect.MessageKind, protoreflect.GroupKind:
   313  		return transformMessage(v.Message())
   314  	case protoreflect.BytesKind:
   315  		// The protoreflect API does not specify whether an empty bytes is
   316  		// guaranteed to be nil or not. Always return non-nil bytes to avoid
   317  		// leaking information about the concrete proto.Message implementation.
   318  		if len(v.Bytes()) == 0 {
   319  			return []byte{}
   320  		}
   321  		return v.Bytes()
   322  	default:
   323  		return v.Interface()
   324  	}
   325  }
   326  
   327  func protoKindToGoType(k protoreflect.Kind) reflect.Type {
   328  	switch k {
   329  	case protoreflect.BoolKind:
   330  		return reflect.TypeOf(bool(false))
   331  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   332  		return reflect.TypeOf(int32(0))
   333  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   334  		return reflect.TypeOf(int64(0))
   335  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   336  		return reflect.TypeOf(uint32(0))
   337  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   338  		return reflect.TypeOf(uint64(0))
   339  	case protoreflect.FloatKind:
   340  		return reflect.TypeOf(float32(0))
   341  	case protoreflect.DoubleKind:
   342  		return reflect.TypeOf(float64(0))
   343  	case protoreflect.StringKind:
   344  		return reflect.TypeOf(string(""))
   345  	case protoreflect.BytesKind:
   346  		return reflect.TypeOf([]byte(nil))
   347  	case protoreflect.EnumKind:
   348  		return reflect.TypeOf(Enum{})
   349  	case protoreflect.MessageKind, protoreflect.GroupKind:
   350  		return reflect.TypeOf(Message{})
   351  	default:
   352  		panic("invalid kind")
   353  	}
   354  }
   355  

View as plain text