...

Source file src/google.golang.org/protobuf/internal/impl/decode.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"math/bits"
     9  
    10  	"google.golang.org/protobuf/encoding/protowire"
    11  	"google.golang.org/protobuf/internal/errors"
    12  	"google.golang.org/protobuf/internal/flags"
    13  	"google.golang.org/protobuf/proto"
    14  	"google.golang.org/protobuf/reflect/protoreflect"
    15  	"google.golang.org/protobuf/reflect/protoregistry"
    16  	"google.golang.org/protobuf/runtime/protoiface"
    17  )
    18  
    19  var errDecode = errors.New("cannot parse invalid wire-format data")
    20  var errRecursionDepth = errors.New("exceeded maximum recursion depth")
    21  
    22  type unmarshalOptions struct {
    23  	flags    protoiface.UnmarshalInputFlags
    24  	resolver interface {
    25  		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
    26  		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
    27  	}
    28  	depth int
    29  }
    30  
    31  func (o unmarshalOptions) Options() proto.UnmarshalOptions {
    32  	return proto.UnmarshalOptions{
    33  		Merge:          true,
    34  		AllowPartial:   true,
    35  		DiscardUnknown: o.DiscardUnknown(),
    36  		Resolver:       o.resolver,
    37  	}
    38  }
    39  
    40  func (o unmarshalOptions) DiscardUnknown() bool {
    41  	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
    42  }
    43  
    44  func (o unmarshalOptions) IsDefault() bool {
    45  	return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
    46  }
    47  
    48  var lazyUnmarshalOptions = unmarshalOptions{
    49  	resolver: protoregistry.GlobalTypes,
    50  	depth:    protowire.DefaultRecursionLimit,
    51  }
    52  
    53  type unmarshalOutput struct {
    54  	n           int // number of bytes consumed
    55  	initialized bool
    56  }
    57  
    58  // unmarshal is protoreflect.Methods.Unmarshal.
    59  func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
    60  	var p pointer
    61  	if ms, ok := in.Message.(*messageState); ok {
    62  		p = ms.pointer()
    63  	} else {
    64  		p = in.Message.(*messageReflectWrapper).pointer()
    65  	}
    66  	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
    67  		flags:    in.Flags,
    68  		resolver: in.Resolver,
    69  		depth:    in.Depth,
    70  	})
    71  	var flags protoiface.UnmarshalOutputFlags
    72  	if out.initialized {
    73  		flags |= protoiface.UnmarshalInitialized
    74  	}
    75  	return protoiface.UnmarshalOutput{
    76  		Flags: flags,
    77  	}, err
    78  }
    79  
    80  // errUnknown is returned during unmarshaling to indicate a parse error that
    81  // should result in a field being placed in the unknown fields section (for example,
    82  // when the wire type doesn't match) as opposed to the entire unmarshal operation
    83  // failing (for example, when a field extends past the available input).
    84  //
    85  // This is a sentinel error which should never be visible to the user.
    86  var errUnknown = errors.New("unknown")
    87  
    88  func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
    89  	mi.init()
    90  	opts.depth--
    91  	if opts.depth < 0 {
    92  		return out, errRecursionDepth
    93  	}
    94  	if flags.ProtoLegacy && mi.isMessageSet {
    95  		return unmarshalMessageSet(mi, b, p, opts)
    96  	}
    97  	initialized := true
    98  	var requiredMask uint64
    99  	var exts *map[int32]ExtensionField
   100  	start := len(b)
   101  	for len(b) > 0 {
   102  		// Parse the tag (field number and wire type).
   103  		var tag uint64
   104  		if b[0] < 0x80 {
   105  			tag = uint64(b[0])
   106  			b = b[1:]
   107  		} else if len(b) >= 2 && b[1] < 128 {
   108  			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
   109  			b = b[2:]
   110  		} else {
   111  			var n int
   112  			tag, n = protowire.ConsumeVarint(b)
   113  			if n < 0 {
   114  				return out, errDecode
   115  			}
   116  			b = b[n:]
   117  		}
   118  		var num protowire.Number
   119  		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
   120  			return out, errDecode
   121  		} else {
   122  			num = protowire.Number(n)
   123  		}
   124  		wtyp := protowire.Type(tag & 7)
   125  
   126  		if wtyp == protowire.EndGroupType {
   127  			if num != groupTag {
   128  				return out, errDecode
   129  			}
   130  			groupTag = 0
   131  			break
   132  		}
   133  
   134  		var f *coderFieldInfo
   135  		if int(num) < len(mi.denseCoderFields) {
   136  			f = mi.denseCoderFields[num]
   137  		} else {
   138  			f = mi.coderFields[num]
   139  		}
   140  		var n int
   141  		err := errUnknown
   142  		switch {
   143  		case f != nil:
   144  			if f.funcs.unmarshal == nil {
   145  				break
   146  			}
   147  			var o unmarshalOutput
   148  			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
   149  			n = o.n
   150  			if err != nil {
   151  				break
   152  			}
   153  			requiredMask |= f.validation.requiredBit
   154  			if f.funcs.isInit != nil && !o.initialized {
   155  				initialized = false
   156  			}
   157  		default:
   158  			// Possible extension.
   159  			if exts == nil && mi.extensionOffset.IsValid() {
   160  				exts = p.Apply(mi.extensionOffset).Extensions()
   161  				if *exts == nil {
   162  					*exts = make(map[int32]ExtensionField)
   163  				}
   164  			}
   165  			if exts == nil {
   166  				break
   167  			}
   168  			var o unmarshalOutput
   169  			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
   170  			if err != nil {
   171  				break
   172  			}
   173  			n = o.n
   174  			if !o.initialized {
   175  				initialized = false
   176  			}
   177  		}
   178  		if err != nil {
   179  			if err != errUnknown {
   180  				return out, err
   181  			}
   182  			n = protowire.ConsumeFieldValue(num, wtyp, b)
   183  			if n < 0 {
   184  				return out, errDecode
   185  			}
   186  			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
   187  				u := mi.mutableUnknownBytes(p)
   188  				*u = protowire.AppendTag(*u, num, wtyp)
   189  				*u = append(*u, b[:n]...)
   190  			}
   191  		}
   192  		b = b[n:]
   193  	}
   194  	if groupTag != 0 {
   195  		return out, errDecode
   196  	}
   197  	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
   198  		initialized = false
   199  	}
   200  	if initialized {
   201  		out.initialized = true
   202  	}
   203  	out.n = start - len(b)
   204  	return out, nil
   205  }
   206  
   207  func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
   208  	x := exts[int32(num)]
   209  	xt := x.Type()
   210  	if xt == nil {
   211  		var err error
   212  		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
   213  		if err != nil {
   214  			if err == protoregistry.NotFound {
   215  				return out, errUnknown
   216  			}
   217  			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
   218  		}
   219  	}
   220  	xi := getExtensionFieldInfo(xt)
   221  	if xi.funcs.unmarshal == nil {
   222  		return out, errUnknown
   223  	}
   224  	if flags.LazyUnmarshalExtensions {
   225  		if opts.IsDefault() && x.canLazy(xt) {
   226  			out, valid := skipExtension(b, xi, num, wtyp, opts)
   227  			switch valid {
   228  			case ValidationValid:
   229  				if out.initialized {
   230  					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
   231  					exts[int32(num)] = x
   232  					return out, nil
   233  				}
   234  			case ValidationInvalid:
   235  				return out, errDecode
   236  			case ValidationUnknown:
   237  			}
   238  		}
   239  	}
   240  	ival := x.Value()
   241  	if !ival.IsValid() && xi.unmarshalNeedsValue {
   242  		// Create a new message, list, or map value to fill in.
   243  		// For enums, create a prototype value to let the unmarshal func know the
   244  		// concrete type.
   245  		ival = xt.New()
   246  	}
   247  	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
   248  	if err != nil {
   249  		return out, err
   250  	}
   251  	if xi.funcs.isInit == nil {
   252  		out.initialized = true
   253  	}
   254  	x.Set(xt, v)
   255  	exts[int32(num)] = x
   256  	return out, nil
   257  }
   258  
   259  func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
   260  	if xi.validation.mi == nil {
   261  		return out, ValidationUnknown
   262  	}
   263  	xi.validation.mi.init()
   264  	switch xi.validation.typ {
   265  	case validationTypeMessage:
   266  		if wtyp != protowire.BytesType {
   267  			return out, ValidationUnknown
   268  		}
   269  		v, n := protowire.ConsumeBytes(b)
   270  		if n < 0 {
   271  			return out, ValidationUnknown
   272  		}
   273  		out, st := xi.validation.mi.validate(v, 0, opts)
   274  		out.n = n
   275  		return out, st
   276  	case validationTypeGroup:
   277  		if wtyp != protowire.StartGroupType {
   278  			return out, ValidationUnknown
   279  		}
   280  		out, st := xi.validation.mi.validate(b, num, opts)
   281  		return out, st
   282  	default:
   283  		return out, ValidationUnknown
   284  	}
   285  }
   286  

View as plain text