...

Source file src/google.golang.org/protobuf/proto/decode.go

Documentation: google.golang.org/protobuf/proto

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package proto
     6  
     7  import (
     8  	"google.golang.org/protobuf/encoding/protowire"
     9  	"google.golang.org/protobuf/internal/encoding/messageset"
    10  	"google.golang.org/protobuf/internal/errors"
    11  	"google.golang.org/protobuf/internal/flags"
    12  	"google.golang.org/protobuf/internal/genid"
    13  	"google.golang.org/protobuf/internal/pragma"
    14  	"google.golang.org/protobuf/reflect/protoreflect"
    15  	"google.golang.org/protobuf/reflect/protoregistry"
    16  	"google.golang.org/protobuf/runtime/protoiface"
    17  )
    18  
    19  // UnmarshalOptions configures the unmarshaler.
    20  //
    21  // Example usage:
    22  //
    23  //	err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
    24  type UnmarshalOptions struct {
    25  	pragma.NoUnkeyedLiterals
    26  
    27  	// Merge merges the input into the destination message.
    28  	// The default behavior is to always reset the message before unmarshaling,
    29  	// unless Merge is specified.
    30  	Merge bool
    31  
    32  	// AllowPartial accepts input for messages that will result in missing
    33  	// required fields. If AllowPartial is false (the default), Unmarshal will
    34  	// return an error if there are any missing required fields.
    35  	AllowPartial bool
    36  
    37  	// If DiscardUnknown is set, unknown fields are ignored.
    38  	DiscardUnknown bool
    39  
    40  	// Resolver is used for looking up types when unmarshaling extension fields.
    41  	// If nil, this defaults to using protoregistry.GlobalTypes.
    42  	Resolver interface {
    43  		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
    44  		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
    45  	}
    46  
    47  	// RecursionLimit limits how deeply messages may be nested.
    48  	// If zero, a default limit is applied.
    49  	RecursionLimit int
    50  }
    51  
    52  // Unmarshal parses the wire-format message in b and places the result in m.
    53  // The provided message must be mutable (e.g., a non-nil pointer to a message).
    54  func Unmarshal(b []byte, m Message) error {
    55  	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
    56  	return err
    57  }
    58  
    59  // Unmarshal parses the wire-format message in b and places the result in m.
    60  // The provided message must be mutable (e.g., a non-nil pointer to a message).
    61  func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
    62  	if o.RecursionLimit == 0 {
    63  		o.RecursionLimit = protowire.DefaultRecursionLimit
    64  	}
    65  	_, err := o.unmarshal(b, m.ProtoReflect())
    66  	return err
    67  }
    68  
    69  // UnmarshalState parses a wire-format message and places the result in m.
    70  //
    71  // This method permits fine-grained control over the unmarshaler.
    72  // Most users should use [Unmarshal] instead.
    73  func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
    74  	if o.RecursionLimit == 0 {
    75  		o.RecursionLimit = protowire.DefaultRecursionLimit
    76  	}
    77  	return o.unmarshal(in.Buf, in.Message)
    78  }
    79  
    80  // unmarshal is a centralized function that all unmarshal operations go through.
    81  // For profiling purposes, avoid changing the name of this function or
    82  // introducing other code paths for unmarshal that do not go through this.
    83  func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
    84  	if o.Resolver == nil {
    85  		o.Resolver = protoregistry.GlobalTypes
    86  	}
    87  	if !o.Merge {
    88  		Reset(m.Interface())
    89  	}
    90  	allowPartial := o.AllowPartial
    91  	o.Merge = true
    92  	o.AllowPartial = true
    93  	methods := protoMethods(m)
    94  	if methods != nil && methods.Unmarshal != nil &&
    95  		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
    96  		in := protoiface.UnmarshalInput{
    97  			Message:  m,
    98  			Buf:      b,
    99  			Resolver: o.Resolver,
   100  			Depth:    o.RecursionLimit,
   101  		}
   102  		if o.DiscardUnknown {
   103  			in.Flags |= protoiface.UnmarshalDiscardUnknown
   104  		}
   105  		out, err = methods.Unmarshal(in)
   106  	} else {
   107  		o.RecursionLimit--
   108  		if o.RecursionLimit < 0 {
   109  			return out, errors.New("exceeded max recursion depth")
   110  		}
   111  		err = o.unmarshalMessageSlow(b, m)
   112  	}
   113  	if err != nil {
   114  		return out, err
   115  	}
   116  	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
   117  		return out, nil
   118  	}
   119  	return out, checkInitialized(m)
   120  }
   121  
   122  func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
   123  	_, err := o.unmarshal(b, m)
   124  	return err
   125  }
   126  
   127  func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
   128  	md := m.Descriptor()
   129  	if messageset.IsMessageSet(md) {
   130  		return o.unmarshalMessageSet(b, m)
   131  	}
   132  	fields := md.Fields()
   133  	for len(b) > 0 {
   134  		// Parse the tag (field number and wire type).
   135  		num, wtyp, tagLen := protowire.ConsumeTag(b)
   136  		if tagLen < 0 {
   137  			return errDecode
   138  		}
   139  		if num > protowire.MaxValidNumber {
   140  			return errDecode
   141  		}
   142  
   143  		// Find the field descriptor for this field number.
   144  		fd := fields.ByNumber(num)
   145  		if fd == nil && md.ExtensionRanges().Has(num) {
   146  			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
   147  			if err != nil && err != protoregistry.NotFound {
   148  				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
   149  			}
   150  			if extType != nil {
   151  				fd = extType.TypeDescriptor()
   152  			}
   153  		}
   154  		var err error
   155  		if fd == nil {
   156  			err = errUnknown
   157  		} else if flags.ProtoLegacy {
   158  			if fd.IsWeak() && fd.Message().IsPlaceholder() {
   159  				err = errUnknown // weak referent is not linked in
   160  			}
   161  		}
   162  
   163  		// Parse the field value.
   164  		var valLen int
   165  		switch {
   166  		case err != nil:
   167  		case fd.IsList():
   168  			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
   169  		case fd.IsMap():
   170  			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
   171  		default:
   172  			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
   173  		}
   174  		if err != nil {
   175  			if err != errUnknown {
   176  				return err
   177  			}
   178  			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
   179  			if valLen < 0 {
   180  				return errDecode
   181  			}
   182  			if !o.DiscardUnknown {
   183  				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
   184  			}
   185  		}
   186  		b = b[tagLen+valLen:]
   187  	}
   188  	return nil
   189  }
   190  
   191  func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
   192  	v, n, err := o.unmarshalScalar(b, wtyp, fd)
   193  	if err != nil {
   194  		return 0, err
   195  	}
   196  	switch fd.Kind() {
   197  	case protoreflect.GroupKind, protoreflect.MessageKind:
   198  		m2 := m.Mutable(fd).Message()
   199  		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
   200  			return n, err
   201  		}
   202  	default:
   203  		// Non-message scalars replace the previous value.
   204  		m.Set(fd, v)
   205  	}
   206  	return n, nil
   207  }
   208  
   209  func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
   210  	if wtyp != protowire.BytesType {
   211  		return 0, errUnknown
   212  	}
   213  	b, n = protowire.ConsumeBytes(b)
   214  	if n < 0 {
   215  		return 0, errDecode
   216  	}
   217  	var (
   218  		keyField = fd.MapKey()
   219  		valField = fd.MapValue()
   220  		key      protoreflect.Value
   221  		val      protoreflect.Value
   222  		haveKey  bool
   223  		haveVal  bool
   224  	)
   225  	switch valField.Kind() {
   226  	case protoreflect.GroupKind, protoreflect.MessageKind:
   227  		val = mapv.NewValue()
   228  	}
   229  	// Map entries are represented as a two-element message with fields
   230  	// containing the key and value.
   231  	for len(b) > 0 {
   232  		num, wtyp, n := protowire.ConsumeTag(b)
   233  		if n < 0 {
   234  			return 0, errDecode
   235  		}
   236  		if num > protowire.MaxValidNumber {
   237  			return 0, errDecode
   238  		}
   239  		b = b[n:]
   240  		err = errUnknown
   241  		switch num {
   242  		case genid.MapEntry_Key_field_number:
   243  			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
   244  			if err != nil {
   245  				break
   246  			}
   247  			haveKey = true
   248  		case genid.MapEntry_Value_field_number:
   249  			var v protoreflect.Value
   250  			v, n, err = o.unmarshalScalar(b, wtyp, valField)
   251  			if err != nil {
   252  				break
   253  			}
   254  			switch valField.Kind() {
   255  			case protoreflect.GroupKind, protoreflect.MessageKind:
   256  				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
   257  					return 0, err
   258  				}
   259  			default:
   260  				val = v
   261  			}
   262  			haveVal = true
   263  		}
   264  		if err == errUnknown {
   265  			n = protowire.ConsumeFieldValue(num, wtyp, b)
   266  			if n < 0 {
   267  				return 0, errDecode
   268  			}
   269  		} else if err != nil {
   270  			return 0, err
   271  		}
   272  		b = b[n:]
   273  	}
   274  	// Every map entry should have entries for key and value, but this is not strictly required.
   275  	if !haveKey {
   276  		key = keyField.Default()
   277  	}
   278  	if !haveVal {
   279  		switch valField.Kind() {
   280  		case protoreflect.GroupKind, protoreflect.MessageKind:
   281  		default:
   282  			val = valField.Default()
   283  		}
   284  	}
   285  	mapv.Set(key.MapKey(), val)
   286  	return n, nil
   287  }
   288  
   289  // errUnknown is used internally to indicate fields which should be added
   290  // to the unknown field set of a message. It is never returned from an exported
   291  // function.
   292  var errUnknown = errors.New("BUG: internal error (unknown)")
   293  
   294  var errDecode = errors.New("cannot parse invalid wire-format data")
   295  

View as plain text