...

Source file src/google.golang.org/protobuf/testing/prototest/message.go

Documentation: google.golang.org/protobuf/testing/prototest

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package prototest exercises protobuf reflection.
     6  package prototest
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"math"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  	"testing"
    16  
    17  	"google.golang.org/protobuf/encoding/prototext"
    18  	"google.golang.org/protobuf/encoding/protowire"
    19  	"google.golang.org/protobuf/proto"
    20  	"google.golang.org/protobuf/reflect/protoreflect"
    21  	"google.golang.org/protobuf/reflect/protoregistry"
    22  )
    23  
    24  // TODO: Test invalid field descriptors or oneof descriptors.
    25  // TODO: This should test the functionality that can be provided by fast-paths.
    26  
    27  // Message tests a message implementation.
    28  type Message struct {
    29  	// Resolver is used to determine the list of extension fields to test with.
    30  	// If nil, this defaults to using protoregistry.GlobalTypes.
    31  	Resolver interface {
    32  		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
    33  		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
    34  		RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
    35  	}
    36  }
    37  
    38  // Test performs tests on a [protoreflect.MessageType] implementation.
    39  func (test Message) Test(t testing.TB, mt protoreflect.MessageType) {
    40  	testType(t, mt)
    41  
    42  	md := mt.Descriptor()
    43  	m1 := mt.New()
    44  	for i := 0; i < md.Fields().Len(); i++ {
    45  		fd := md.Fields().Get(i)
    46  		testField(t, m1, fd)
    47  	}
    48  	if test.Resolver == nil {
    49  		test.Resolver = protoregistry.GlobalTypes
    50  	}
    51  	var extTypes []protoreflect.ExtensionType
    52  	test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e protoreflect.ExtensionType) bool {
    53  		extTypes = append(extTypes, e)
    54  		return true
    55  	})
    56  	for _, xt := range extTypes {
    57  		testField(t, m1, xt.TypeDescriptor())
    58  	}
    59  	for i := 0; i < md.Oneofs().Len(); i++ {
    60  		testOneof(t, m1, md.Oneofs().Get(i))
    61  	}
    62  	testUnknown(t, m1)
    63  
    64  	// Test round-trip marshal/unmarshal.
    65  	m2 := mt.New().Interface()
    66  	populateMessage(m2.ProtoReflect(), 1, nil)
    67  	for _, xt := range extTypes {
    68  		m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
    69  	}
    70  	b, err := proto.MarshalOptions{
    71  		AllowPartial: true,
    72  	}.Marshal(m2)
    73  	if err != nil {
    74  		t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
    75  	}
    76  	m3 := mt.New().Interface()
    77  	if err := (proto.UnmarshalOptions{
    78  		AllowPartial: true,
    79  		Resolver:     test.Resolver,
    80  	}.Unmarshal(b, m3)); err != nil {
    81  		t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
    82  	}
    83  	if !proto.Equal(m2, m3) {
    84  		t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", prototext.Format(m2), prototext.Format(m3))
    85  	}
    86  }
    87  
    88  func testType(t testing.TB, mt protoreflect.MessageType) {
    89  	m := mt.New().Interface()
    90  	want := reflect.TypeOf(m)
    91  	if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
    92  		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)
    93  	}
    94  	if got := reflect.TypeOf(m.ProtoReflect().New().Interface()); got != want {
    95  		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().New().Interface()): %v != %v", got, want)
    96  	}
    97  	if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want {
    98  		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want)
    99  	}
   100  	if mt, ok := mt.(protoreflect.MessageFieldTypes); ok {
   101  		testFieldTypes(t, mt)
   102  	}
   103  }
   104  
   105  func testFieldTypes(t testing.TB, mt protoreflect.MessageFieldTypes) {
   106  	descName := func(d protoreflect.Descriptor) protoreflect.FullName {
   107  		if d == nil {
   108  			return "<nil>"
   109  		}
   110  		return d.FullName()
   111  	}
   112  	typeName := func(mt protoreflect.MessageType) protoreflect.FullName {
   113  		if mt == nil {
   114  			return "<nil>"
   115  		}
   116  		return mt.Descriptor().FullName()
   117  	}
   118  	adjustExpr := func(idx int, expr string) string {
   119  		expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1)
   120  		expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1)
   121  		expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1)
   122  		expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1)
   123  		return expr
   124  	}
   125  	checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) {
   126  		if got != want {
   127  			t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
   128  		}
   129  	}
   130  	checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) {
   131  		if got != want {
   132  			t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
   133  		}
   134  	}
   135  	checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) {
   136  		if got != want {
   137  			t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want))
   138  		}
   139  	}
   140  
   141  	fds := mt.Descriptor().Fields()
   142  	m := mt.New()
   143  	for i := 0; i < fds.Len(); i++ {
   144  		fd := fds.Get(i)
   145  		switch {
   146  		case fd.IsList():
   147  			if fd.Enum() != nil {
   148  				checkEnumDesc(i,
   149  					"mt.Enum(i).Descriptor()", "fd.Enum()",
   150  					mt.Enum(i).Descriptor(), fd.Enum())
   151  			}
   152  			if fd.Message() != nil {
   153  				checkMessageDesc(i,
   154  					"mt.Message(i).Descriptor()", "fd.Message()",
   155  					mt.Message(i).Descriptor(), fd.Message())
   156  				checkMessageType(i,
   157  					"mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()",
   158  					mt.Message(i), m.NewField(fd).List().NewElement().Message().Type())
   159  			}
   160  		case fd.IsMap():
   161  			mti := mt.Message(i)
   162  			if m := mti.New(); m != nil {
   163  				checkMessageDesc(i,
   164  					"m.Descriptor()", "fd.Message()",
   165  					m.Descriptor(), fd.Message())
   166  			}
   167  			if m := mti.Zero(); m != nil {
   168  				checkMessageDesc(i,
   169  					"m.Descriptor()", "fd.Message()",
   170  					m.Descriptor(), fd.Message())
   171  			}
   172  			checkMessageDesc(i,
   173  				"mti.Descriptor()", "fd.Message()",
   174  				mti.Descriptor(), fd.Message())
   175  			if mti := mti.(protoreflect.MessageFieldTypes); mti != nil {
   176  				if fd.MapValue().Enum() != nil {
   177  					checkEnumDesc(i,
   178  						"mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()",
   179  						mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum())
   180  				}
   181  				if fd.MapValue().Message() != nil {
   182  					checkMessageDesc(i,
   183  						"mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()",
   184  						mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message())
   185  					checkMessageType(i,
   186  						"mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()",
   187  						mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type())
   188  				}
   189  			}
   190  		default:
   191  			if fd.Enum() != nil {
   192  				checkEnumDesc(i,
   193  					"mt.Enum(i).Descriptor()", "fd.Enum()",
   194  					mt.Enum(i).Descriptor(), fd.Enum())
   195  			}
   196  			if fd.Message() != nil {
   197  				checkMessageDesc(i,
   198  					"mt.Message(i).Descriptor()", "fd.Message()",
   199  					mt.Message(i).Descriptor(), fd.Message())
   200  				checkMessageType(i,
   201  					"mt.Message(i)", "m.NewField(fd).Message().Type()",
   202  					mt.Message(i), m.NewField(fd).Message().Type())
   203  			}
   204  		}
   205  	}
   206  }
   207  
   208  // testField exercises set/get/has/clear of a field.
   209  func testField(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) {
   210  	name := fd.FullName()
   211  	num := fd.Number()
   212  
   213  	switch {
   214  	case fd.IsList():
   215  		testFieldList(t, m, fd)
   216  	case fd.IsMap():
   217  		testFieldMap(t, m, fd)
   218  	case fd.Message() != nil:
   219  	default:
   220  		if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) {
   221  			t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want))
   222  		}
   223  		if fd.Kind() == protoreflect.FloatKind || fd.Kind() == protoreflect.DoubleKind {
   224  			testFieldFloat(t, m, fd)
   225  		}
   226  	}
   227  
   228  	// Set to a non-zero value, the zero value, different non-zero values.
   229  	for _, n := range []seed{1, 0, minVal, maxVal} {
   230  		v := newValue(m, fd, n, nil)
   231  		m.Set(fd, v)
   232  		wantHas := true
   233  		if n == 0 {
   234  			if fd.Syntax() == protoreflect.Proto3 && fd.Message() == nil {
   235  				wantHas = false
   236  			}
   237  			if fd.IsExtension() {
   238  				wantHas = true
   239  			}
   240  			if fd.Cardinality() == protoreflect.Repeated {
   241  				wantHas = false
   242  			}
   243  			if fd.ContainingOneof() != nil {
   244  				wantHas = true
   245  			}
   246  		}
   247  		if fd.Syntax() == protoreflect.Proto3 && fd.Cardinality() != protoreflect.Repeated && fd.ContainingOneof() == nil && fd.Kind() == protoreflect.EnumKind && v.Enum() == 0 {
   248  			wantHas = false
   249  		}
   250  		if got, want := m.Has(fd), wantHas; got != want {
   251  			t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
   252  		}
   253  		if got, want := m.Get(fd), v; !valueEqual(got, want) {
   254  			t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
   255  		}
   256  		found := false
   257  		m.Range(func(d protoreflect.FieldDescriptor, got protoreflect.Value) bool {
   258  			if fd != d {
   259  				return true
   260  			}
   261  			found = true
   262  			if want := v; !valueEqual(got, want) {
   263  				t.Errorf("after setting %q:\nMessage.Range got value %v, want %v", name, formatValue(got), formatValue(want))
   264  			}
   265  			return true
   266  		})
   267  		if got, want := wantHas, found; got != want {
   268  			t.Errorf("after setting %q:\nMessageRange saw field: %v, want %v", name, got, want)
   269  		}
   270  	}
   271  
   272  	m.Clear(fd)
   273  	if got, want := m.Has(fd), false; got != want {
   274  		t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
   275  	}
   276  	switch {
   277  	case fd.IsList():
   278  		if got := m.Get(fd); got.List().Len() != 0 {
   279  			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
   280  		}
   281  	case fd.IsMap():
   282  		if got := m.Get(fd); got.Map().Len() != 0 {
   283  			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
   284  		}
   285  	case fd.Message() == nil:
   286  		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
   287  			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
   288  		}
   289  	}
   290  
   291  	// Set to the default value.
   292  	switch {
   293  	case fd.IsList() || fd.IsMap():
   294  		m.Set(fd, m.Mutable(fd))
   295  		if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != protoreflect.Repeated) || fd.ContainingOneof() != nil; got != want {
   296  			t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
   297  		}
   298  	case fd.Message() == nil:
   299  		m.Set(fd, m.Get(fd))
   300  		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
   301  			t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
   302  		}
   303  	}
   304  	m.Clear(fd)
   305  
   306  	// Set to the wrong type.
   307  	v := protoreflect.ValueOfString("")
   308  	if fd.Kind() == protoreflect.StringKind {
   309  		v = protoreflect.ValueOfInt32(0)
   310  	}
   311  	if !panics(func() {
   312  		m.Set(fd, v)
   313  	}) {
   314  		t.Errorf("setting %v to %T succeeds, want panic", name, v.Interface())
   315  	}
   316  }
   317  
   318  // testFieldMap tests set/get/has/clear of entries in a map field.
   319  func testFieldMap(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) {
   320  	name := fd.FullName()
   321  	num := fd.Number()
   322  
   323  	// New values.
   324  	m.Clear(fd) // start with an empty map
   325  	mapv := m.Get(fd).Map()
   326  	if mapv.IsValid() {
   327  		t.Errorf("after clearing field: message.Get(%v).IsValid() = true, want false", name)
   328  	}
   329  	if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
   330  		t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
   331  	}
   332  	if !panics(func() {
   333  		m.Set(fd, protoreflect.ValueOfMap(mapv))
   334  	}) {
   335  		t.Errorf("message.Set(%v, <invalid>) does not panic", name)
   336  	}
   337  	if !panics(func() {
   338  		mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, nil))
   339  	}) {
   340  		t.Errorf("message.Get(%v).Set(...) of invalid map does not panic", name)
   341  	}
   342  	mapv = m.Mutable(fd).Map() // mutable map
   343  	if !mapv.IsValid() {
   344  		t.Errorf("message.Mutable(%v).IsValid() = false, want true", name)
   345  	}
   346  	if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
   347  		t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
   348  	}
   349  
   350  	// Add values.
   351  	want := make(testMap)
   352  	for i, n := range []seed{1, 0, minVal, maxVal} {
   353  		if got, want := m.Has(fd), i > 0; got != want {
   354  			t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
   355  		}
   356  
   357  		k := newMapKey(fd, n)
   358  		v := newMapValue(fd, mapv, n, nil)
   359  		mapv.Set(k, v)
   360  		want.Set(k, v)
   361  		if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) {
   362  			t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
   363  		}
   364  	}
   365  
   366  	// Set values.
   367  	want.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   368  		nv := newMapValue(fd, mapv, 10, nil)
   369  		mapv.Set(k, nv)
   370  		want.Set(k, nv)
   371  		if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) {
   372  			t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
   373  		}
   374  		return true
   375  	})
   376  
   377  	// Clear values.
   378  	want.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   379  		mapv.Clear(k)
   380  		want.Clear(k)
   381  		if got, want := m.Has(fd), want.Len() > 0; got != want {
   382  			t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
   383  		}
   384  		if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) {
   385  			t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
   386  		}
   387  		return true
   388  	})
   389  	if mapv := m.Get(fd).Map(); mapv.IsValid() {
   390  		t.Errorf("after clearing all elements: message.Get(%v).IsValid() = true, want false %v", name, formatValue(protoreflect.ValueOfMap(mapv)))
   391  	}
   392  
   393  	// Non-existent map keys.
   394  	missingKey := newMapKey(fd, 1)
   395  	if got, want := mapv.Has(missingKey), false; got != want {
   396  		t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
   397  	}
   398  	if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
   399  		t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
   400  	}
   401  	mapv.Clear(missingKey) // noop
   402  
   403  	// Mutable.
   404  	if fd.MapValue().Message() == nil {
   405  		if !panics(func() {
   406  			mapv.Mutable(newMapKey(fd, 1))
   407  		}) {
   408  			t.Errorf("Mutable on %q succeeds, want panic", name)
   409  		}
   410  	} else {
   411  		k := newMapKey(fd, 1)
   412  		v := mapv.Mutable(k)
   413  		if got, want := mapv.Len(), 1; got != want {
   414  			t.Errorf("after Mutable on %q, Map.Len() = %v, want %v", name, got, want)
   415  		}
   416  		populateMessage(v.Message(), 1, nil)
   417  		if !valueEqual(mapv.Get(k), v) {
   418  			t.Errorf("after Mutable on %q, changing new mutable value does not change map entry", name)
   419  		}
   420  		mapv.Clear(k)
   421  	}
   422  }
   423  
   424  type testMap map[interface{}]protoreflect.Value
   425  
   426  func (m testMap) Get(k protoreflect.MapKey) protoreflect.Value     { return m[k.Interface()] }
   427  func (m testMap) Set(k protoreflect.MapKey, v protoreflect.Value)  { m[k.Interface()] = v }
   428  func (m testMap) Has(k protoreflect.MapKey) bool                   { return m.Get(k).IsValid() }
   429  func (m testMap) Clear(k protoreflect.MapKey)                      { delete(m, k.Interface()) }
   430  func (m testMap) Mutable(k protoreflect.MapKey) protoreflect.Value { panic("unimplemented") }
   431  func (m testMap) Len() int                                         { return len(m) }
   432  func (m testMap) NewValue() protoreflect.Value                     { panic("unimplemented") }
   433  func (m testMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) {
   434  	for k, v := range m {
   435  		if !f(protoreflect.ValueOf(k).MapKey(), v) {
   436  			return
   437  		}
   438  	}
   439  }
   440  func (m testMap) IsValid() bool { return true }
   441  
   442  // testFieldList exercises set/get/append/truncate of values in a list.
   443  func testFieldList(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) {
   444  	name := fd.FullName()
   445  	num := fd.Number()
   446  
   447  	m.Clear(fd) // start with an empty list
   448  	list := m.Get(fd).List()
   449  	if list.IsValid() {
   450  		t.Errorf("message.Get(%v).IsValid() = true, want false", name)
   451  	}
   452  	if !panics(func() {
   453  		m.Set(fd, protoreflect.ValueOfList(list))
   454  	}) {
   455  		t.Errorf("message.Set(%v, <invalid>) does not panic", name)
   456  	}
   457  	if !panics(func() {
   458  		list.Append(newListElement(fd, list, 0, nil))
   459  	}) {
   460  		t.Errorf("message.Get(%v).Append(...) of invalid list does not panic", name)
   461  	}
   462  	if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
   463  		t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
   464  	}
   465  	list = m.Mutable(fd).List() // mutable list
   466  	if !list.IsValid() {
   467  		t.Errorf("message.Get(%v).IsValid() = false, want true", name)
   468  	}
   469  	if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
   470  		t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
   471  	}
   472  
   473  	// Append values.
   474  	var want protoreflect.List = &testList{}
   475  	for i, n := range []seed{1, 0, minVal, maxVal} {
   476  		if got, want := m.Has(fd), i > 0; got != want {
   477  			t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
   478  		}
   479  		v := newListElement(fd, list, n, nil)
   480  		want.Append(v)
   481  		list.Append(v)
   482  
   483  		if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) {
   484  			t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
   485  		}
   486  	}
   487  
   488  	// Set values.
   489  	for i := 0; i < want.Len(); i++ {
   490  		v := newListElement(fd, list, seed(i+10), nil)
   491  		want.Set(i, v)
   492  		list.Set(i, v)
   493  		if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) {
   494  			t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
   495  		}
   496  	}
   497  
   498  	// Truncate.
   499  	for want.Len() > 0 {
   500  		n := want.Len() - 1
   501  		want.Truncate(n)
   502  		list.Truncate(n)
   503  		if got, want := m.Has(fd), want.Len() > 0; got != want {
   504  			t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
   505  		}
   506  		if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) {
   507  			t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
   508  		}
   509  	}
   510  
   511  	// AppendMutable.
   512  	if fd.Message() == nil {
   513  		if !panics(func() {
   514  			list.AppendMutable()
   515  		}) {
   516  			t.Errorf("AppendMutable on %q succeeds, want panic", name)
   517  		}
   518  	} else {
   519  		v := list.AppendMutable()
   520  		if got, want := list.Len(), 1; got != want {
   521  			t.Errorf("after AppendMutable on %q, list.Len() = %v, want %v", name, got, want)
   522  		}
   523  		populateMessage(v.Message(), 1, nil)
   524  		if !valueEqual(list.Get(0), v) {
   525  			t.Errorf("after AppendMutable on %q, changing new mutable value does not change list item 0", name)
   526  		}
   527  		want.Truncate(0)
   528  	}
   529  }
   530  
   531  type testList struct {
   532  	a []protoreflect.Value
   533  }
   534  
   535  func (l *testList) Append(v protoreflect.Value)       { l.a = append(l.a, v) }
   536  func (l *testList) AppendMutable() protoreflect.Value { panic("unimplemented") }
   537  func (l *testList) Get(n int) protoreflect.Value      { return l.a[n] }
   538  func (l *testList) Len() int                          { return len(l.a) }
   539  func (l *testList) Set(n int, v protoreflect.Value)   { l.a[n] = v }
   540  func (l *testList) Truncate(n int)                    { l.a = l.a[:n] }
   541  func (l *testList) NewElement() protoreflect.Value    { panic("unimplemented") }
   542  func (l *testList) IsValid() bool                     { return true }
   543  
   544  // testFieldFloat exercises some interesting floating-point scalar field values.
   545  func testFieldFloat(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) {
   546  	name := fd.FullName()
   547  	num := fd.Number()
   548  
   549  	for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
   550  		var val protoreflect.Value
   551  		if fd.Kind() == protoreflect.FloatKind {
   552  			val = protoreflect.ValueOfFloat32(float32(v))
   553  		} else {
   554  			val = protoreflect.ValueOfFloat64(float64(v))
   555  		}
   556  		m.Set(fd, val)
   557  		// Note that Has is true for -0.
   558  		if got, want := m.Has(fd), true; got != want {
   559  			t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want)
   560  		}
   561  		if got, want := m.Get(fd), val; !valueEqual(got, want) {
   562  			t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
   563  		}
   564  	}
   565  }
   566  
   567  // testOneof tests the behavior of fields in a oneof.
   568  func testOneof(t testing.TB, m protoreflect.Message, od protoreflect.OneofDescriptor) {
   569  	for _, mutable := range []bool{false, true} {
   570  		for i := 0; i < od.Fields().Len(); i++ {
   571  			fda := od.Fields().Get(i)
   572  			if mutable {
   573  				// Set fields by requesting a mutable reference.
   574  				if !fda.IsMap() && !fda.IsList() && fda.Message() == nil {
   575  					continue
   576  				}
   577  				_ = m.Mutable(fda)
   578  			} else {
   579  				// Set fields explicitly.
   580  				m.Set(fda, newValue(m, fda, 1, nil))
   581  			}
   582  			if got, want := m.WhichOneof(od), fda; got != want {
   583  				t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
   584  			}
   585  			for j := 0; j < od.Fields().Len(); j++ {
   586  				fdb := od.Fields().Get(j)
   587  				if got, want := m.Has(fdb), i == j; got != want {
   588  					t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
   589  				}
   590  			}
   591  		}
   592  	}
   593  }
   594  
   595  // testUnknown tests the behavior of unknown fields.
   596  func testUnknown(t testing.TB, m protoreflect.Message) {
   597  	var b []byte
   598  	b = protowire.AppendTag(b, 1000, protowire.VarintType)
   599  	b = protowire.AppendVarint(b, 1001)
   600  	m.SetUnknown(protoreflect.RawFields(b))
   601  	if got, want := []byte(m.GetUnknown()), b; !bytes.Equal(got, want) {
   602  		t.Errorf("after setting unknown fields:\nGetUnknown() = %v, want %v", got, want)
   603  	}
   604  }
   605  
   606  func formatValue(v protoreflect.Value) string {
   607  	switch v := v.Interface().(type) {
   608  	case protoreflect.List:
   609  		var buf bytes.Buffer
   610  		buf.WriteString("list[")
   611  		for i := 0; i < v.Len(); i++ {
   612  			if i > 0 {
   613  				buf.WriteString(" ")
   614  			}
   615  			buf.WriteString(formatValue(v.Get(i)))
   616  		}
   617  		buf.WriteString("]")
   618  		return buf.String()
   619  	case protoreflect.Map:
   620  		var buf bytes.Buffer
   621  		buf.WriteString("map[")
   622  		var keys []protoreflect.MapKey
   623  		v.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   624  			keys = append(keys, k)
   625  			return true
   626  		})
   627  		sort.Slice(keys, func(i, j int) bool {
   628  			return keys[i].String() < keys[j].String()
   629  		})
   630  		for i, k := range keys {
   631  			if i > 0 {
   632  				buf.WriteString(" ")
   633  			}
   634  			buf.WriteString(formatValue(k.Value()))
   635  			buf.WriteString(":")
   636  			buf.WriteString(formatValue(v.Get(k)))
   637  		}
   638  		buf.WriteString("]")
   639  		return buf.String()
   640  	case protoreflect.Message:
   641  		b, err := prototext.Marshal(v.Interface())
   642  		if err != nil {
   643  			return fmt.Sprintf("<%v>", err)
   644  		}
   645  		return fmt.Sprintf("%v{%s}", v.Descriptor().FullName(), b)
   646  	case string:
   647  		return fmt.Sprintf("%q", v)
   648  	default:
   649  		return fmt.Sprint(v)
   650  	}
   651  }
   652  
   653  func valueEqual(a, b protoreflect.Value) bool {
   654  	ai, bi := a.Interface(), b.Interface()
   655  	switch ai.(type) {
   656  	case protoreflect.Message:
   657  		return proto.Equal(
   658  			a.Message().Interface(),
   659  			b.Message().Interface(),
   660  		)
   661  	case protoreflect.List:
   662  		lista, listb := a.List(), b.List()
   663  		if lista.Len() != listb.Len() {
   664  			return false
   665  		}
   666  		for i := 0; i < lista.Len(); i++ {
   667  			if !valueEqual(lista.Get(i), listb.Get(i)) {
   668  				return false
   669  			}
   670  		}
   671  		return true
   672  	case protoreflect.Map:
   673  		mapa, mapb := a.Map(), b.Map()
   674  		if mapa.Len() != mapb.Len() {
   675  			return false
   676  		}
   677  		equal := true
   678  		mapa.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   679  			if !valueEqual(v, mapb.Get(k)) {
   680  				equal = false
   681  				return false
   682  			}
   683  			return true
   684  		})
   685  		return equal
   686  	case []byte:
   687  		return bytes.Equal(a.Bytes(), b.Bytes())
   688  	case float32:
   689  		// NaNs are equal, but must be the same NaN.
   690  		return math.Float32bits(ai.(float32)) == math.Float32bits(bi.(float32))
   691  	case float64:
   692  		// NaNs are equal, but must be the same NaN.
   693  		return math.Float64bits(ai.(float64)) == math.Float64bits(bi.(float64))
   694  	default:
   695  		return ai == bi
   696  	}
   697  }
   698  
   699  // A seed is used to vary the content of a value.
   700  //
   701  // A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages
   702  // is unpopulated.
   703  //
   704  // A seed of minVal or maxVal is the least or greatest value of the value type.
   705  type seed int
   706  
   707  const (
   708  	minVal seed = -1
   709  	maxVal seed = -2
   710  )
   711  
   712  // newSeed creates new seed values from a base, for example to create seeds for the
   713  // elements in a list. If the input seed is minVal or maxVal, so is the output.
   714  func newSeed(n seed, adjust ...int) seed {
   715  	switch n {
   716  	case minVal, maxVal:
   717  		return n
   718  	}
   719  	for _, a := range adjust {
   720  		n = 10*n + seed(a)
   721  	}
   722  	return n
   723  }
   724  
   725  // newValue returns a new value assignable to a field.
   726  //
   727  // The stack parameter is used to avoid infinite recursion when populating circular
   728  // data structures.
   729  func newValue(m protoreflect.Message, fd protoreflect.FieldDescriptor, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value {
   730  	switch {
   731  	case fd.IsList():
   732  		if n == 0 {
   733  			return m.New().Mutable(fd)
   734  		}
   735  		list := m.NewField(fd).List()
   736  		list.Append(newListElement(fd, list, 0, stack))
   737  		list.Append(newListElement(fd, list, minVal, stack))
   738  		list.Append(newListElement(fd, list, maxVal, stack))
   739  		list.Append(newListElement(fd, list, n, stack))
   740  		return protoreflect.ValueOfList(list)
   741  	case fd.IsMap():
   742  		if n == 0 {
   743  			return m.New().Mutable(fd)
   744  		}
   745  		mapv := m.NewField(fd).Map()
   746  		mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
   747  		mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
   748  		mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
   749  		mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
   750  		return protoreflect.ValueOfMap(mapv)
   751  	case fd.Message() != nil:
   752  		return populateMessage(m.NewField(fd).Message(), n, stack)
   753  	default:
   754  		return newScalarValue(fd, n)
   755  	}
   756  }
   757  
   758  func newListElement(fd protoreflect.FieldDescriptor, list protoreflect.List, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value {
   759  	if fd.Message() == nil {
   760  		return newScalarValue(fd, n)
   761  	}
   762  	return populateMessage(list.NewElement().Message(), n, stack)
   763  }
   764  
   765  func newMapKey(fd protoreflect.FieldDescriptor, n seed) protoreflect.MapKey {
   766  	kd := fd.MapKey()
   767  	return newScalarValue(kd, n).MapKey()
   768  }
   769  
   770  func newMapValue(fd protoreflect.FieldDescriptor, mapv protoreflect.Map, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value {
   771  	vd := fd.MapValue()
   772  	if vd.Message() == nil {
   773  		return newScalarValue(vd, n)
   774  	}
   775  	return populateMessage(mapv.NewValue().Message(), n, stack)
   776  }
   777  
   778  func newScalarValue(fd protoreflect.FieldDescriptor, n seed) protoreflect.Value {
   779  	switch fd.Kind() {
   780  	case protoreflect.BoolKind:
   781  		return protoreflect.ValueOfBool(n != 0)
   782  	case protoreflect.EnumKind:
   783  		vals := fd.Enum().Values()
   784  		var i int
   785  		switch n {
   786  		case minVal:
   787  			i = 0
   788  		case maxVal:
   789  			i = vals.Len() - 1
   790  		default:
   791  			i = int(n) % vals.Len()
   792  		}
   793  		return protoreflect.ValueOfEnum(vals.Get(i).Number())
   794  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   795  		switch n {
   796  		case minVal:
   797  			return protoreflect.ValueOfInt32(math.MinInt32)
   798  		case maxVal:
   799  			return protoreflect.ValueOfInt32(math.MaxInt32)
   800  		default:
   801  			return protoreflect.ValueOfInt32(int32(n))
   802  		}
   803  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   804  		switch n {
   805  		case minVal:
   806  			// Only use 0 for the zero value.
   807  			return protoreflect.ValueOfUint32(1)
   808  		case maxVal:
   809  			return protoreflect.ValueOfUint32(math.MaxInt32)
   810  		default:
   811  			return protoreflect.ValueOfUint32(uint32(n))
   812  		}
   813  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   814  		switch n {
   815  		case minVal:
   816  			return protoreflect.ValueOfInt64(math.MinInt64)
   817  		case maxVal:
   818  			return protoreflect.ValueOfInt64(math.MaxInt64)
   819  		default:
   820  			return protoreflect.ValueOfInt64(int64(n))
   821  		}
   822  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   823  		switch n {
   824  		case minVal:
   825  			// Only use 0 for the zero value.
   826  			return protoreflect.ValueOfUint64(1)
   827  		case maxVal:
   828  			return protoreflect.ValueOfUint64(math.MaxInt64)
   829  		default:
   830  			return protoreflect.ValueOfUint64(uint64(n))
   831  		}
   832  	case protoreflect.FloatKind:
   833  		switch n {
   834  		case minVal:
   835  			return protoreflect.ValueOfFloat32(math.SmallestNonzeroFloat32)
   836  		case maxVal:
   837  			return protoreflect.ValueOfFloat32(math.MaxFloat32)
   838  		default:
   839  			return protoreflect.ValueOfFloat32(1.5 * float32(n))
   840  		}
   841  	case protoreflect.DoubleKind:
   842  		switch n {
   843  		case minVal:
   844  			return protoreflect.ValueOfFloat64(math.SmallestNonzeroFloat64)
   845  		case maxVal:
   846  			return protoreflect.ValueOfFloat64(math.MaxFloat64)
   847  		default:
   848  			return protoreflect.ValueOfFloat64(1.5 * float64(n))
   849  		}
   850  	case protoreflect.StringKind:
   851  		if n == 0 {
   852  			return protoreflect.ValueOfString("")
   853  		}
   854  		return protoreflect.ValueOfString(fmt.Sprintf("%d", n))
   855  	case protoreflect.BytesKind:
   856  		if n == 0 {
   857  			return protoreflect.ValueOfBytes(nil)
   858  		}
   859  		return protoreflect.ValueOfBytes([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
   860  	}
   861  	panic("unhandled kind")
   862  }
   863  
   864  func populateMessage(m protoreflect.Message, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value {
   865  	if n == 0 {
   866  		return protoreflect.ValueOfMessage(m)
   867  	}
   868  	md := m.Descriptor()
   869  	for _, x := range stack {
   870  		if md == x {
   871  			return protoreflect.ValueOfMessage(m)
   872  		}
   873  	}
   874  	stack = append(stack, md)
   875  	for i := 0; i < md.Fields().Len(); i++ {
   876  		fd := md.Fields().Get(i)
   877  		if fd.IsWeak() {
   878  			continue
   879  		}
   880  		m.Set(fd, newValue(m, fd, newSeed(n, i), stack))
   881  	}
   882  	return protoreflect.ValueOfMessage(m)
   883  }
   884  
   885  func panics(f func()) (didPanic bool) {
   886  	defer func() {
   887  		if err := recover(); err != nil {
   888  			didPanic = true
   889  		}
   890  	}()
   891  	f()
   892  	return false
   893  }
   894  

View as plain text