...

Source file src/google.golang.org/protobuf/internal/impl/message_reflect.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  
    11  	"google.golang.org/protobuf/internal/detrand"
    12  	"google.golang.org/protobuf/internal/pragma"
    13  	"google.golang.org/protobuf/reflect/protoreflect"
    14  )
    15  
    16  type reflectMessageInfo struct {
    17  	fields map[protoreflect.FieldNumber]*fieldInfo
    18  	oneofs map[protoreflect.Name]*oneofInfo
    19  
    20  	// fieldTypes contains the zero value of an enum or message field.
    21  	// For lists, it contains the element type.
    22  	// For maps, it contains the entry value type.
    23  	fieldTypes map[protoreflect.FieldNumber]interface{}
    24  
    25  	// denseFields is a subset of fields where:
    26  	//	0 < fieldDesc.Number() < len(denseFields)
    27  	// It provides faster access to the fieldInfo, but may be incomplete.
    28  	denseFields []*fieldInfo
    29  
    30  	// rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
    31  	rangeInfos []interface{} // either *fieldInfo or *oneofInfo
    32  
    33  	getUnknown   func(pointer) protoreflect.RawFields
    34  	setUnknown   func(pointer, protoreflect.RawFields)
    35  	extensionMap func(pointer) *extensionMap
    36  
    37  	nilMessage atomicNilMessage
    38  }
    39  
    40  // makeReflectFuncs generates the set of functions to support reflection.
    41  func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
    42  	mi.makeKnownFieldsFunc(si)
    43  	mi.makeUnknownFieldsFunc(t, si)
    44  	mi.makeExtensionFieldsFunc(t, si)
    45  	mi.makeFieldTypes(si)
    46  }
    47  
    48  // makeKnownFieldsFunc generates functions for operations that can be performed
    49  // on each protobuf message field. It takes in a reflect.Type representing the
    50  // Go struct and matches message fields with struct fields.
    51  //
    52  // This code assumes that the struct is well-formed and panics if there are
    53  // any discrepancies.
    54  func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
    55  	mi.fields = map[protoreflect.FieldNumber]*fieldInfo{}
    56  	md := mi.Desc
    57  	fds := md.Fields()
    58  	for i := 0; i < fds.Len(); i++ {
    59  		fd := fds.Get(i)
    60  		fs := si.fieldsByNumber[fd.Number()]
    61  		isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
    62  		if isOneof {
    63  			fs = si.oneofsByName[fd.ContainingOneof().Name()]
    64  		}
    65  		var fi fieldInfo
    66  		switch {
    67  		case fs.Type == nil:
    68  			fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
    69  		case isOneof:
    70  			fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
    71  		case fd.IsMap():
    72  			fi = fieldInfoForMap(fd, fs, mi.Exporter)
    73  		case fd.IsList():
    74  			fi = fieldInfoForList(fd, fs, mi.Exporter)
    75  		case fd.IsWeak():
    76  			fi = fieldInfoForWeakMessage(fd, si.weakOffset)
    77  		case fd.Message() != nil:
    78  			fi = fieldInfoForMessage(fd, fs, mi.Exporter)
    79  		default:
    80  			fi = fieldInfoForScalar(fd, fs, mi.Exporter)
    81  		}
    82  		mi.fields[fd.Number()] = &fi
    83  	}
    84  
    85  	mi.oneofs = map[protoreflect.Name]*oneofInfo{}
    86  	for i := 0; i < md.Oneofs().Len(); i++ {
    87  		od := md.Oneofs().Get(i)
    88  		mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
    89  	}
    90  
    91  	mi.denseFields = make([]*fieldInfo, fds.Len()*2)
    92  	for i := 0; i < fds.Len(); i++ {
    93  		if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
    94  			mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
    95  		}
    96  	}
    97  
    98  	for i := 0; i < fds.Len(); {
    99  		fd := fds.Get(i)
   100  		if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
   101  			mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
   102  			i += od.Fields().Len()
   103  		} else {
   104  			mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
   105  			i++
   106  		}
   107  	}
   108  
   109  	// Introduce instability to iteration order, but keep it deterministic.
   110  	if len(mi.rangeInfos) > 1 && detrand.Bool() {
   111  		i := detrand.Intn(len(mi.rangeInfos) - 1)
   112  		mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
   113  	}
   114  }
   115  
   116  func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
   117  	switch {
   118  	case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
   119  		// Handle as []byte.
   120  		mi.getUnknown = func(p pointer) protoreflect.RawFields {
   121  			if p.IsNil() {
   122  				return nil
   123  			}
   124  			return *p.Apply(mi.unknownOffset).Bytes()
   125  		}
   126  		mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
   127  			if p.IsNil() {
   128  				panic("invalid SetUnknown on nil Message")
   129  			}
   130  			*p.Apply(mi.unknownOffset).Bytes() = b
   131  		}
   132  	case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
   133  		// Handle as *[]byte.
   134  		mi.getUnknown = func(p pointer) protoreflect.RawFields {
   135  			if p.IsNil() {
   136  				return nil
   137  			}
   138  			bp := p.Apply(mi.unknownOffset).BytesPtr()
   139  			if *bp == nil {
   140  				return nil
   141  			}
   142  			return **bp
   143  		}
   144  		mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
   145  			if p.IsNil() {
   146  				panic("invalid SetUnknown on nil Message")
   147  			}
   148  			bp := p.Apply(mi.unknownOffset).BytesPtr()
   149  			if *bp == nil {
   150  				*bp = new([]byte)
   151  			}
   152  			**bp = b
   153  		}
   154  	default:
   155  		mi.getUnknown = func(pointer) protoreflect.RawFields {
   156  			return nil
   157  		}
   158  		mi.setUnknown = func(p pointer, _ protoreflect.RawFields) {
   159  			if p.IsNil() {
   160  				panic("invalid SetUnknown on nil Message")
   161  			}
   162  		}
   163  	}
   164  }
   165  
   166  func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
   167  	if si.extensionOffset.IsValid() {
   168  		mi.extensionMap = func(p pointer) *extensionMap {
   169  			if p.IsNil() {
   170  				return (*extensionMap)(nil)
   171  			}
   172  			v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
   173  			return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
   174  		}
   175  	} else {
   176  		mi.extensionMap = func(pointer) *extensionMap {
   177  			return (*extensionMap)(nil)
   178  		}
   179  	}
   180  }
   181  func (mi *MessageInfo) makeFieldTypes(si structInfo) {
   182  	md := mi.Desc
   183  	fds := md.Fields()
   184  	for i := 0; i < fds.Len(); i++ {
   185  		var ft reflect.Type
   186  		fd := fds.Get(i)
   187  		fs := si.fieldsByNumber[fd.Number()]
   188  		isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
   189  		if isOneof {
   190  			fs = si.oneofsByName[fd.ContainingOneof().Name()]
   191  		}
   192  		var isMessage bool
   193  		switch {
   194  		case fs.Type == nil:
   195  			continue // never occurs for officially generated message types
   196  		case isOneof:
   197  			if fd.Enum() != nil || fd.Message() != nil {
   198  				ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
   199  			}
   200  		case fd.IsMap():
   201  			if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
   202  				ft = fs.Type.Elem()
   203  			}
   204  			isMessage = fd.MapValue().Message() != nil
   205  		case fd.IsList():
   206  			if fd.Enum() != nil || fd.Message() != nil {
   207  				ft = fs.Type.Elem()
   208  			}
   209  			isMessage = fd.Message() != nil
   210  		case fd.Enum() != nil:
   211  			ft = fs.Type
   212  			if fd.HasPresence() && ft.Kind() == reflect.Ptr {
   213  				ft = ft.Elem()
   214  			}
   215  		case fd.Message() != nil:
   216  			ft = fs.Type
   217  			if fd.IsWeak() {
   218  				ft = nil
   219  			}
   220  			isMessage = true
   221  		}
   222  		if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
   223  			ft = reflect.PtrTo(ft) // never occurs for officially generated message types
   224  		}
   225  		if ft != nil {
   226  			if mi.fieldTypes == nil {
   227  				mi.fieldTypes = make(map[protoreflect.FieldNumber]interface{})
   228  			}
   229  			mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
   230  		}
   231  	}
   232  }
   233  
   234  type extensionMap map[int32]ExtensionField
   235  
   236  func (m *extensionMap) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
   237  	if m != nil {
   238  		for _, x := range *m {
   239  			xd := x.Type().TypeDescriptor()
   240  			v := x.Value()
   241  			if xd.IsList() && v.List().Len() == 0 {
   242  				continue
   243  			}
   244  			if !f(xd, v) {
   245  				return
   246  			}
   247  		}
   248  	}
   249  }
   250  func (m *extensionMap) Has(xt protoreflect.ExtensionType) (ok bool) {
   251  	if m == nil {
   252  		return false
   253  	}
   254  	xd := xt.TypeDescriptor()
   255  	x, ok := (*m)[int32(xd.Number())]
   256  	if !ok {
   257  		return false
   258  	}
   259  	switch {
   260  	case xd.IsList():
   261  		return x.Value().List().Len() > 0
   262  	case xd.IsMap():
   263  		return x.Value().Map().Len() > 0
   264  	case xd.Message() != nil:
   265  		return x.Value().Message().IsValid()
   266  	}
   267  	return true
   268  }
   269  func (m *extensionMap) Clear(xt protoreflect.ExtensionType) {
   270  	delete(*m, int32(xt.TypeDescriptor().Number()))
   271  }
   272  func (m *extensionMap) Get(xt protoreflect.ExtensionType) protoreflect.Value {
   273  	xd := xt.TypeDescriptor()
   274  	if m != nil {
   275  		if x, ok := (*m)[int32(xd.Number())]; ok {
   276  			return x.Value()
   277  		}
   278  	}
   279  	return xt.Zero()
   280  }
   281  func (m *extensionMap) Set(xt protoreflect.ExtensionType, v protoreflect.Value) {
   282  	xd := xt.TypeDescriptor()
   283  	isValid := true
   284  	switch {
   285  	case !xt.IsValidValue(v):
   286  		isValid = false
   287  	case xd.IsList():
   288  		isValid = v.List().IsValid()
   289  	case xd.IsMap():
   290  		isValid = v.Map().IsValid()
   291  	case xd.Message() != nil:
   292  		isValid = v.Message().IsValid()
   293  	}
   294  	if !isValid {
   295  		panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
   296  	}
   297  
   298  	if *m == nil {
   299  		*m = make(map[int32]ExtensionField)
   300  	}
   301  	var x ExtensionField
   302  	x.Set(xt, v)
   303  	(*m)[int32(xd.Number())] = x
   304  }
   305  func (m *extensionMap) Mutable(xt protoreflect.ExtensionType) protoreflect.Value {
   306  	xd := xt.TypeDescriptor()
   307  	if xd.Kind() != protoreflect.MessageKind && xd.Kind() != protoreflect.GroupKind && !xd.IsList() && !xd.IsMap() {
   308  		panic("invalid Mutable on field with non-composite type")
   309  	}
   310  	if x, ok := (*m)[int32(xd.Number())]; ok {
   311  		return x.Value()
   312  	}
   313  	v := xt.New()
   314  	m.Set(xt, v)
   315  	return v
   316  }
   317  
   318  // MessageState is a data structure that is nested as the first field in a
   319  // concrete message. It provides a way to implement the ProtoReflect method
   320  // in an allocation-free way without needing to have a shadow Go type generated
   321  // for every message type. This technique only works using unsafe.
   322  //
   323  // Example generated code:
   324  //
   325  //	type M struct {
   326  //		state protoimpl.MessageState
   327  //
   328  //		Field1 int32
   329  //		Field2 string
   330  //		Field3 *BarMessage
   331  //		...
   332  //	}
   333  //
   334  //	func (m *M) ProtoReflect() protoreflect.Message {
   335  //		mi := &file_fizz_buzz_proto_msgInfos[5]
   336  //		if protoimpl.UnsafeEnabled && m != nil {
   337  //			ms := protoimpl.X.MessageStateOf(Pointer(m))
   338  //			if ms.LoadMessageInfo() == nil {
   339  //				ms.StoreMessageInfo(mi)
   340  //			}
   341  //			return ms
   342  //		}
   343  //		return mi.MessageOf(m)
   344  //	}
   345  //
   346  // The MessageState type holds a *MessageInfo, which must be atomically set to
   347  // the message info associated with a given message instance.
   348  // By unsafely converting a *M into a *MessageState, the MessageState object
   349  // has access to all the information needed to implement protobuf reflection.
   350  // It has access to the message info as its first field, and a pointer to the
   351  // MessageState is identical to a pointer to the concrete message value.
   352  //
   353  // Requirements:
   354  //   - The type M must implement protoreflect.ProtoMessage.
   355  //   - The address of m must not be nil.
   356  //   - The address of m and the address of m.state must be equal,
   357  //     even though they are different Go types.
   358  type MessageState struct {
   359  	pragma.NoUnkeyedLiterals
   360  	pragma.DoNotCompare
   361  	pragma.DoNotCopy
   362  
   363  	atomicMessageInfo *MessageInfo
   364  }
   365  
   366  type messageState MessageState
   367  
   368  var (
   369  	_ protoreflect.Message = (*messageState)(nil)
   370  	_ unwrapper            = (*messageState)(nil)
   371  )
   372  
   373  // messageDataType is a tuple of a pointer to the message data and
   374  // a pointer to the message type. It is a generalized way of providing a
   375  // reflective view over a message instance. The disadvantage of this approach
   376  // is the need to allocate this tuple of 16B.
   377  type messageDataType struct {
   378  	p  pointer
   379  	mi *MessageInfo
   380  }
   381  
   382  type (
   383  	messageReflectWrapper messageDataType
   384  	messageIfaceWrapper   messageDataType
   385  )
   386  
   387  var (
   388  	_ protoreflect.Message      = (*messageReflectWrapper)(nil)
   389  	_ unwrapper                 = (*messageReflectWrapper)(nil)
   390  	_ protoreflect.ProtoMessage = (*messageIfaceWrapper)(nil)
   391  	_ unwrapper                 = (*messageIfaceWrapper)(nil)
   392  )
   393  
   394  // MessageOf returns a reflective view over a message. The input must be a
   395  // pointer to a named Go struct. If the provided type has a ProtoReflect method,
   396  // it must be implemented by calling this method.
   397  func (mi *MessageInfo) MessageOf(m interface{}) protoreflect.Message {
   398  	if reflect.TypeOf(m) != mi.GoReflectType {
   399  		panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
   400  	}
   401  	p := pointerOfIface(m)
   402  	if p.IsNil() {
   403  		return mi.nilMessage.Init(mi)
   404  	}
   405  	return &messageReflectWrapper{p, mi}
   406  }
   407  
   408  func (m *messageReflectWrapper) pointer() pointer          { return m.p }
   409  func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
   410  
   411  // Reset implements the v1 proto.Message.Reset method.
   412  func (m *messageIfaceWrapper) Reset() {
   413  	if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
   414  		mr.Reset()
   415  		return
   416  	}
   417  	rv := reflect.ValueOf(m.protoUnwrap())
   418  	if rv.Kind() == reflect.Ptr && !rv.IsNil() {
   419  		rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
   420  	}
   421  }
   422  func (m *messageIfaceWrapper) ProtoReflect() protoreflect.Message {
   423  	return (*messageReflectWrapper)(m)
   424  }
   425  func (m *messageIfaceWrapper) protoUnwrap() interface{} {
   426  	return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
   427  }
   428  
   429  // checkField verifies that the provided field descriptor is valid.
   430  // Exactly one of the returned values is populated.
   431  func (mi *MessageInfo) checkField(fd protoreflect.FieldDescriptor) (*fieldInfo, protoreflect.ExtensionType) {
   432  	var fi *fieldInfo
   433  	if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
   434  		fi = mi.denseFields[n]
   435  	} else {
   436  		fi = mi.fields[n]
   437  	}
   438  	if fi != nil {
   439  		if fi.fieldDesc != fd {
   440  			if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
   441  				panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
   442  			}
   443  			panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
   444  		}
   445  		return fi, nil
   446  	}
   447  
   448  	if fd.IsExtension() {
   449  		if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
   450  			// TODO: Should this be exact containing message descriptor match?
   451  			panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
   452  		}
   453  		if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
   454  			panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
   455  		}
   456  		xtd, ok := fd.(protoreflect.ExtensionTypeDescriptor)
   457  		if !ok {
   458  			panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
   459  		}
   460  		return nil, xtd.Type()
   461  	}
   462  	panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
   463  }
   464  

View as plain text