...

Source file src/google.golang.org/protobuf/internal/impl/checkinit.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"sync"
     9  
    10  	"google.golang.org/protobuf/internal/errors"
    11  	"google.golang.org/protobuf/reflect/protoreflect"
    12  	"google.golang.org/protobuf/runtime/protoiface"
    13  )
    14  
    15  func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) {
    16  	var p pointer
    17  	if ms, ok := in.Message.(*messageState); ok {
    18  		p = ms.pointer()
    19  	} else {
    20  		p = in.Message.(*messageReflectWrapper).pointer()
    21  	}
    22  	return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
    23  }
    24  
    25  func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
    26  	mi.init()
    27  	if !mi.needsInitCheck {
    28  		return nil
    29  	}
    30  	if p.IsNil() {
    31  		for _, f := range mi.orderedCoderFields {
    32  			if f.isRequired {
    33  				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
    34  			}
    35  		}
    36  		return nil
    37  	}
    38  	if mi.extensionOffset.IsValid() {
    39  		e := p.Apply(mi.extensionOffset).Extensions()
    40  		if err := mi.isInitExtensions(e); err != nil {
    41  			return err
    42  		}
    43  	}
    44  	for _, f := range mi.orderedCoderFields {
    45  		if !f.isRequired && f.funcs.isInit == nil {
    46  			continue
    47  		}
    48  		fptr := p.Apply(f.offset)
    49  		if f.isPointer && fptr.Elem().IsNil() {
    50  			if f.isRequired {
    51  				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
    52  			}
    53  			continue
    54  		}
    55  		if f.funcs.isInit == nil {
    56  			continue
    57  		}
    58  		if err := f.funcs.isInit(fptr, f); err != nil {
    59  			return err
    60  		}
    61  	}
    62  	return nil
    63  }
    64  
    65  func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
    66  	if ext == nil {
    67  		return nil
    68  	}
    69  	for _, x := range *ext {
    70  		ei := getExtensionFieldInfo(x.Type())
    71  		if ei.funcs.isInit == nil {
    72  			continue
    73  		}
    74  		v := x.Value()
    75  		if !v.IsValid() {
    76  			continue
    77  		}
    78  		if err := ei.funcs.isInit(v); err != nil {
    79  			return err
    80  		}
    81  	}
    82  	return nil
    83  }
    84  
    85  var (
    86  	needsInitCheckMu  sync.Mutex
    87  	needsInitCheckMap sync.Map
    88  )
    89  
    90  // needsInitCheck reports whether a message needs to be checked for partial initialization.
    91  //
    92  // It returns true if the message transitively includes any required or extension fields.
    93  func needsInitCheck(md protoreflect.MessageDescriptor) bool {
    94  	if v, ok := needsInitCheckMap.Load(md); ok {
    95  		if has, ok := v.(bool); ok {
    96  			return has
    97  		}
    98  	}
    99  	needsInitCheckMu.Lock()
   100  	defer needsInitCheckMu.Unlock()
   101  	return needsInitCheckLocked(md)
   102  }
   103  
   104  func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) {
   105  	if v, ok := needsInitCheckMap.Load(md); ok {
   106  		// If has is true, we've previously determined that this message
   107  		// needs init checks.
   108  		//
   109  		// If has is false, we've previously determined that it can never
   110  		// be uninitialized.
   111  		//
   112  		// If has is not a bool, we've just encountered a cycle in the
   113  		// message graph. In this case, it is safe to return false: If
   114  		// the message does have required fields, we'll detect them later
   115  		// in the graph traversal.
   116  		has, ok := v.(bool)
   117  		return ok && has
   118  	}
   119  	needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
   120  	defer func() {
   121  		needsInitCheckMap.Store(md, has)
   122  	}()
   123  	if md.RequiredNumbers().Len() > 0 {
   124  		return true
   125  	}
   126  	if md.ExtensionRanges().Len() > 0 {
   127  		return true
   128  	}
   129  	for i := 0; i < md.Fields().Len(); i++ {
   130  		fd := md.Fields().Get(i)
   131  		// Map keys are never messages, so just consider the map value.
   132  		if fd.IsMap() {
   133  			fd = fd.MapValue()
   134  		}
   135  		fmd := fd.Message()
   136  		if fmd != nil && needsInitCheckLocked(fmd) {
   137  			return true
   138  		}
   139  	}
   140  	return false
   141  }
   142  

View as plain text