...

Source file src/google.golang.org/protobuf/internal/impl/codec_field.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"sync"
    11  
    12  	"google.golang.org/protobuf/encoding/protowire"
    13  	"google.golang.org/protobuf/internal/errors"
    14  	"google.golang.org/protobuf/proto"
    15  	"google.golang.org/protobuf/reflect/protoreflect"
    16  	"google.golang.org/protobuf/reflect/protoregistry"
    17  	"google.golang.org/protobuf/runtime/protoiface"
    18  )
    19  
    20  type errInvalidUTF8 struct{}
    21  
    22  func (errInvalidUTF8) Error() string     { return "string field contains invalid UTF-8" }
    23  func (errInvalidUTF8) InvalidUTF8() bool { return true }
    24  func (errInvalidUTF8) Unwrap() error     { return errors.Error }
    25  
    26  // initOneofFieldCoders initializes the fast-path functions for the fields in a oneof.
    27  //
    28  // For size, marshal, and isInit operations, functions are set only on the first field
    29  // in the oneof. The functions are called when the oneof is non-nil, and will dispatch
    30  // to the appropriate field-specific function as necessary.
    31  //
    32  // The unmarshal function is set on each field individually as usual.
    33  func (mi *MessageInfo) initOneofFieldCoders(od protoreflect.OneofDescriptor, si structInfo) {
    34  	fs := si.oneofsByName[od.Name()]
    35  	ft := fs.Type
    36  	oneofFields := make(map[reflect.Type]*coderFieldInfo)
    37  	needIsInit := false
    38  	fields := od.Fields()
    39  	for i, lim := 0, fields.Len(); i < lim; i++ {
    40  		fd := od.Fields().Get(i)
    41  		num := fd.Number()
    42  		// Make a copy of the original coderFieldInfo for use in unmarshaling.
    43  		//
    44  		// oneofFields[oneofType].funcs.marshal is the field-specific marshal function.
    45  		//
    46  		// mi.coderFields[num].marshal is set on only the first field in the oneof,
    47  		// and dispatches to the field-specific marshaler in oneofFields.
    48  		cf := *mi.coderFields[num]
    49  		ot := si.oneofWrappersByNumber[num]
    50  		cf.ft = ot.Field(0).Type
    51  		cf.mi, cf.funcs = fieldCoder(fd, cf.ft)
    52  		oneofFields[ot] = &cf
    53  		if cf.funcs.isInit != nil {
    54  			needIsInit = true
    55  		}
    56  		mi.coderFields[num].funcs.unmarshal = func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
    57  			var vw reflect.Value         // pointer to wrapper type
    58  			vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
    59  			if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
    60  				vw = vi.Elem()
    61  			} else {
    62  				vw = reflect.New(ot)
    63  			}
    64  			out, err := cf.funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, &cf, opts)
    65  			if err != nil {
    66  				return out, err
    67  			}
    68  			vi.Set(vw)
    69  			return out, nil
    70  		}
    71  	}
    72  	getInfo := func(p pointer) (pointer, *coderFieldInfo) {
    73  		v := p.AsValueOf(ft).Elem()
    74  		if v.IsNil() {
    75  			return pointer{}, nil
    76  		}
    77  		v = v.Elem() // interface -> *struct
    78  		if v.IsNil() {
    79  			return pointer{}, nil
    80  		}
    81  		return pointerOfValue(v).Apply(zeroOffset), oneofFields[v.Elem().Type()]
    82  	}
    83  	first := mi.coderFields[od.Fields().Get(0).Number()]
    84  	first.funcs.size = func(p pointer, _ *coderFieldInfo, opts marshalOptions) int {
    85  		p, info := getInfo(p)
    86  		if info == nil || info.funcs.size == nil {
    87  			return 0
    88  		}
    89  		return info.funcs.size(p, info, opts)
    90  	}
    91  	first.funcs.marshal = func(b []byte, p pointer, _ *coderFieldInfo, opts marshalOptions) ([]byte, error) {
    92  		p, info := getInfo(p)
    93  		if info == nil || info.funcs.marshal == nil {
    94  			return b, nil
    95  		}
    96  		return info.funcs.marshal(b, p, info, opts)
    97  	}
    98  	first.funcs.merge = func(dst, src pointer, _ *coderFieldInfo, opts mergeOptions) {
    99  		srcp, srcinfo := getInfo(src)
   100  		if srcinfo == nil || srcinfo.funcs.merge == nil {
   101  			return
   102  		}
   103  		dstp, dstinfo := getInfo(dst)
   104  		if dstinfo != srcinfo {
   105  			dst.AsValueOf(ft).Elem().Set(reflect.New(src.AsValueOf(ft).Elem().Elem().Elem().Type()))
   106  			dstp = pointerOfValue(dst.AsValueOf(ft).Elem().Elem()).Apply(zeroOffset)
   107  		}
   108  		srcinfo.funcs.merge(dstp, srcp, srcinfo, opts)
   109  	}
   110  	if needIsInit {
   111  		first.funcs.isInit = func(p pointer, _ *coderFieldInfo) error {
   112  			p, info := getInfo(p)
   113  			if info == nil || info.funcs.isInit == nil {
   114  				return nil
   115  			}
   116  			return info.funcs.isInit(p, info)
   117  		}
   118  	}
   119  }
   120  
   121  func makeWeakMessageFieldCoder(fd protoreflect.FieldDescriptor) pointerCoderFuncs {
   122  	var once sync.Once
   123  	var messageType protoreflect.MessageType
   124  	lazyInit := func() {
   125  		once.Do(func() {
   126  			messageName := fd.Message().FullName()
   127  			messageType, _ = protoregistry.GlobalTypes.FindMessageByName(messageName)
   128  		})
   129  	}
   130  
   131  	return pointerCoderFuncs{
   132  		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   133  			m, ok := p.WeakFields().get(f.num)
   134  			if !ok {
   135  				return 0
   136  			}
   137  			lazyInit()
   138  			if messageType == nil {
   139  				panic(fmt.Sprintf("weak message %v is not linked in", fd.Message().FullName()))
   140  			}
   141  			return sizeMessage(m, f.tagsize, opts)
   142  		},
   143  		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   144  			m, ok := p.WeakFields().get(f.num)
   145  			if !ok {
   146  				return b, nil
   147  			}
   148  			lazyInit()
   149  			if messageType == nil {
   150  				panic(fmt.Sprintf("weak message %v is not linked in", fd.Message().FullName()))
   151  			}
   152  			return appendMessage(b, m, f.wiretag, opts)
   153  		},
   154  		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
   155  			fs := p.WeakFields()
   156  			m, ok := fs.get(f.num)
   157  			if !ok {
   158  				lazyInit()
   159  				if messageType == nil {
   160  					return unmarshalOutput{}, errUnknown
   161  				}
   162  				m = messageType.New().Interface()
   163  				fs.set(f.num, m)
   164  			}
   165  			return consumeMessage(b, m, wtyp, opts)
   166  		},
   167  		isInit: func(p pointer, f *coderFieldInfo) error {
   168  			m, ok := p.WeakFields().get(f.num)
   169  			if !ok {
   170  				return nil
   171  			}
   172  			return proto.CheckInitialized(m)
   173  		},
   174  		merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   175  			sm, ok := src.WeakFields().get(f.num)
   176  			if !ok {
   177  				return
   178  			}
   179  			dm, ok := dst.WeakFields().get(f.num)
   180  			if !ok {
   181  				lazyInit()
   182  				if messageType == nil {
   183  					panic(fmt.Sprintf("weak message %v is not linked in", fd.Message().FullName()))
   184  				}
   185  				dm = messageType.New().Interface()
   186  				dst.WeakFields().set(f.num, dm)
   187  			}
   188  			opts.Merge(dm, sm)
   189  		},
   190  	}
   191  }
   192  
   193  func makeMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
   194  	if mi := getMessageInfo(ft); mi != nil {
   195  		funcs := pointerCoderFuncs{
   196  			size:      sizeMessageInfo,
   197  			marshal:   appendMessageInfo,
   198  			unmarshal: consumeMessageInfo,
   199  			merge:     mergeMessage,
   200  		}
   201  		if needsInitCheck(mi.Desc) {
   202  			funcs.isInit = isInitMessageInfo
   203  		}
   204  		return funcs
   205  	} else {
   206  		return pointerCoderFuncs{
   207  			size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   208  				m := asMessage(p.AsValueOf(ft).Elem())
   209  				return sizeMessage(m, f.tagsize, opts)
   210  			},
   211  			marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   212  				m := asMessage(p.AsValueOf(ft).Elem())
   213  				return appendMessage(b, m, f.wiretag, opts)
   214  			},
   215  			unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
   216  				mp := p.AsValueOf(ft).Elem()
   217  				if mp.IsNil() {
   218  					mp.Set(reflect.New(ft.Elem()))
   219  				}
   220  				return consumeMessage(b, asMessage(mp), wtyp, opts)
   221  			},
   222  			isInit: func(p pointer, f *coderFieldInfo) error {
   223  				m := asMessage(p.AsValueOf(ft).Elem())
   224  				return proto.CheckInitialized(m)
   225  			},
   226  			merge: mergeMessage,
   227  		}
   228  	}
   229  }
   230  
   231  func sizeMessageInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   232  	return protowire.SizeBytes(f.mi.sizePointer(p.Elem(), opts)) + f.tagsize
   233  }
   234  
   235  func appendMessageInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   236  	b = protowire.AppendVarint(b, f.wiretag)
   237  	b = protowire.AppendVarint(b, uint64(f.mi.sizePointer(p.Elem(), opts)))
   238  	return f.mi.marshalAppendPointer(b, p.Elem(), opts)
   239  }
   240  
   241  func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   242  	if wtyp != protowire.BytesType {
   243  		return out, errUnknown
   244  	}
   245  	v, n := protowire.ConsumeBytes(b)
   246  	if n < 0 {
   247  		return out, errDecode
   248  	}
   249  	if p.Elem().IsNil() {
   250  		p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
   251  	}
   252  	o, err := f.mi.unmarshalPointer(v, p.Elem(), 0, opts)
   253  	if err != nil {
   254  		return out, err
   255  	}
   256  	out.n = n
   257  	out.initialized = o.initialized
   258  	return out, nil
   259  }
   260  
   261  func isInitMessageInfo(p pointer, f *coderFieldInfo) error {
   262  	return f.mi.checkInitializedPointer(p.Elem())
   263  }
   264  
   265  func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
   266  	return protowire.SizeBytes(proto.Size(m)) + tagsize
   267  }
   268  
   269  func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
   270  	b = protowire.AppendVarint(b, wiretag)
   271  	b = protowire.AppendVarint(b, uint64(proto.Size(m)))
   272  	return opts.Options().MarshalAppend(b, m)
   273  }
   274  
   275  func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
   276  	if wtyp != protowire.BytesType {
   277  		return out, errUnknown
   278  	}
   279  	v, n := protowire.ConsumeBytes(b)
   280  	if n < 0 {
   281  		return out, errDecode
   282  	}
   283  	o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
   284  		Buf:     v,
   285  		Message: m.ProtoReflect(),
   286  	})
   287  	if err != nil {
   288  		return out, err
   289  	}
   290  	out.n = n
   291  	out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
   292  	return out, nil
   293  }
   294  
   295  func sizeMessageValue(v protoreflect.Value, tagsize int, opts marshalOptions) int {
   296  	m := v.Message().Interface()
   297  	return sizeMessage(m, tagsize, opts)
   298  }
   299  
   300  func appendMessageValue(b []byte, v protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
   301  	m := v.Message().Interface()
   302  	return appendMessage(b, m, wiretag, opts)
   303  }
   304  
   305  func consumeMessageValue(b []byte, v protoreflect.Value, _ protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (protoreflect.Value, unmarshalOutput, error) {
   306  	m := v.Message().Interface()
   307  	out, err := consumeMessage(b, m, wtyp, opts)
   308  	return v, out, err
   309  }
   310  
   311  func isInitMessageValue(v protoreflect.Value) error {
   312  	m := v.Message().Interface()
   313  	return proto.CheckInitialized(m)
   314  }
   315  
   316  var coderMessageValue = valueCoderFuncs{
   317  	size:      sizeMessageValue,
   318  	marshal:   appendMessageValue,
   319  	unmarshal: consumeMessageValue,
   320  	isInit:    isInitMessageValue,
   321  	merge:     mergeMessageValue,
   322  }
   323  
   324  func sizeGroupValue(v protoreflect.Value, tagsize int, opts marshalOptions) int {
   325  	m := v.Message().Interface()
   326  	return sizeGroup(m, tagsize, opts)
   327  }
   328  
   329  func appendGroupValue(b []byte, v protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
   330  	m := v.Message().Interface()
   331  	return appendGroup(b, m, wiretag, opts)
   332  }
   333  
   334  func consumeGroupValue(b []byte, v protoreflect.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (protoreflect.Value, unmarshalOutput, error) {
   335  	m := v.Message().Interface()
   336  	out, err := consumeGroup(b, m, num, wtyp, opts)
   337  	return v, out, err
   338  }
   339  
   340  var coderGroupValue = valueCoderFuncs{
   341  	size:      sizeGroupValue,
   342  	marshal:   appendGroupValue,
   343  	unmarshal: consumeGroupValue,
   344  	isInit:    isInitMessageValue,
   345  	merge:     mergeMessageValue,
   346  }
   347  
   348  func makeGroupFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
   349  	num := fd.Number()
   350  	if mi := getMessageInfo(ft); mi != nil {
   351  		funcs := pointerCoderFuncs{
   352  			size:      sizeGroupType,
   353  			marshal:   appendGroupType,
   354  			unmarshal: consumeGroupType,
   355  			merge:     mergeMessage,
   356  		}
   357  		if needsInitCheck(mi.Desc) {
   358  			funcs.isInit = isInitMessageInfo
   359  		}
   360  		return funcs
   361  	} else {
   362  		return pointerCoderFuncs{
   363  			size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   364  				m := asMessage(p.AsValueOf(ft).Elem())
   365  				return sizeGroup(m, f.tagsize, opts)
   366  			},
   367  			marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   368  				m := asMessage(p.AsValueOf(ft).Elem())
   369  				return appendGroup(b, m, f.wiretag, opts)
   370  			},
   371  			unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
   372  				mp := p.AsValueOf(ft).Elem()
   373  				if mp.IsNil() {
   374  					mp.Set(reflect.New(ft.Elem()))
   375  				}
   376  				return consumeGroup(b, asMessage(mp), num, wtyp, opts)
   377  			},
   378  			isInit: func(p pointer, f *coderFieldInfo) error {
   379  				m := asMessage(p.AsValueOf(ft).Elem())
   380  				return proto.CheckInitialized(m)
   381  			},
   382  			merge: mergeMessage,
   383  		}
   384  	}
   385  }
   386  
   387  func sizeGroupType(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   388  	return 2*f.tagsize + f.mi.sizePointer(p.Elem(), opts)
   389  }
   390  
   391  func appendGroupType(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   392  	b = protowire.AppendVarint(b, f.wiretag) // start group
   393  	b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
   394  	b = protowire.AppendVarint(b, f.wiretag+1) // end group
   395  	return b, err
   396  }
   397  
   398  func consumeGroupType(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   399  	if wtyp != protowire.StartGroupType {
   400  		return out, errUnknown
   401  	}
   402  	if p.Elem().IsNil() {
   403  		p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
   404  	}
   405  	return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts)
   406  }
   407  
   408  func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
   409  	return 2*tagsize + proto.Size(m)
   410  }
   411  
   412  func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
   413  	b = protowire.AppendVarint(b, wiretag) // start group
   414  	b, err := opts.Options().MarshalAppend(b, m)
   415  	b = protowire.AppendVarint(b, wiretag+1) // end group
   416  	return b, err
   417  }
   418  
   419  func consumeGroup(b []byte, m proto.Message, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
   420  	if wtyp != protowire.StartGroupType {
   421  		return out, errUnknown
   422  	}
   423  	b, n := protowire.ConsumeGroup(num, b)
   424  	if n < 0 {
   425  		return out, errDecode
   426  	}
   427  	o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
   428  		Buf:     b,
   429  		Message: m.ProtoReflect(),
   430  	})
   431  	if err != nil {
   432  		return out, err
   433  	}
   434  	out.n = n
   435  	out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
   436  	return out, nil
   437  }
   438  
   439  func makeMessageSliceFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
   440  	if mi := getMessageInfo(ft); mi != nil {
   441  		funcs := pointerCoderFuncs{
   442  			size:      sizeMessageSliceInfo,
   443  			marshal:   appendMessageSliceInfo,
   444  			unmarshal: consumeMessageSliceInfo,
   445  			merge:     mergeMessageSlice,
   446  		}
   447  		if needsInitCheck(mi.Desc) {
   448  			funcs.isInit = isInitMessageSliceInfo
   449  		}
   450  		return funcs
   451  	}
   452  	return pointerCoderFuncs{
   453  		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   454  			return sizeMessageSlice(p, ft, f.tagsize, opts)
   455  		},
   456  		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   457  			return appendMessageSlice(b, p, f.wiretag, ft, opts)
   458  		},
   459  		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
   460  			return consumeMessageSlice(b, p, ft, wtyp, opts)
   461  		},
   462  		isInit: func(p pointer, f *coderFieldInfo) error {
   463  			return isInitMessageSlice(p, ft)
   464  		},
   465  		merge: mergeMessageSlice,
   466  	}
   467  }
   468  
   469  func sizeMessageSliceInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   470  	s := p.PointerSlice()
   471  	n := 0
   472  	for _, v := range s {
   473  		n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
   474  	}
   475  	return n
   476  }
   477  
   478  func appendMessageSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   479  	s := p.PointerSlice()
   480  	var err error
   481  	for _, v := range s {
   482  		b = protowire.AppendVarint(b, f.wiretag)
   483  		siz := f.mi.sizePointer(v, opts)
   484  		b = protowire.AppendVarint(b, uint64(siz))
   485  		b, err = f.mi.marshalAppendPointer(b, v, opts)
   486  		if err != nil {
   487  			return b, err
   488  		}
   489  	}
   490  	return b, nil
   491  }
   492  
   493  func consumeMessageSliceInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   494  	if wtyp != protowire.BytesType {
   495  		return out, errUnknown
   496  	}
   497  	v, n := protowire.ConsumeBytes(b)
   498  	if n < 0 {
   499  		return out, errDecode
   500  	}
   501  	m := reflect.New(f.mi.GoReflectType.Elem()).Interface()
   502  	mp := pointerOfIface(m)
   503  	o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
   504  	if err != nil {
   505  		return out, err
   506  	}
   507  	p.AppendPointerSlice(mp)
   508  	out.n = n
   509  	out.initialized = o.initialized
   510  	return out, nil
   511  }
   512  
   513  func isInitMessageSliceInfo(p pointer, f *coderFieldInfo) error {
   514  	s := p.PointerSlice()
   515  	for _, v := range s {
   516  		if err := f.mi.checkInitializedPointer(v); err != nil {
   517  			return err
   518  		}
   519  	}
   520  	return nil
   521  }
   522  
   523  func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
   524  	s := p.PointerSlice()
   525  	n := 0
   526  	for _, v := range s {
   527  		m := asMessage(v.AsValueOf(goType.Elem()))
   528  		n += protowire.SizeBytes(proto.Size(m)) + tagsize
   529  	}
   530  	return n
   531  }
   532  
   533  func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) {
   534  	s := p.PointerSlice()
   535  	var err error
   536  	for _, v := range s {
   537  		m := asMessage(v.AsValueOf(goType.Elem()))
   538  		b = protowire.AppendVarint(b, wiretag)
   539  		siz := proto.Size(m)
   540  		b = protowire.AppendVarint(b, uint64(siz))
   541  		b, err = opts.Options().MarshalAppend(b, m)
   542  		if err != nil {
   543  			return b, err
   544  		}
   545  	}
   546  	return b, nil
   547  }
   548  
   549  func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
   550  	if wtyp != protowire.BytesType {
   551  		return out, errUnknown
   552  	}
   553  	v, n := protowire.ConsumeBytes(b)
   554  	if n < 0 {
   555  		return out, errDecode
   556  	}
   557  	mp := reflect.New(goType.Elem())
   558  	o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
   559  		Buf:     v,
   560  		Message: asMessage(mp).ProtoReflect(),
   561  	})
   562  	if err != nil {
   563  		return out, err
   564  	}
   565  	p.AppendPointerSlice(pointerOfValue(mp))
   566  	out.n = n
   567  	out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
   568  	return out, nil
   569  }
   570  
   571  func isInitMessageSlice(p pointer, goType reflect.Type) error {
   572  	s := p.PointerSlice()
   573  	for _, v := range s {
   574  		m := asMessage(v.AsValueOf(goType.Elem()))
   575  		if err := proto.CheckInitialized(m); err != nil {
   576  			return err
   577  		}
   578  	}
   579  	return nil
   580  }
   581  
   582  // Slices of messages
   583  
   584  func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
   585  	list := listv.List()
   586  	n := 0
   587  	for i, llen := 0, list.Len(); i < llen; i++ {
   588  		m := list.Get(i).Message().Interface()
   589  		n += protowire.SizeBytes(proto.Size(m)) + tagsize
   590  	}
   591  	return n
   592  }
   593  
   594  func appendMessageSliceValue(b []byte, listv protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
   595  	list := listv.List()
   596  	mopts := opts.Options()
   597  	for i, llen := 0, list.Len(); i < llen; i++ {
   598  		m := list.Get(i).Message().Interface()
   599  		b = protowire.AppendVarint(b, wiretag)
   600  		siz := proto.Size(m)
   601  		b = protowire.AppendVarint(b, uint64(siz))
   602  		var err error
   603  		b, err = mopts.MarshalAppend(b, m)
   604  		if err != nil {
   605  			return b, err
   606  		}
   607  	}
   608  	return b, nil
   609  }
   610  
   611  func consumeMessageSliceValue(b []byte, listv protoreflect.Value, _ protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
   612  	list := listv.List()
   613  	if wtyp != protowire.BytesType {
   614  		return protoreflect.Value{}, out, errUnknown
   615  	}
   616  	v, n := protowire.ConsumeBytes(b)
   617  	if n < 0 {
   618  		return protoreflect.Value{}, out, errDecode
   619  	}
   620  	m := list.NewElement()
   621  	o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
   622  		Buf:     v,
   623  		Message: m.Message(),
   624  	})
   625  	if err != nil {
   626  		return protoreflect.Value{}, out, err
   627  	}
   628  	list.Append(m)
   629  	out.n = n
   630  	out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
   631  	return listv, out, nil
   632  }
   633  
   634  func isInitMessageSliceValue(listv protoreflect.Value) error {
   635  	list := listv.List()
   636  	for i, llen := 0, list.Len(); i < llen; i++ {
   637  		m := list.Get(i).Message().Interface()
   638  		if err := proto.CheckInitialized(m); err != nil {
   639  			return err
   640  		}
   641  	}
   642  	return nil
   643  }
   644  
   645  var coderMessageSliceValue = valueCoderFuncs{
   646  	size:      sizeMessageSliceValue,
   647  	marshal:   appendMessageSliceValue,
   648  	unmarshal: consumeMessageSliceValue,
   649  	isInit:    isInitMessageSliceValue,
   650  	merge:     mergeMessageListValue,
   651  }
   652  
   653  func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
   654  	list := listv.List()
   655  	n := 0
   656  	for i, llen := 0, list.Len(); i < llen; i++ {
   657  		m := list.Get(i).Message().Interface()
   658  		n += 2*tagsize + proto.Size(m)
   659  	}
   660  	return n
   661  }
   662  
   663  func appendGroupSliceValue(b []byte, listv protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
   664  	list := listv.List()
   665  	mopts := opts.Options()
   666  	for i, llen := 0, list.Len(); i < llen; i++ {
   667  		m := list.Get(i).Message().Interface()
   668  		b = protowire.AppendVarint(b, wiretag) // start group
   669  		var err error
   670  		b, err = mopts.MarshalAppend(b, m)
   671  		if err != nil {
   672  			return b, err
   673  		}
   674  		b = protowire.AppendVarint(b, wiretag+1) // end group
   675  	}
   676  	return b, nil
   677  }
   678  
   679  func consumeGroupSliceValue(b []byte, listv protoreflect.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
   680  	list := listv.List()
   681  	if wtyp != protowire.StartGroupType {
   682  		return protoreflect.Value{}, out, errUnknown
   683  	}
   684  	b, n := protowire.ConsumeGroup(num, b)
   685  	if n < 0 {
   686  		return protoreflect.Value{}, out, errDecode
   687  	}
   688  	m := list.NewElement()
   689  	o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
   690  		Buf:     b,
   691  		Message: m.Message(),
   692  	})
   693  	if err != nil {
   694  		return protoreflect.Value{}, out, err
   695  	}
   696  	list.Append(m)
   697  	out.n = n
   698  	out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
   699  	return listv, out, nil
   700  }
   701  
   702  var coderGroupSliceValue = valueCoderFuncs{
   703  	size:      sizeGroupSliceValue,
   704  	marshal:   appendGroupSliceValue,
   705  	unmarshal: consumeGroupSliceValue,
   706  	isInit:    isInitMessageSliceValue,
   707  	merge:     mergeMessageListValue,
   708  }
   709  
   710  func makeGroupSliceFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
   711  	num := fd.Number()
   712  	if mi := getMessageInfo(ft); mi != nil {
   713  		funcs := pointerCoderFuncs{
   714  			size:      sizeGroupSliceInfo,
   715  			marshal:   appendGroupSliceInfo,
   716  			unmarshal: consumeGroupSliceInfo,
   717  			merge:     mergeMessageSlice,
   718  		}
   719  		if needsInitCheck(mi.Desc) {
   720  			funcs.isInit = isInitMessageSliceInfo
   721  		}
   722  		return funcs
   723  	}
   724  	return pointerCoderFuncs{
   725  		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   726  			return sizeGroupSlice(p, ft, f.tagsize, opts)
   727  		},
   728  		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   729  			return appendGroupSlice(b, p, f.wiretag, ft, opts)
   730  		},
   731  		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
   732  			return consumeGroupSlice(b, p, num, wtyp, ft, opts)
   733  		},
   734  		isInit: func(p pointer, f *coderFieldInfo) error {
   735  			return isInitMessageSlice(p, ft)
   736  		},
   737  		merge: mergeMessageSlice,
   738  	}
   739  }
   740  
   741  func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalOptions) int {
   742  	s := p.PointerSlice()
   743  	n := 0
   744  	for _, v := range s {
   745  		m := asMessage(v.AsValueOf(messageType.Elem()))
   746  		n += 2*tagsize + proto.Size(m)
   747  	}
   748  	return n
   749  }
   750  
   751  func appendGroupSlice(b []byte, p pointer, wiretag uint64, messageType reflect.Type, opts marshalOptions) ([]byte, error) {
   752  	s := p.PointerSlice()
   753  	var err error
   754  	for _, v := range s {
   755  		m := asMessage(v.AsValueOf(messageType.Elem()))
   756  		b = protowire.AppendVarint(b, wiretag) // start group
   757  		b, err = opts.Options().MarshalAppend(b, m)
   758  		if err != nil {
   759  			return b, err
   760  		}
   761  		b = protowire.AppendVarint(b, wiretag+1) // end group
   762  	}
   763  	return b, nil
   764  }
   765  
   766  func consumeGroupSlice(b []byte, p pointer, num protowire.Number, wtyp protowire.Type, goType reflect.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
   767  	if wtyp != protowire.StartGroupType {
   768  		return out, errUnknown
   769  	}
   770  	b, n := protowire.ConsumeGroup(num, b)
   771  	if n < 0 {
   772  		return out, errDecode
   773  	}
   774  	mp := reflect.New(goType.Elem())
   775  	o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
   776  		Buf:     b,
   777  		Message: asMessage(mp).ProtoReflect(),
   778  	})
   779  	if err != nil {
   780  		return out, err
   781  	}
   782  	p.AppendPointerSlice(pointerOfValue(mp))
   783  	out.n = n
   784  	out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
   785  	return out, nil
   786  }
   787  
   788  func sizeGroupSliceInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
   789  	s := p.PointerSlice()
   790  	n := 0
   791  	for _, v := range s {
   792  		n += 2*f.tagsize + f.mi.sizePointer(v, opts)
   793  	}
   794  	return n
   795  }
   796  
   797  func appendGroupSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   798  	s := p.PointerSlice()
   799  	var err error
   800  	for _, v := range s {
   801  		b = protowire.AppendVarint(b, f.wiretag) // start group
   802  		b, err = f.mi.marshalAppendPointer(b, v, opts)
   803  		if err != nil {
   804  			return b, err
   805  		}
   806  		b = protowire.AppendVarint(b, f.wiretag+1) // end group
   807  	}
   808  	return b, nil
   809  }
   810  
   811  func consumeGroupSliceInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
   812  	if wtyp != protowire.StartGroupType {
   813  		return unmarshalOutput{}, errUnknown
   814  	}
   815  	m := reflect.New(f.mi.GoReflectType.Elem()).Interface()
   816  	mp := pointerOfIface(m)
   817  	out, err := f.mi.unmarshalPointer(b, mp, f.num, opts)
   818  	if err != nil {
   819  		return out, err
   820  	}
   821  	p.AppendPointerSlice(mp)
   822  	return out, nil
   823  }
   824  
   825  func asMessage(v reflect.Value) protoreflect.ProtoMessage {
   826  	if m, ok := v.Interface().(protoreflect.ProtoMessage); ok {
   827  		return m
   828  	}
   829  	return legacyWrapMessage(v).Interface()
   830  }
   831  

View as plain text