...

Source file src/google.golang.org/protobuf/reflect/protoregistry/registry.go

Documentation: google.golang.org/protobuf/reflect/protoregistry

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package protoregistry provides data structures to register and lookup
     6  // protobuf descriptor types.
     7  //
     8  // The [Files] registry contains file descriptors and provides the ability
     9  // to iterate over the files or lookup a specific descriptor within the files.
    10  // [Files] only contains protobuf descriptors and has no understanding of Go
    11  // type information that may be associated with each descriptor.
    12  //
    13  // The [Types] registry contains descriptor types for which there is a known
    14  // Go type associated with that descriptor. It provides the ability to iterate
    15  // over the registered types or lookup a type by name.
    16  package protoregistry
    17  
    18  import (
    19  	"fmt"
    20  	"os"
    21  	"strings"
    22  	"sync"
    23  
    24  	"google.golang.org/protobuf/internal/encoding/messageset"
    25  	"google.golang.org/protobuf/internal/errors"
    26  	"google.golang.org/protobuf/internal/flags"
    27  	"google.golang.org/protobuf/reflect/protoreflect"
    28  )
    29  
    30  // conflictPolicy configures the policy for handling registration conflicts.
    31  //
    32  // It can be over-written at compile time with a linker-initialized variable:
    33  //
    34  //	go build -ldflags "-X google.golang.org/protobuf/reflect/protoregistry.conflictPolicy=warn"
    35  //
    36  // It can be over-written at program execution with an environment variable:
    37  //
    38  //	GOLANG_PROTOBUF_REGISTRATION_CONFLICT=warn ./main
    39  //
    40  // Neither of the above are covered by the compatibility promise and
    41  // may be removed in a future release of this module.
    42  var conflictPolicy = "panic" // "panic" | "warn" | "ignore"
    43  
    44  // ignoreConflict reports whether to ignore a registration conflict
    45  // given the descriptor being registered and the error.
    46  // It is a variable so that the behavior is easily overridden in another file.
    47  var ignoreConflict = func(d protoreflect.Descriptor, err error) bool {
    48  	const env = "GOLANG_PROTOBUF_REGISTRATION_CONFLICT"
    49  	const faq = "https://protobuf.dev/reference/go/faq#namespace-conflict"
    50  	policy := conflictPolicy
    51  	if v := os.Getenv(env); v != "" {
    52  		policy = v
    53  	}
    54  	switch policy {
    55  	case "panic":
    56  		panic(fmt.Sprintf("%v\nSee %v\n", err, faq))
    57  	case "warn":
    58  		fmt.Fprintf(os.Stderr, "WARNING: %v\nSee %v\n\n", err, faq)
    59  		return true
    60  	case "ignore":
    61  		return true
    62  	default:
    63  		panic("invalid " + env + " value: " + os.Getenv(env))
    64  	}
    65  }
    66  
    67  var globalMutex sync.RWMutex
    68  
    69  // GlobalFiles is a global registry of file descriptors.
    70  var GlobalFiles *Files = new(Files)
    71  
    72  // GlobalTypes is the registry used by default for type lookups
    73  // unless a local registry is provided by the user.
    74  var GlobalTypes *Types = new(Types)
    75  
    76  // NotFound is a sentinel error value to indicate that the type was not found.
    77  //
    78  // Since registry lookup can happen in the critical performance path, resolvers
    79  // must return this exact error value, not an error wrapping it.
    80  var NotFound = errors.New("not found")
    81  
    82  // Files is a registry for looking up or iterating over files and the
    83  // descriptors contained within them.
    84  // The Find and Range methods are safe for concurrent use.
    85  type Files struct {
    86  	// The map of descsByName contains:
    87  	//	EnumDescriptor
    88  	//	EnumValueDescriptor
    89  	//	MessageDescriptor
    90  	//	ExtensionDescriptor
    91  	//	ServiceDescriptor
    92  	//	*packageDescriptor
    93  	//
    94  	// Note that files are stored as a slice, since a package may contain
    95  	// multiple files. Only top-level declarations are registered.
    96  	// Note that enum values are in the top-level since that are in the same
    97  	// scope as the parent enum.
    98  	descsByName map[protoreflect.FullName]interface{}
    99  	filesByPath map[string][]protoreflect.FileDescriptor
   100  	numFiles    int
   101  }
   102  
   103  type packageDescriptor struct {
   104  	files []protoreflect.FileDescriptor
   105  }
   106  
   107  // RegisterFile registers the provided file descriptor.
   108  //
   109  // If any descriptor within the file conflicts with the descriptor of any
   110  // previously registered file (e.g., two enums with the same full name),
   111  // then the file is not registered and an error is returned.
   112  //
   113  // It is permitted for multiple files to have the same file path.
   114  func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
   115  	if r == GlobalFiles {
   116  		globalMutex.Lock()
   117  		defer globalMutex.Unlock()
   118  	}
   119  	if r.descsByName == nil {
   120  		r.descsByName = map[protoreflect.FullName]interface{}{
   121  			"": &packageDescriptor{},
   122  		}
   123  		r.filesByPath = make(map[string][]protoreflect.FileDescriptor)
   124  	}
   125  	path := file.Path()
   126  	if prev := r.filesByPath[path]; len(prev) > 0 {
   127  		r.checkGenProtoConflict(path)
   128  		err := errors.New("file %q is already registered", file.Path())
   129  		err = amendErrorWithCaller(err, prev[0], file)
   130  		if !(r == GlobalFiles && ignoreConflict(file, err)) {
   131  			return err
   132  		}
   133  	}
   134  
   135  	for name := file.Package(); name != ""; name = name.Parent() {
   136  		switch prev := r.descsByName[name]; prev.(type) {
   137  		case nil, *packageDescriptor:
   138  		default:
   139  			err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
   140  			err = amendErrorWithCaller(err, prev, file)
   141  			if r == GlobalFiles && ignoreConflict(file, err) {
   142  				err = nil
   143  			}
   144  			return err
   145  		}
   146  	}
   147  	var err error
   148  	var hasConflict bool
   149  	rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
   150  		if prev := r.descsByName[d.FullName()]; prev != nil {
   151  			hasConflict = true
   152  			err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
   153  			err = amendErrorWithCaller(err, prev, file)
   154  			if r == GlobalFiles && ignoreConflict(d, err) {
   155  				err = nil
   156  			}
   157  		}
   158  	})
   159  	if hasConflict {
   160  		return err
   161  	}
   162  
   163  	for name := file.Package(); name != ""; name = name.Parent() {
   164  		if r.descsByName[name] == nil {
   165  			r.descsByName[name] = &packageDescriptor{}
   166  		}
   167  	}
   168  	p := r.descsByName[file.Package()].(*packageDescriptor)
   169  	p.files = append(p.files, file)
   170  	rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
   171  		r.descsByName[d.FullName()] = d
   172  	})
   173  	r.filesByPath[path] = append(r.filesByPath[path], file)
   174  	r.numFiles++
   175  	return nil
   176  }
   177  
   178  // Several well-known types were hosted in the google.golang.org/genproto module
   179  // but were later moved to this module. To avoid a weak dependency on the
   180  // genproto module (and its relatively large set of transitive dependencies),
   181  // we rely on a registration conflict to determine whether the genproto version
   182  // is too old (i.e., does not contain aliases to the new type declarations).
   183  func (r *Files) checkGenProtoConflict(path string) {
   184  	if r != GlobalFiles {
   185  		return
   186  	}
   187  	var prevPath string
   188  	const prevModule = "google.golang.org/genproto"
   189  	const prevVersion = "cb27e3aa (May 26th, 2020)"
   190  	switch path {
   191  	case "google/protobuf/field_mask.proto":
   192  		prevPath = prevModule + "/protobuf/field_mask"
   193  	case "google/protobuf/api.proto":
   194  		prevPath = prevModule + "/protobuf/api"
   195  	case "google/protobuf/type.proto":
   196  		prevPath = prevModule + "/protobuf/ptype"
   197  	case "google/protobuf/source_context.proto":
   198  		prevPath = prevModule + "/protobuf/source_context"
   199  	default:
   200  		return
   201  	}
   202  	pkgName := strings.TrimSuffix(strings.TrimPrefix(path, "google/protobuf/"), ".proto")
   203  	pkgName = strings.Replace(pkgName, "_", "", -1) + "pb" // e.g., "field_mask" => "fieldmaskpb"
   204  	currPath := "google.golang.org/protobuf/types/known/" + pkgName
   205  	panic(fmt.Sprintf(""+
   206  		"duplicate registration of %q\n"+
   207  		"\n"+
   208  		"The generated definition for this file has moved:\n"+
   209  		"\tfrom: %q\n"+
   210  		"\tto:   %q\n"+
   211  		"A dependency on the %q module must\n"+
   212  		"be at version %v or higher.\n"+
   213  		"\n"+
   214  		"Upgrade the dependency by running:\n"+
   215  		"\tgo get -u %v\n",
   216  		path, prevPath, currPath, prevModule, prevVersion, prevPath))
   217  }
   218  
   219  // FindDescriptorByName looks up a descriptor by the full name.
   220  //
   221  // This returns (nil, [NotFound]) if not found.
   222  func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
   223  	if r == nil {
   224  		return nil, NotFound
   225  	}
   226  	if r == GlobalFiles {
   227  		globalMutex.RLock()
   228  		defer globalMutex.RUnlock()
   229  	}
   230  	prefix := name
   231  	suffix := nameSuffix("")
   232  	for prefix != "" {
   233  		if d, ok := r.descsByName[prefix]; ok {
   234  			switch d := d.(type) {
   235  			case protoreflect.EnumDescriptor:
   236  				if d.FullName() == name {
   237  					return d, nil
   238  				}
   239  			case protoreflect.EnumValueDescriptor:
   240  				if d.FullName() == name {
   241  					return d, nil
   242  				}
   243  			case protoreflect.MessageDescriptor:
   244  				if d.FullName() == name {
   245  					return d, nil
   246  				}
   247  				if d := findDescriptorInMessage(d, suffix); d != nil && d.FullName() == name {
   248  					return d, nil
   249  				}
   250  			case protoreflect.ExtensionDescriptor:
   251  				if d.FullName() == name {
   252  					return d, nil
   253  				}
   254  			case protoreflect.ServiceDescriptor:
   255  				if d.FullName() == name {
   256  					return d, nil
   257  				}
   258  				if d := d.Methods().ByName(suffix.Pop()); d != nil && d.FullName() == name {
   259  					return d, nil
   260  				}
   261  			}
   262  			return nil, NotFound
   263  		}
   264  		prefix = prefix.Parent()
   265  		suffix = nameSuffix(name[len(prefix)+len("."):])
   266  	}
   267  	return nil, NotFound
   268  }
   269  
   270  func findDescriptorInMessage(md protoreflect.MessageDescriptor, suffix nameSuffix) protoreflect.Descriptor {
   271  	name := suffix.Pop()
   272  	if suffix == "" {
   273  		if ed := md.Enums().ByName(name); ed != nil {
   274  			return ed
   275  		}
   276  		for i := md.Enums().Len() - 1; i >= 0; i-- {
   277  			if vd := md.Enums().Get(i).Values().ByName(name); vd != nil {
   278  				return vd
   279  			}
   280  		}
   281  		if xd := md.Extensions().ByName(name); xd != nil {
   282  			return xd
   283  		}
   284  		if fd := md.Fields().ByName(name); fd != nil {
   285  			return fd
   286  		}
   287  		if od := md.Oneofs().ByName(name); od != nil {
   288  			return od
   289  		}
   290  	}
   291  	if md := md.Messages().ByName(name); md != nil {
   292  		if suffix == "" {
   293  			return md
   294  		}
   295  		return findDescriptorInMessage(md, suffix)
   296  	}
   297  	return nil
   298  }
   299  
   300  type nameSuffix string
   301  
   302  func (s *nameSuffix) Pop() (name protoreflect.Name) {
   303  	if i := strings.IndexByte(string(*s), '.'); i >= 0 {
   304  		name, *s = protoreflect.Name((*s)[:i]), (*s)[i+1:]
   305  	} else {
   306  		name, *s = protoreflect.Name((*s)), ""
   307  	}
   308  	return name
   309  }
   310  
   311  // FindFileByPath looks up a file by the path.
   312  //
   313  // This returns (nil, [NotFound]) if not found.
   314  // This returns an error if multiple files have the same path.
   315  func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
   316  	if r == nil {
   317  		return nil, NotFound
   318  	}
   319  	if r == GlobalFiles {
   320  		globalMutex.RLock()
   321  		defer globalMutex.RUnlock()
   322  	}
   323  	fds := r.filesByPath[path]
   324  	switch len(fds) {
   325  	case 0:
   326  		return nil, NotFound
   327  	case 1:
   328  		return fds[0], nil
   329  	default:
   330  		return nil, errors.New("multiple files named %q", path)
   331  	}
   332  }
   333  
   334  // NumFiles reports the number of registered files,
   335  // including duplicate files with the same name.
   336  func (r *Files) NumFiles() int {
   337  	if r == nil {
   338  		return 0
   339  	}
   340  	if r == GlobalFiles {
   341  		globalMutex.RLock()
   342  		defer globalMutex.RUnlock()
   343  	}
   344  	return r.numFiles
   345  }
   346  
   347  // RangeFiles iterates over all registered files while f returns true.
   348  // If multiple files have the same name, RangeFiles iterates over all of them.
   349  // The iteration order is undefined.
   350  func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
   351  	if r == nil {
   352  		return
   353  	}
   354  	if r == GlobalFiles {
   355  		globalMutex.RLock()
   356  		defer globalMutex.RUnlock()
   357  	}
   358  	for _, files := range r.filesByPath {
   359  		for _, file := range files {
   360  			if !f(file) {
   361  				return
   362  			}
   363  		}
   364  	}
   365  }
   366  
   367  // NumFilesByPackage reports the number of registered files in a proto package.
   368  func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
   369  	if r == nil {
   370  		return 0
   371  	}
   372  	if r == GlobalFiles {
   373  		globalMutex.RLock()
   374  		defer globalMutex.RUnlock()
   375  	}
   376  	p, ok := r.descsByName[name].(*packageDescriptor)
   377  	if !ok {
   378  		return 0
   379  	}
   380  	return len(p.files)
   381  }
   382  
   383  // RangeFilesByPackage iterates over all registered files in a given proto package
   384  // while f returns true. The iteration order is undefined.
   385  func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
   386  	if r == nil {
   387  		return
   388  	}
   389  	if r == GlobalFiles {
   390  		globalMutex.RLock()
   391  		defer globalMutex.RUnlock()
   392  	}
   393  	p, ok := r.descsByName[name].(*packageDescriptor)
   394  	if !ok {
   395  		return
   396  	}
   397  	for _, file := range p.files {
   398  		if !f(file) {
   399  			return
   400  		}
   401  	}
   402  }
   403  
   404  // rangeTopLevelDescriptors iterates over all top-level descriptors in a file
   405  // which will be directly entered into the registry.
   406  func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflect.Descriptor)) {
   407  	eds := fd.Enums()
   408  	for i := eds.Len() - 1; i >= 0; i-- {
   409  		f(eds.Get(i))
   410  		vds := eds.Get(i).Values()
   411  		for i := vds.Len() - 1; i >= 0; i-- {
   412  			f(vds.Get(i))
   413  		}
   414  	}
   415  	mds := fd.Messages()
   416  	for i := mds.Len() - 1; i >= 0; i-- {
   417  		f(mds.Get(i))
   418  	}
   419  	xds := fd.Extensions()
   420  	for i := xds.Len() - 1; i >= 0; i-- {
   421  		f(xds.Get(i))
   422  	}
   423  	sds := fd.Services()
   424  	for i := sds.Len() - 1; i >= 0; i-- {
   425  		f(sds.Get(i))
   426  	}
   427  }
   428  
   429  // MessageTypeResolver is an interface for looking up messages.
   430  //
   431  // A compliant implementation must deterministically return the same type
   432  // if no error is encountered.
   433  //
   434  // The [Types] type implements this interface.
   435  type MessageTypeResolver interface {
   436  	// FindMessageByName looks up a message by its full name.
   437  	// E.g., "google.protobuf.Any"
   438  	//
   439  	// This return (nil, NotFound) if not found.
   440  	FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error)
   441  
   442  	// FindMessageByURL looks up a message by a URL identifier.
   443  	// See documentation on google.protobuf.Any.type_url for the URL format.
   444  	//
   445  	// This returns (nil, NotFound) if not found.
   446  	FindMessageByURL(url string) (protoreflect.MessageType, error)
   447  }
   448  
   449  // ExtensionTypeResolver is an interface for looking up extensions.
   450  //
   451  // A compliant implementation must deterministically return the same type
   452  // if no error is encountered.
   453  //
   454  // The [Types] type implements this interface.
   455  type ExtensionTypeResolver interface {
   456  	// FindExtensionByName looks up a extension field by the field's full name.
   457  	// Note that this is the full name of the field as determined by
   458  	// where the extension is declared and is unrelated to the full name of the
   459  	// message being extended.
   460  	//
   461  	// This returns (nil, NotFound) if not found.
   462  	FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
   463  
   464  	// FindExtensionByNumber looks up a extension field by the field number
   465  	// within some parent message, identified by full name.
   466  	//
   467  	// This returns (nil, NotFound) if not found.
   468  	FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
   469  }
   470  
   471  var (
   472  	_ MessageTypeResolver   = (*Types)(nil)
   473  	_ ExtensionTypeResolver = (*Types)(nil)
   474  )
   475  
   476  // Types is a registry for looking up or iterating over descriptor types.
   477  // The Find and Range methods are safe for concurrent use.
   478  type Types struct {
   479  	typesByName         typesByName
   480  	extensionsByMessage extensionsByMessage
   481  
   482  	numEnums      int
   483  	numMessages   int
   484  	numExtensions int
   485  }
   486  
   487  type (
   488  	typesByName         map[protoreflect.FullName]interface{}
   489  	extensionsByMessage map[protoreflect.FullName]extensionsByNumber
   490  	extensionsByNumber  map[protoreflect.FieldNumber]protoreflect.ExtensionType
   491  )
   492  
   493  // RegisterMessage registers the provided message type.
   494  //
   495  // If a naming conflict occurs, the type is not registered and an error is returned.
   496  func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
   497  	// Under rare circumstances getting the descriptor might recursively
   498  	// examine the registry, so fetch it before locking.
   499  	md := mt.Descriptor()
   500  
   501  	if r == GlobalTypes {
   502  		globalMutex.Lock()
   503  		defer globalMutex.Unlock()
   504  	}
   505  
   506  	if err := r.register("message", md, mt); err != nil {
   507  		return err
   508  	}
   509  	r.numMessages++
   510  	return nil
   511  }
   512  
   513  // RegisterEnum registers the provided enum type.
   514  //
   515  // If a naming conflict occurs, the type is not registered and an error is returned.
   516  func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
   517  	// Under rare circumstances getting the descriptor might recursively
   518  	// examine the registry, so fetch it before locking.
   519  	ed := et.Descriptor()
   520  
   521  	if r == GlobalTypes {
   522  		globalMutex.Lock()
   523  		defer globalMutex.Unlock()
   524  	}
   525  
   526  	if err := r.register("enum", ed, et); err != nil {
   527  		return err
   528  	}
   529  	r.numEnums++
   530  	return nil
   531  }
   532  
   533  // RegisterExtension registers the provided extension type.
   534  //
   535  // If a naming conflict occurs, the type is not registered and an error is returned.
   536  func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
   537  	// Under rare circumstances getting the descriptor might recursively
   538  	// examine the registry, so fetch it before locking.
   539  	//
   540  	// A known case where this can happen: Fetching the TypeDescriptor for a
   541  	// legacy ExtensionDesc can consult the global registry.
   542  	xd := xt.TypeDescriptor()
   543  
   544  	if r == GlobalTypes {
   545  		globalMutex.Lock()
   546  		defer globalMutex.Unlock()
   547  	}
   548  
   549  	field := xd.Number()
   550  	message := xd.ContainingMessage().FullName()
   551  	if prev := r.extensionsByMessage[message][field]; prev != nil {
   552  		err := errors.New("extension number %d is already registered on message %v", field, message)
   553  		err = amendErrorWithCaller(err, prev, xt)
   554  		if !(r == GlobalTypes && ignoreConflict(xd, err)) {
   555  			return err
   556  		}
   557  	}
   558  
   559  	if err := r.register("extension", xd, xt); err != nil {
   560  		return err
   561  	}
   562  	if r.extensionsByMessage == nil {
   563  		r.extensionsByMessage = make(extensionsByMessage)
   564  	}
   565  	if r.extensionsByMessage[message] == nil {
   566  		r.extensionsByMessage[message] = make(extensionsByNumber)
   567  	}
   568  	r.extensionsByMessage[message][field] = xt
   569  	r.numExtensions++
   570  	return nil
   571  }
   572  
   573  func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
   574  	name := desc.FullName()
   575  	prev := r.typesByName[name]
   576  	if prev != nil {
   577  		err := errors.New("%v %v is already registered", kind, name)
   578  		err = amendErrorWithCaller(err, prev, typ)
   579  		if !(r == GlobalTypes && ignoreConflict(desc, err)) {
   580  			return err
   581  		}
   582  	}
   583  	if r.typesByName == nil {
   584  		r.typesByName = make(typesByName)
   585  	}
   586  	r.typesByName[name] = typ
   587  	return nil
   588  }
   589  
   590  // FindEnumByName looks up an enum by its full name.
   591  // E.g., "google.protobuf.Field.Kind".
   592  //
   593  // This returns (nil, [NotFound]) if not found.
   594  func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) {
   595  	if r == nil {
   596  		return nil, NotFound
   597  	}
   598  	if r == GlobalTypes {
   599  		globalMutex.RLock()
   600  		defer globalMutex.RUnlock()
   601  	}
   602  	if v := r.typesByName[enum]; v != nil {
   603  		if et, _ := v.(protoreflect.EnumType); et != nil {
   604  			return et, nil
   605  		}
   606  		return nil, errors.New("found wrong type: got %v, want enum", typeName(v))
   607  	}
   608  	return nil, NotFound
   609  }
   610  
   611  // FindMessageByName looks up a message by its full name,
   612  // e.g. "google.protobuf.Any".
   613  //
   614  // This returns (nil, [NotFound]) if not found.
   615  func (r *Types) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
   616  	if r == nil {
   617  		return nil, NotFound
   618  	}
   619  	if r == GlobalTypes {
   620  		globalMutex.RLock()
   621  		defer globalMutex.RUnlock()
   622  	}
   623  	if v := r.typesByName[message]; v != nil {
   624  		if mt, _ := v.(protoreflect.MessageType); mt != nil {
   625  			return mt, nil
   626  		}
   627  		return nil, errors.New("found wrong type: got %v, want message", typeName(v))
   628  	}
   629  	return nil, NotFound
   630  }
   631  
   632  // FindMessageByURL looks up a message by a URL identifier.
   633  // See documentation on google.protobuf.Any.type_url for the URL format.
   634  //
   635  // This returns (nil, [NotFound]) if not found.
   636  func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
   637  	// This function is similar to FindMessageByName but
   638  	// truncates anything before and including '/' in the URL.
   639  	if r == nil {
   640  		return nil, NotFound
   641  	}
   642  	if r == GlobalTypes {
   643  		globalMutex.RLock()
   644  		defer globalMutex.RUnlock()
   645  	}
   646  	message := protoreflect.FullName(url)
   647  	if i := strings.LastIndexByte(url, '/'); i >= 0 {
   648  		message = message[i+len("/"):]
   649  	}
   650  
   651  	if v := r.typesByName[message]; v != nil {
   652  		if mt, _ := v.(protoreflect.MessageType); mt != nil {
   653  			return mt, nil
   654  		}
   655  		return nil, errors.New("found wrong type: got %v, want message", typeName(v))
   656  	}
   657  	return nil, NotFound
   658  }
   659  
   660  // FindExtensionByName looks up a extension field by the field's full name.
   661  // Note that this is the full name of the field as determined by
   662  // where the extension is declared and is unrelated to the full name of the
   663  // message being extended.
   664  //
   665  // This returns (nil, [NotFound]) if not found.
   666  func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
   667  	if r == nil {
   668  		return nil, NotFound
   669  	}
   670  	if r == GlobalTypes {
   671  		globalMutex.RLock()
   672  		defer globalMutex.RUnlock()
   673  	}
   674  	if v := r.typesByName[field]; v != nil {
   675  		if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
   676  			return xt, nil
   677  		}
   678  
   679  		// MessageSet extensions are special in that the name of the extension
   680  		// is the name of the message type used to extend the MessageSet.
   681  		// This naming scheme is used by text and JSON serialization.
   682  		//
   683  		// This feature is protected by the ProtoLegacy flag since MessageSets
   684  		// are a proto1 feature that is long deprecated.
   685  		if flags.ProtoLegacy {
   686  			if _, ok := v.(protoreflect.MessageType); ok {
   687  				field := field.Append(messageset.ExtensionName)
   688  				if v := r.typesByName[field]; v != nil {
   689  					if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
   690  						if messageset.IsMessageSetExtension(xt.TypeDescriptor()) {
   691  							return xt, nil
   692  						}
   693  					}
   694  				}
   695  			}
   696  		}
   697  
   698  		return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
   699  	}
   700  	return nil, NotFound
   701  }
   702  
   703  // FindExtensionByNumber looks up a extension field by the field number
   704  // within some parent message, identified by full name.
   705  //
   706  // This returns (nil, [NotFound]) if not found.
   707  func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
   708  	if r == nil {
   709  		return nil, NotFound
   710  	}
   711  	if r == GlobalTypes {
   712  		globalMutex.RLock()
   713  		defer globalMutex.RUnlock()
   714  	}
   715  	if xt, ok := r.extensionsByMessage[message][field]; ok {
   716  		return xt, nil
   717  	}
   718  	return nil, NotFound
   719  }
   720  
   721  // NumEnums reports the number of registered enums.
   722  func (r *Types) NumEnums() int {
   723  	if r == nil {
   724  		return 0
   725  	}
   726  	if r == GlobalTypes {
   727  		globalMutex.RLock()
   728  		defer globalMutex.RUnlock()
   729  	}
   730  	return r.numEnums
   731  }
   732  
   733  // RangeEnums iterates over all registered enums while f returns true.
   734  // Iteration order is undefined.
   735  func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
   736  	if r == nil {
   737  		return
   738  	}
   739  	if r == GlobalTypes {
   740  		globalMutex.RLock()
   741  		defer globalMutex.RUnlock()
   742  	}
   743  	for _, typ := range r.typesByName {
   744  		if et, ok := typ.(protoreflect.EnumType); ok {
   745  			if !f(et) {
   746  				return
   747  			}
   748  		}
   749  	}
   750  }
   751  
   752  // NumMessages reports the number of registered messages.
   753  func (r *Types) NumMessages() int {
   754  	if r == nil {
   755  		return 0
   756  	}
   757  	if r == GlobalTypes {
   758  		globalMutex.RLock()
   759  		defer globalMutex.RUnlock()
   760  	}
   761  	return r.numMessages
   762  }
   763  
   764  // RangeMessages iterates over all registered messages while f returns true.
   765  // Iteration order is undefined.
   766  func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
   767  	if r == nil {
   768  		return
   769  	}
   770  	if r == GlobalTypes {
   771  		globalMutex.RLock()
   772  		defer globalMutex.RUnlock()
   773  	}
   774  	for _, typ := range r.typesByName {
   775  		if mt, ok := typ.(protoreflect.MessageType); ok {
   776  			if !f(mt) {
   777  				return
   778  			}
   779  		}
   780  	}
   781  }
   782  
   783  // NumExtensions reports the number of registered extensions.
   784  func (r *Types) NumExtensions() int {
   785  	if r == nil {
   786  		return 0
   787  	}
   788  	if r == GlobalTypes {
   789  		globalMutex.RLock()
   790  		defer globalMutex.RUnlock()
   791  	}
   792  	return r.numExtensions
   793  }
   794  
   795  // RangeExtensions iterates over all registered extensions while f returns true.
   796  // Iteration order is undefined.
   797  func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
   798  	if r == nil {
   799  		return
   800  	}
   801  	if r == GlobalTypes {
   802  		globalMutex.RLock()
   803  		defer globalMutex.RUnlock()
   804  	}
   805  	for _, typ := range r.typesByName {
   806  		if xt, ok := typ.(protoreflect.ExtensionType); ok {
   807  			if !f(xt) {
   808  				return
   809  			}
   810  		}
   811  	}
   812  }
   813  
   814  // NumExtensionsByMessage reports the number of registered extensions for
   815  // a given message type.
   816  func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
   817  	if r == nil {
   818  		return 0
   819  	}
   820  	if r == GlobalTypes {
   821  		globalMutex.RLock()
   822  		defer globalMutex.RUnlock()
   823  	}
   824  	return len(r.extensionsByMessage[message])
   825  }
   826  
   827  // RangeExtensionsByMessage iterates over all registered extensions filtered
   828  // by a given message type while f returns true. Iteration order is undefined.
   829  func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
   830  	if r == nil {
   831  		return
   832  	}
   833  	if r == GlobalTypes {
   834  		globalMutex.RLock()
   835  		defer globalMutex.RUnlock()
   836  	}
   837  	for _, xt := range r.extensionsByMessage[message] {
   838  		if !f(xt) {
   839  			return
   840  		}
   841  	}
   842  }
   843  
   844  func typeName(t interface{}) string {
   845  	switch t.(type) {
   846  	case protoreflect.EnumType:
   847  		return "enum"
   848  	case protoreflect.MessageType:
   849  		return "message"
   850  	case protoreflect.ExtensionType:
   851  		return "extension"
   852  	default:
   853  		return fmt.Sprintf("%T", t)
   854  	}
   855  }
   856  
   857  func amendErrorWithCaller(err error, prev, curr interface{}) error {
   858  	prevPkg := goPackage(prev)
   859  	currPkg := goPackage(curr)
   860  	if prevPkg == "" || currPkg == "" || prevPkg == currPkg {
   861  		return err
   862  	}
   863  	return errors.New("%s\n\tpreviously from: %q\n\tcurrently from:  %q", err, prevPkg, currPkg)
   864  }
   865  
   866  func goPackage(v interface{}) string {
   867  	switch d := v.(type) {
   868  	case protoreflect.EnumType:
   869  		v = d.Descriptor()
   870  	case protoreflect.MessageType:
   871  		v = d.Descriptor()
   872  	case protoreflect.ExtensionType:
   873  		v = d.TypeDescriptor()
   874  	}
   875  	if d, ok := v.(protoreflect.Descriptor); ok {
   876  		v = d.ParentFile()
   877  	}
   878  	if d, ok := v.(interface{ GoPackagePath() string }); ok {
   879  		return d.GoPackagePath()
   880  	}
   881  	return ""
   882  }
   883  

View as plain text