...

Source file src/google.golang.org/protobuf/reflect/protoregistry/registry_test.go

Documentation: google.golang.org/protobuf/reflect/protoregistry

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package protoregistry_test
     6  
     7  import (
     8  	"fmt"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/google/go-cmp/cmp"
    13  	"github.com/google/go-cmp/cmp/cmpopts"
    14  
    15  	"google.golang.org/protobuf/encoding/prototext"
    16  	pimpl "google.golang.org/protobuf/internal/impl"
    17  	"google.golang.org/protobuf/reflect/protodesc"
    18  	"google.golang.org/protobuf/reflect/protoreflect"
    19  	"google.golang.org/protobuf/reflect/protoregistry"
    20  
    21  	testpb "google.golang.org/protobuf/internal/testprotos/registry"
    22  	"google.golang.org/protobuf/types/descriptorpb"
    23  )
    24  
    25  func mustMakeFile(s string) protoreflect.FileDescriptor {
    26  	pb := new(descriptorpb.FileDescriptorProto)
    27  	if err := prototext.Unmarshal([]byte(s), pb); err != nil {
    28  		panic(err)
    29  	}
    30  	fd, err := protodesc.NewFile(pb, nil)
    31  	if err != nil {
    32  		panic(err)
    33  	}
    34  	return fd
    35  }
    36  
    37  func TestFiles(t *testing.T) {
    38  	type (
    39  		file struct {
    40  			Path string
    41  			Pkg  protoreflect.FullName
    42  		}
    43  		testFile struct {
    44  			inFile  protoreflect.FileDescriptor
    45  			wantErr string
    46  		}
    47  		testFindDesc struct {
    48  			inName    protoreflect.FullName
    49  			wantFound bool
    50  		}
    51  		testRangePkg struct {
    52  			inPkg     protoreflect.FullName
    53  			wantFiles []file
    54  		}
    55  		testFindPath struct {
    56  			inPath    string
    57  			wantFiles []file
    58  			wantErr   string
    59  		}
    60  	)
    61  
    62  	tests := []struct {
    63  		files     []testFile
    64  		findDescs []testFindDesc
    65  		rangePkgs []testRangePkg
    66  		findPaths []testFindPath
    67  	}{{
    68  		// Test that overlapping packages and files are permitted.
    69  		files: []testFile{
    70  			{inFile: mustMakeFile(`syntax:"proto2" name:"test1.proto" package:"foo.bar"`)},
    71  			{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"my.test"`)},
    72  			{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`), wantErr: "already registered"},
    73  			{inFile: mustMakeFile(`syntax:"proto2" name:"test2.proto" package:"my.test.package"`)},
    74  			{inFile: mustMakeFile(`syntax:"proto2" name:"weird" package:"foo.bar"`)},
    75  			{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/baz/../test.proto" package:"my.test"`)},
    76  		},
    77  
    78  		rangePkgs: []testRangePkg{{
    79  			inPkg: "nothing",
    80  		}, {
    81  			inPkg: "",
    82  		}, {
    83  			inPkg: ".",
    84  		}, {
    85  			inPkg: "foo",
    86  		}, {
    87  			inPkg: "foo.",
    88  		}, {
    89  			inPkg: "foo..",
    90  		}, {
    91  			inPkg: "foo.bar",
    92  			wantFiles: []file{
    93  				{"test1.proto", "foo.bar"},
    94  				{"weird", "foo.bar"},
    95  			},
    96  		}, {
    97  			inPkg: "my.test",
    98  			wantFiles: []file{
    99  				{"foo/bar/baz/../test.proto", "my.test"},
   100  				{"foo/bar/test.proto", "my.test"},
   101  			},
   102  		}, {
   103  			inPkg: "fo",
   104  		}},
   105  
   106  		findPaths: []testFindPath{{
   107  			inPath:  "nothing",
   108  			wantErr: "not found",
   109  		}, {
   110  			inPath: "weird",
   111  			wantFiles: []file{
   112  				{"weird", "foo.bar"},
   113  			},
   114  		}, {
   115  			inPath: "foo/bar/test.proto",
   116  			wantFiles: []file{
   117  				{"foo/bar/test.proto", "my.test"},
   118  			},
   119  		}},
   120  	}, {
   121  		// Test when new enum conflicts with existing package.
   122  		files: []testFile{{
   123  			inFile: mustMakeFile(`syntax:"proto2" name:"test1a.proto" package:"foo.bar.baz"`),
   124  		}, {
   125  			inFile:  mustMakeFile(`syntax:"proto2" name:"test1b.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
   126  			wantErr: `file "test1b.proto" has a name conflict over foo`,
   127  		}},
   128  	}, {
   129  		// Test when new package conflicts with existing enum.
   130  		files: []testFile{{
   131  			inFile: mustMakeFile(`syntax:"proto2" name:"test2a.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
   132  		}, {
   133  			inFile:  mustMakeFile(`syntax:"proto2" name:"test2b.proto" package:"foo.bar.baz"`),
   134  			wantErr: `file "test2b.proto" has a package name conflict over foo`,
   135  		}},
   136  	}, {
   137  		// Test when new enum conflicts with existing enum in same package.
   138  		files: []testFile{{
   139  			inFile: mustMakeFile(`syntax:"proto2" name:"test3a.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE" number:0}]}]`),
   140  		}, {
   141  			inFile:  mustMakeFile(`syntax:"proto2" name:"test3b.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE2" number:0}]}]`),
   142  			wantErr: `file "test3b.proto" has a name conflict over foo.BAR`,
   143  		}},
   144  	}, {
   145  		files: []testFile{{
   146  			inFile: mustMakeFile(`
   147  				syntax:  "proto2"
   148  				name:    "test1.proto"
   149  				package: "fizz.buzz"
   150  				message_type: [{
   151  					name: "Message"
   152  					field: [
   153  						{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING oneof_index:0}
   154  					]
   155  					oneof_decl:      [{name:"Oneof"}]
   156  					extension_range: [{start:1000 end:2000}]
   157  
   158  					enum_type: [
   159  						{name:"Enum" value:[{name:"EnumValue" number:0}]}
   160  					]
   161  					nested_type: [
   162  						{name:"Message" field:[{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}]}
   163  					]
   164  					extension: [
   165  						{name:"Extension" number:1001 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
   166  					]
   167  				}]
   168  				enum_type: [{
   169  					name:  "Enum"
   170  					value: [{name:"EnumValue" number:0}]
   171  				}]
   172  				extension: [
   173  					{name:"Extension" number:1000 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
   174  				]
   175  				service: [{
   176  					name: "Service"
   177  					method: [{
   178  						name:             "Method"
   179  						input_type:       ".fizz.buzz.Message"
   180  						output_type:      ".fizz.buzz.Message"
   181  						client_streaming: true
   182  						server_streaming: true
   183  					}]
   184  				}]
   185  			`),
   186  		}, {
   187  			inFile: mustMakeFile(`
   188  				syntax:  "proto2"
   189  				name:    "test2.proto"
   190  				package: "fizz.buzz.gazz"
   191  				enum_type: [{
   192  					name:  "Enum"
   193  					value: [{name:"EnumValue" number:0}]
   194  				}]
   195  			`),
   196  		}, {
   197  			inFile: mustMakeFile(`
   198  				syntax:  "proto2"
   199  				name:    "test3.proto"
   200  				package: "fizz.buzz"
   201  				enum_type: [{
   202  					name:  "Enum1"
   203  					value: [{name:"EnumValue1" number:0}]
   204  				}, {
   205  					name:  "Enum2"
   206  					value: [{name:"EnumValue2" number:0}]
   207  				}]
   208  			`),
   209  		}, {
   210  			// Make sure we can register without package name.
   211  			inFile: mustMakeFile(`
   212  				name:   "weird"
   213  				syntax: "proto2"
   214  				message_type: [{
   215  					name: "Message"
   216  					nested_type: [{
   217  						name: "Message"
   218  						nested_type: [{
   219  							name: "Message"
   220  						}]
   221  					}]
   222  				}]
   223  			`),
   224  		}},
   225  		findDescs: []testFindDesc{
   226  			{inName: "fizz.buzz.message", wantFound: false},
   227  			{inName: "fizz.buzz.Message", wantFound: true},
   228  			{inName: "fizz.buzz.Message.X", wantFound: false},
   229  			{inName: "fizz.buzz.Field", wantFound: false},
   230  			{inName: "fizz.buzz.Oneof", wantFound: false},
   231  			{inName: "fizz.buzz.Message.Field", wantFound: true},
   232  			{inName: "fizz.buzz.Message.Field.X", wantFound: false},
   233  			{inName: "fizz.buzz.Message.Oneof", wantFound: true},
   234  			{inName: "fizz.buzz.Message.Oneof.X", wantFound: false},
   235  			{inName: "fizz.buzz.Message.Message", wantFound: true},
   236  			{inName: "fizz.buzz.Message.Message.X", wantFound: false},
   237  			{inName: "fizz.buzz.Message.Enum", wantFound: true},
   238  			{inName: "fizz.buzz.Message.Enum.X", wantFound: false},
   239  			{inName: "fizz.buzz.Message.EnumValue", wantFound: true},
   240  			{inName: "fizz.buzz.Message.EnumValue.X", wantFound: false},
   241  			{inName: "fizz.buzz.Message.Extension", wantFound: true},
   242  			{inName: "fizz.buzz.Message.Extension.X", wantFound: false},
   243  			{inName: "fizz.buzz.enum", wantFound: false},
   244  			{inName: "fizz.buzz.Enum", wantFound: true},
   245  			{inName: "fizz.buzz.Enum.X", wantFound: false},
   246  			{inName: "fizz.buzz.EnumValue", wantFound: true},
   247  			{inName: "fizz.buzz.EnumValue.X", wantFound: false},
   248  			{inName: "fizz.buzz.Enum.EnumValue", wantFound: false},
   249  			{inName: "fizz.buzz.Extension", wantFound: true},
   250  			{inName: "fizz.buzz.Extension.X", wantFound: false},
   251  			{inName: "fizz.buzz.service", wantFound: false},
   252  			{inName: "fizz.buzz.Service", wantFound: true},
   253  			{inName: "fizz.buzz.Service.X", wantFound: false},
   254  			{inName: "fizz.buzz.Method", wantFound: false},
   255  			{inName: "fizz.buzz.Service.Method", wantFound: true},
   256  			{inName: "fizz.buzz.Service.Method.X", wantFound: false},
   257  
   258  			{inName: "fizz.buzz.gazz", wantFound: false},
   259  			{inName: "fizz.buzz.gazz.Enum", wantFound: true},
   260  			{inName: "fizz.buzz.gazz.EnumValue", wantFound: true},
   261  			{inName: "fizz.buzz.gazz.Enum.EnumValue", wantFound: false},
   262  
   263  			{inName: "fizz.buzz", wantFound: false},
   264  			{inName: "fizz.buzz.Enum1", wantFound: true},
   265  			{inName: "fizz.buzz.EnumValue1", wantFound: true},
   266  			{inName: "fizz.buzz.Enum1.EnumValue1", wantFound: false},
   267  			{inName: "fizz.buzz.Enum2", wantFound: true},
   268  			{inName: "fizz.buzz.EnumValue2", wantFound: true},
   269  			{inName: "fizz.buzz.Enum2.EnumValue2", wantFound: false},
   270  			{inName: "fizz.buzz.Enum3", wantFound: false},
   271  
   272  			{inName: "", wantFound: false},
   273  			{inName: "Message", wantFound: true},
   274  			{inName: "Message.Message", wantFound: true},
   275  			{inName: "Message.Message.Message", wantFound: true},
   276  			{inName: "Message.Message.Message.Message", wantFound: false},
   277  		},
   278  	}}
   279  
   280  	sortFiles := cmpopts.SortSlices(func(x, y file) bool {
   281  		return x.Path < y.Path || (x.Path == y.Path && x.Pkg < y.Pkg)
   282  	})
   283  	for _, tt := range tests {
   284  		t.Run("", func(t *testing.T) {
   285  			var files protoregistry.Files
   286  			for i, tc := range tt.files {
   287  				gotErr := files.RegisterFile(tc.inFile)
   288  				if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
   289  					t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
   290  				}
   291  			}
   292  
   293  			for _, tc := range tt.findDescs {
   294  				d, _ := files.FindDescriptorByName(tc.inName)
   295  				gotFound := d != nil
   296  				if gotFound != tc.wantFound {
   297  					t.Errorf("FindDescriptorByName(%v) find mismatch: got %v, want %v", tc.inName, gotFound, tc.wantFound)
   298  				}
   299  			}
   300  
   301  			for _, tc := range tt.rangePkgs {
   302  				var gotFiles []file
   303  				var gotCnt int
   304  				wantCnt := files.NumFilesByPackage(tc.inPkg)
   305  				files.RangeFilesByPackage(tc.inPkg, func(fd protoreflect.FileDescriptor) bool {
   306  					gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
   307  					gotCnt++
   308  					return true
   309  				})
   310  				if gotCnt != wantCnt {
   311  					t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
   312  				}
   313  				if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
   314  					t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
   315  				}
   316  			}
   317  
   318  			for _, tc := range tt.findPaths {
   319  				var gotFiles []file
   320  				fd, gotErr := files.FindFileByPath(tc.inPath)
   321  				if gotErr == nil {
   322  					gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
   323  				}
   324  				if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
   325  					t.Errorf("FindFileByPath(%v) = %v, want %v", tc.inPath, gotErr, tc.wantErr)
   326  				}
   327  				if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
   328  					t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff)
   329  				}
   330  			}
   331  		})
   332  	}
   333  }
   334  
   335  func TestTypes(t *testing.T) {
   336  	mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{})
   337  	et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE)
   338  	xt1 := testpb.E_StringField
   339  	xt2 := testpb.E_Message4_MessageField
   340  	registry := new(protoregistry.Types)
   341  	if err := registry.RegisterMessage(mt1); err != nil {
   342  		t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err)
   343  	}
   344  	if err := registry.RegisterEnum(et1); err != nil {
   345  		t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err)
   346  	}
   347  	if err := registry.RegisterExtension(xt1); err != nil {
   348  		t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err)
   349  	}
   350  	if err := registry.RegisterExtension(xt2); err != nil {
   351  		t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err)
   352  	}
   353  
   354  	t.Run("FindMessageByName", func(t *testing.T) {
   355  		tests := []struct {
   356  			name         string
   357  			messageType  protoreflect.MessageType
   358  			wantErr      bool
   359  			wantNotFound bool
   360  		}{{
   361  			name:        "testprotos.Message1",
   362  			messageType: mt1,
   363  		}, {
   364  			name:         "testprotos.NoSuchMessage",
   365  			wantErr:      true,
   366  			wantNotFound: true,
   367  		}, {
   368  			name:    "testprotos.Enum1",
   369  			wantErr: true,
   370  		}, {
   371  			name:    "testprotos.Enum2",
   372  			wantErr: true,
   373  		}, {
   374  			name:    "testprotos.Enum3",
   375  			wantErr: true,
   376  		}}
   377  		for _, tc := range tests {
   378  			got, err := registry.FindMessageByName(protoreflect.FullName(tc.name))
   379  			gotErr := err != nil
   380  			if gotErr != tc.wantErr {
   381  				t.Errorf("FindMessageByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
   382  				continue
   383  			}
   384  			if tc.wantNotFound && err != protoregistry.NotFound {
   385  				t.Errorf("FindMessageByName(%v) got error: %v, want NotFound error", tc.name, err)
   386  				continue
   387  			}
   388  			if got != tc.messageType {
   389  				t.Errorf("FindMessageByName(%v) got wrong value: %v", tc.name, got)
   390  			}
   391  		}
   392  	})
   393  
   394  	t.Run("FindMessageByURL", func(t *testing.T) {
   395  		tests := []struct {
   396  			name         string
   397  			messageType  protoreflect.MessageType
   398  			wantErr      bool
   399  			wantNotFound bool
   400  		}{{
   401  			name:        "testprotos.Message1",
   402  			messageType: mt1,
   403  		}, {
   404  			name:         "type.googleapis.com/testprotos.Nada",
   405  			wantErr:      true,
   406  			wantNotFound: true,
   407  		}, {
   408  			name:    "testprotos.Enum1",
   409  			wantErr: true,
   410  		}}
   411  		for _, tc := range tests {
   412  			got, err := registry.FindMessageByURL(tc.name)
   413  			gotErr := err != nil
   414  			if gotErr != tc.wantErr {
   415  				t.Errorf("FindMessageByURL(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
   416  				continue
   417  			}
   418  			if tc.wantNotFound && err != protoregistry.NotFound {
   419  				t.Errorf("FindMessageByURL(%v) got error: %v, want NotFound error", tc.name, err)
   420  				continue
   421  			}
   422  			if got != tc.messageType {
   423  				t.Errorf("FindMessageByURL(%v) got wrong value: %v", tc.name, got)
   424  			}
   425  		}
   426  	})
   427  
   428  	t.Run("FindEnumByName", func(t *testing.T) {
   429  		tests := []struct {
   430  			name         string
   431  			enumType     protoreflect.EnumType
   432  			wantErr      bool
   433  			wantNotFound bool
   434  		}{{
   435  			name:     "testprotos.Enum1",
   436  			enumType: et1,
   437  		}, {
   438  			name:         "testprotos.None",
   439  			wantErr:      true,
   440  			wantNotFound: true,
   441  		}, {
   442  			name:    "testprotos.Message1",
   443  			wantErr: true,
   444  		}}
   445  		for _, tc := range tests {
   446  			got, err := registry.FindEnumByName(protoreflect.FullName(tc.name))
   447  			gotErr := err != nil
   448  			if gotErr != tc.wantErr {
   449  				t.Errorf("FindEnumByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
   450  				continue
   451  			}
   452  			if tc.wantNotFound && err != protoregistry.NotFound {
   453  				t.Errorf("FindEnumByName(%v) got error: %v, want NotFound error", tc.name, err)
   454  				continue
   455  			}
   456  			if got != tc.enumType {
   457  				t.Errorf("FindEnumByName(%v) got wrong value: %v", tc.name, got)
   458  			}
   459  		}
   460  	})
   461  
   462  	t.Run("FindExtensionByName", func(t *testing.T) {
   463  		tests := []struct {
   464  			name          string
   465  			extensionType protoreflect.ExtensionType
   466  			wantErr       bool
   467  			wantNotFound  bool
   468  		}{{
   469  			name:          "testprotos.string_field",
   470  			extensionType: xt1,
   471  		}, {
   472  			name:          "testprotos.Message4.message_field",
   473  			extensionType: xt2,
   474  		}, {
   475  			name:         "testprotos.None",
   476  			wantErr:      true,
   477  			wantNotFound: true,
   478  		}, {
   479  			name:    "testprotos.Message1",
   480  			wantErr: true,
   481  		}}
   482  		for _, tc := range tests {
   483  			got, err := registry.FindExtensionByName(protoreflect.FullName(tc.name))
   484  			gotErr := err != nil
   485  			if gotErr != tc.wantErr {
   486  				t.Errorf("FindExtensionByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
   487  				continue
   488  			}
   489  			if tc.wantNotFound && err != protoregistry.NotFound {
   490  				t.Errorf("FindExtensionByName(%v) got error: %v, want NotFound error", tc.name, err)
   491  				continue
   492  			}
   493  			if got != tc.extensionType {
   494  				t.Errorf("FindExtensionByName(%v) got wrong value: %v", tc.name, got)
   495  			}
   496  		}
   497  	})
   498  
   499  	t.Run("FindExtensionByNumber", func(t *testing.T) {
   500  		tests := []struct {
   501  			parent        string
   502  			number        int32
   503  			extensionType protoreflect.ExtensionType
   504  			wantErr       bool
   505  			wantNotFound  bool
   506  		}{{
   507  			parent:        "testprotos.Message1",
   508  			number:        11,
   509  			extensionType: xt1,
   510  		}, {
   511  			parent:       "testprotos.Message1",
   512  			number:       13,
   513  			wantErr:      true,
   514  			wantNotFound: true,
   515  		}, {
   516  			parent:        "testprotos.Message1",
   517  			number:        21,
   518  			extensionType: xt2,
   519  		}, {
   520  			parent:       "testprotos.Message1",
   521  			number:       23,
   522  			wantErr:      true,
   523  			wantNotFound: true,
   524  		}, {
   525  			parent:       "testprotos.NoSuchMessage",
   526  			number:       11,
   527  			wantErr:      true,
   528  			wantNotFound: true,
   529  		}, {
   530  			parent:       "testprotos.Message1",
   531  			number:       30,
   532  			wantErr:      true,
   533  			wantNotFound: true,
   534  		}, {
   535  			parent:       "testprotos.Message1",
   536  			number:       99,
   537  			wantErr:      true,
   538  			wantNotFound: true,
   539  		}}
   540  		for _, tc := range tests {
   541  			got, err := registry.FindExtensionByNumber(protoreflect.FullName(tc.parent), protoreflect.FieldNumber(tc.number))
   542  			gotErr := err != nil
   543  			if gotErr != tc.wantErr {
   544  				t.Errorf("FindExtensionByNumber(%v, %d) = (_, %v), want error? %t", tc.parent, tc.number, err, tc.wantErr)
   545  				continue
   546  			}
   547  			if tc.wantNotFound && err != protoregistry.NotFound {
   548  				t.Errorf("FindExtensionByNumber(%v, %d) got error %v, want NotFound error", tc.parent, tc.number, err)
   549  				continue
   550  			}
   551  			if got != tc.extensionType {
   552  				t.Errorf("FindExtensionByNumber(%v, %d) got wrong value: %v", tc.parent, tc.number, got)
   553  			}
   554  		}
   555  	})
   556  
   557  	sortTypes := cmp.Options{
   558  		cmpopts.SortSlices(func(x, y protoreflect.EnumType) bool {
   559  			return x.Descriptor().FullName() < y.Descriptor().FullName()
   560  		}),
   561  		cmpopts.SortSlices(func(x, y protoreflect.MessageType) bool {
   562  			return x.Descriptor().FullName() < y.Descriptor().FullName()
   563  		}),
   564  		cmpopts.SortSlices(func(x, y protoreflect.ExtensionType) bool {
   565  			return x.TypeDescriptor().FullName() < y.TypeDescriptor().FullName()
   566  		}),
   567  	}
   568  	compare := cmp.Options{
   569  		cmp.Comparer(func(x, y protoreflect.EnumType) bool {
   570  			return x == y
   571  		}),
   572  		cmp.Comparer(func(x, y protoreflect.ExtensionType) bool {
   573  			return x == y
   574  		}),
   575  		cmp.Comparer(func(x, y protoreflect.MessageType) bool {
   576  			return x == y
   577  		}),
   578  	}
   579  
   580  	t.Run("RangeEnums", func(t *testing.T) {
   581  		want := []protoreflect.EnumType{et1}
   582  		var got []protoreflect.EnumType
   583  		var gotCnt int
   584  		wantCnt := registry.NumEnums()
   585  		registry.RangeEnums(func(et protoreflect.EnumType) bool {
   586  			got = append(got, et)
   587  			gotCnt++
   588  			return true
   589  		})
   590  
   591  		if gotCnt != wantCnt {
   592  			t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
   593  		}
   594  		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
   595  			t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
   596  		}
   597  	})
   598  
   599  	t.Run("RangeMessages", func(t *testing.T) {
   600  		want := []protoreflect.MessageType{mt1}
   601  		var got []protoreflect.MessageType
   602  		var gotCnt int
   603  		wantCnt := registry.NumMessages()
   604  		registry.RangeMessages(func(mt protoreflect.MessageType) bool {
   605  			got = append(got, mt)
   606  			gotCnt++
   607  			return true
   608  		})
   609  
   610  		if gotCnt != wantCnt {
   611  			t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
   612  		}
   613  		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
   614  			t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
   615  		}
   616  	})
   617  
   618  	t.Run("RangeExtensions", func(t *testing.T) {
   619  		want := []protoreflect.ExtensionType{xt1, xt2}
   620  		var got []protoreflect.ExtensionType
   621  		var gotCnt int
   622  		wantCnt := registry.NumExtensions()
   623  		registry.RangeExtensions(func(xt protoreflect.ExtensionType) bool {
   624  			got = append(got, xt)
   625  			gotCnt++
   626  			return true
   627  		})
   628  
   629  		if gotCnt != wantCnt {
   630  			t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
   631  		}
   632  		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
   633  			t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
   634  		}
   635  	})
   636  
   637  	t.Run("RangeExtensionsByMessage", func(t *testing.T) {
   638  		want := []protoreflect.ExtensionType{xt1, xt2}
   639  		var got []protoreflect.ExtensionType
   640  		var gotCnt int
   641  		wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
   642  		registry.RangeExtensionsByMessage("testprotos.Message1", func(xt protoreflect.ExtensionType) bool {
   643  			got = append(got, xt)
   644  			gotCnt++
   645  			return true
   646  		})
   647  
   648  		if gotCnt != wantCnt {
   649  			t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
   650  		}
   651  		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
   652  			t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
   653  		}
   654  	})
   655  }
   656  

View as plain text