...

Source file src/google.golang.org/protobuf/reflect/protorange/range.go

Documentation: google.golang.org/protobuf/reflect/protorange

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package protorange provides functionality to traverse a message value.
     6  package protorange
     7  
     8  import (
     9  	"bytes"
    10  	"errors"
    11  
    12  	"google.golang.org/protobuf/internal/genid"
    13  	"google.golang.org/protobuf/internal/order"
    14  	"google.golang.org/protobuf/proto"
    15  	"google.golang.org/protobuf/reflect/protopath"
    16  	"google.golang.org/protobuf/reflect/protoreflect"
    17  	"google.golang.org/protobuf/reflect/protoregistry"
    18  )
    19  
    20  var (
    21  	// Break breaks traversal of children in the current value.
    22  	// It has no effect when traversing values that are not composite types
    23  	// (e.g., messages, lists, and maps).
    24  	Break = errors.New("break traversal of children in current value")
    25  
    26  	// Terminate terminates the entire range operation.
    27  	// All necessary Pop operations continue to be called.
    28  	Terminate = errors.New("terminate range operation")
    29  )
    30  
    31  // Range performs a depth-first traversal over reachable values in a message.
    32  //
    33  // See [Options.Range] for details.
    34  func Range(m protoreflect.Message, f func(protopath.Values) error) error {
    35  	return Options{}.Range(m, f, nil)
    36  }
    37  
    38  // Options configures traversal of a message value tree.
    39  type Options struct {
    40  	// Stable specifies whether to visit message fields and map entries
    41  	// in a stable ordering. If false, then the ordering is undefined and
    42  	// may be non-deterministic.
    43  	//
    44  	// Message fields are visited in ascending order by field number.
    45  	// Map entries are visited in ascending order, where
    46  	// boolean keys are ordered such that false sorts before true,
    47  	// numeric keys are ordered based on the numeric value, and
    48  	// string keys are lexicographically ordered by Unicode codepoints.
    49  	Stable bool
    50  
    51  	// Resolver is used for looking up types when expanding google.protobuf.Any
    52  	// messages. If nil, this defaults to using protoregistry.GlobalTypes.
    53  	// To prevent expansion of Any messages, pass an empty protoregistry.Types:
    54  	//
    55  	//	Options{Resolver: (*protoregistry.Types)(nil)}
    56  	//
    57  	Resolver interface {
    58  		protoregistry.ExtensionTypeResolver
    59  		protoregistry.MessageTypeResolver
    60  	}
    61  }
    62  
    63  // Range performs a depth-first traversal over reachable values in a message.
    64  // The first push and the last pop are to push/pop a [protopath.Root] step.
    65  // If push or pop return any non-nil error (other than [Break] or [Terminate]),
    66  // it terminates the traversal and is returned by Range.
    67  //
    68  // The rules for traversing a message is as follows:
    69  //
    70  //   - For messages, iterate over every populated known and extension field.
    71  //     Each field is preceded by a push of a [protopath.FieldAccess] step,
    72  //     followed by recursive application of the rules on the field value,
    73  //     and succeeded by a pop of that step.
    74  //     If the message has unknown fields, then push an [protopath.UnknownAccess] step
    75  //     followed immediately by pop of that step.
    76  //
    77  //   - As an exception to the above rule, if the current message is a
    78  //     google.protobuf.Any message, expand the underlying message (if resolvable).
    79  //     The expanded message is preceded by a push of a [protopath.AnyExpand] step,
    80  //     followed by recursive application of the rules on the underlying message,
    81  //     and succeeded by a pop of that step. Mutations to the expanded message
    82  //     are written back to the Any message when popping back out.
    83  //
    84  //   - For lists, iterate over every element. Each element is preceded by a push
    85  //     of a [protopath.ListIndex] step, followed by recursive application of the rules
    86  //     on the list element, and succeeded by a pop of that step.
    87  //
    88  //   - For maps, iterate over every entry. Each entry is preceded by a push
    89  //     of a [protopath.MapIndex] step, followed by recursive application of the rules
    90  //     on the map entry value, and succeeded by a pop of that step.
    91  //
    92  // Mutations should only be made to the last value, otherwise the effects on
    93  // traversal will be undefined. If the mutation is made to the last value
    94  // during to a push, then the effects of the mutation will affect traversal.
    95  // For example, if the last value is currently a message, and the push function
    96  // populates a few fields in that message, then the newly modified fields
    97  // will be traversed.
    98  //
    99  // The [protopath.Values] provided to push functions is only valid until the
   100  // corresponding pop call and the values provided to a pop call is only valid
   101  // for the duration of the pop call itself.
   102  func (o Options) Range(m protoreflect.Message, push, pop func(protopath.Values) error) error {
   103  	var err error
   104  	p := new(protopath.Values)
   105  	if o.Resolver == nil {
   106  		o.Resolver = protoregistry.GlobalTypes
   107  	}
   108  
   109  	pushStep(p, protopath.Root(m.Descriptor()), protoreflect.ValueOfMessage(m))
   110  	if push != nil {
   111  		err = amendError(err, push(*p))
   112  	}
   113  	if err == nil {
   114  		err = o.rangeMessage(p, m, push, pop)
   115  	}
   116  	if pop != nil {
   117  		err = amendError(err, pop(*p))
   118  	}
   119  	popStep(p)
   120  
   121  	if err == Break || err == Terminate {
   122  		err = nil
   123  	}
   124  	return err
   125  }
   126  
   127  func (o Options) rangeMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (err error) {
   128  	if ok, err := o.rangeAnyMessage(p, m, push, pop); ok {
   129  		return err
   130  	}
   131  
   132  	fieldOrder := order.AnyFieldOrder
   133  	if o.Stable {
   134  		fieldOrder = order.NumberFieldOrder
   135  	}
   136  	order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   137  		pushStep(p, protopath.FieldAccess(fd), v)
   138  		if push != nil {
   139  			err = amendError(err, push(*p))
   140  		}
   141  		if err == nil {
   142  			switch {
   143  			case fd.IsMap():
   144  				err = o.rangeMap(p, fd, v.Map(), push, pop)
   145  			case fd.IsList():
   146  				err = o.rangeList(p, fd, v.List(), push, pop)
   147  			case fd.Message() != nil:
   148  				err = o.rangeMessage(p, v.Message(), push, pop)
   149  			}
   150  		}
   151  		if pop != nil {
   152  			err = amendError(err, pop(*p))
   153  		}
   154  		popStep(p)
   155  		return err == nil
   156  	})
   157  
   158  	if b := m.GetUnknown(); len(b) > 0 && err == nil {
   159  		pushStep(p, protopath.UnknownAccess(), protoreflect.ValueOfBytes(b))
   160  		if push != nil {
   161  			err = amendError(err, push(*p))
   162  		}
   163  		if pop != nil {
   164  			err = amendError(err, pop(*p))
   165  		}
   166  		popStep(p)
   167  	}
   168  
   169  	if err == Break {
   170  		err = nil
   171  	}
   172  	return err
   173  }
   174  
   175  func (o Options) rangeAnyMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (ok bool, err error) {
   176  	md := m.Descriptor()
   177  	if md.FullName() != "google.protobuf.Any" {
   178  		return false, nil
   179  	}
   180  
   181  	fds := md.Fields()
   182  	url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String()
   183  	val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes()
   184  	mt, errFind := o.Resolver.FindMessageByURL(url)
   185  	if errFind != nil {
   186  		return false, nil
   187  	}
   188  
   189  	// Unmarshal the raw encoded message value into a structured message value.
   190  	m2 := mt.New()
   191  	errUnmarshal := proto.UnmarshalOptions{
   192  		Merge:        true,
   193  		AllowPartial: true,
   194  		Resolver:     o.Resolver,
   195  	}.Unmarshal(val, m2.Interface())
   196  	if errUnmarshal != nil {
   197  		// If the the underlying message cannot be unmarshaled,
   198  		// then just treat this as an normal message type.
   199  		return false, nil
   200  	}
   201  
   202  	// Marshal Any before ranging to detect possible mutations.
   203  	b1, errMarshal := proto.MarshalOptions{
   204  		AllowPartial:  true,
   205  		Deterministic: true,
   206  	}.Marshal(m2.Interface())
   207  	if errMarshal != nil {
   208  		return true, errMarshal
   209  	}
   210  
   211  	pushStep(p, protopath.AnyExpand(m2.Descriptor()), protoreflect.ValueOfMessage(m2))
   212  	if push != nil {
   213  		err = amendError(err, push(*p))
   214  	}
   215  	if err == nil {
   216  		err = o.rangeMessage(p, m2, push, pop)
   217  	}
   218  	if pop != nil {
   219  		err = amendError(err, pop(*p))
   220  	}
   221  	popStep(p)
   222  
   223  	// Marshal Any after ranging to detect possible mutations.
   224  	b2, errMarshal := proto.MarshalOptions{
   225  		AllowPartial:  true,
   226  		Deterministic: true,
   227  	}.Marshal(m2.Interface())
   228  	if errMarshal != nil {
   229  		return true, errMarshal
   230  	}
   231  
   232  	// Mutations detected, write the new sequence of bytes to the Any message.
   233  	if !bytes.Equal(b1, b2) {
   234  		m.Set(fds.ByNumber(genid.Any_Value_field_number), protoreflect.ValueOfBytes(b2))
   235  	}
   236  
   237  	if err == Break {
   238  		err = nil
   239  	}
   240  	return true, err
   241  }
   242  
   243  func (o Options) rangeList(p *protopath.Values, fd protoreflect.FieldDescriptor, ls protoreflect.List, push, pop func(protopath.Values) error) (err error) {
   244  	for i := 0; i < ls.Len() && err == nil; i++ {
   245  		v := ls.Get(i)
   246  		pushStep(p, protopath.ListIndex(i), v)
   247  		if push != nil {
   248  			err = amendError(err, push(*p))
   249  		}
   250  		if err == nil && fd.Message() != nil {
   251  			err = o.rangeMessage(p, v.Message(), push, pop)
   252  		}
   253  		if pop != nil {
   254  			err = amendError(err, pop(*p))
   255  		}
   256  		popStep(p)
   257  	}
   258  
   259  	if err == Break {
   260  		err = nil
   261  	}
   262  	return err
   263  }
   264  
   265  func (o Options) rangeMap(p *protopath.Values, fd protoreflect.FieldDescriptor, ms protoreflect.Map, push, pop func(protopath.Values) error) (err error) {
   266  	keyOrder := order.AnyKeyOrder
   267  	if o.Stable {
   268  		keyOrder = order.GenericKeyOrder
   269  	}
   270  	order.RangeEntries(ms, keyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool {
   271  		pushStep(p, protopath.MapIndex(k), v)
   272  		if push != nil {
   273  			err = amendError(err, push(*p))
   274  		}
   275  		if err == nil && fd.MapValue().Message() != nil {
   276  			err = o.rangeMessage(p, v.Message(), push, pop)
   277  		}
   278  		if pop != nil {
   279  			err = amendError(err, pop(*p))
   280  		}
   281  		popStep(p)
   282  		return err == nil
   283  	})
   284  
   285  	if err == Break {
   286  		err = nil
   287  	}
   288  	return err
   289  }
   290  
   291  func pushStep(p *protopath.Values, s protopath.Step, v protoreflect.Value) {
   292  	p.Path = append(p.Path, s)
   293  	p.Values = append(p.Values, v)
   294  }
   295  
   296  func popStep(p *protopath.Values) {
   297  	p.Path = p.Path[:len(p.Path)-1]
   298  	p.Values = p.Values[:len(p.Values)-1]
   299  }
   300  
   301  // amendError amends the previous error with the current error if it is
   302  // considered more serious. The precedence order for errors is:
   303  //
   304  //	nil < Break < Terminate < previous non-nil < current non-nil
   305  func amendError(prev, curr error) error {
   306  	switch {
   307  	case curr == nil:
   308  		return prev
   309  	case curr == Break && prev != nil:
   310  		return prev
   311  	case curr == Terminate && prev != nil && prev != Break:
   312  		return prev
   313  	default:
   314  		return curr
   315  	}
   316  }
   317  

View as plain text