...

Source file src/google.golang.org/protobuf/internal/impl/codec_map.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"reflect"
     9  	"sort"
    10  
    11  	"google.golang.org/protobuf/encoding/protowire"
    12  	"google.golang.org/protobuf/internal/genid"
    13  	"google.golang.org/protobuf/reflect/protoreflect"
    14  )
    15  
    16  type mapInfo struct {
    17  	goType     reflect.Type
    18  	keyWiretag uint64
    19  	valWiretag uint64
    20  	keyFuncs   valueCoderFuncs
    21  	valFuncs   valueCoderFuncs
    22  	keyZero    protoreflect.Value
    23  	keyKind    protoreflect.Kind
    24  	conv       *mapConverter
    25  }
    26  
    27  func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
    28  	// TODO: Consider generating specialized map coders.
    29  	keyField := fd.MapKey()
    30  	valField := fd.MapValue()
    31  	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
    32  	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
    33  	keyFuncs := encoderFuncsForValue(keyField)
    34  	valFuncs := encoderFuncsForValue(valField)
    35  	conv := newMapConverter(ft, fd)
    36  
    37  	mapi := &mapInfo{
    38  		goType:     ft,
    39  		keyWiretag: keyWiretag,
    40  		valWiretag: valWiretag,
    41  		keyFuncs:   keyFuncs,
    42  		valFuncs:   valFuncs,
    43  		keyZero:    keyField.Default(),
    44  		keyKind:    keyField.Kind(),
    45  		conv:       conv,
    46  	}
    47  	if valField.Kind() == protoreflect.MessageKind {
    48  		valueMessage = getMessageInfo(ft.Elem())
    49  	}
    50  
    51  	funcs = pointerCoderFuncs{
    52  		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
    53  			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
    54  		},
    55  		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
    56  			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
    57  		},
    58  		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
    59  			mp := p.AsValueOf(ft)
    60  			if mp.Elem().IsNil() {
    61  				mp.Elem().Set(reflect.MakeMap(mapi.goType))
    62  			}
    63  			if f.mi == nil {
    64  				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
    65  			} else {
    66  				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
    67  			}
    68  		},
    69  	}
    70  	switch valField.Kind() {
    71  	case protoreflect.MessageKind:
    72  		funcs.merge = mergeMapOfMessage
    73  	case protoreflect.BytesKind:
    74  		funcs.merge = mergeMapOfBytes
    75  	default:
    76  		funcs.merge = mergeMap
    77  	}
    78  	if valFuncs.isInit != nil {
    79  		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
    80  			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
    81  		}
    82  	}
    83  	return valueMessage, funcs
    84  }
    85  
    86  const (
    87  	mapKeyTagSize = 1 // field 1, tag size 1.
    88  	mapValTagSize = 1 // field 2, tag size 2.
    89  )
    90  
    91  func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
    92  	if mapv.Len() == 0 {
    93  		return 0
    94  	}
    95  	n := 0
    96  	iter := mapRange(mapv)
    97  	for iter.Next() {
    98  		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
    99  		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
   100  		var valSize int
   101  		value := mapi.conv.valConv.PBValueOf(iter.Value())
   102  		if f.mi == nil {
   103  			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
   104  		} else {
   105  			p := pointerOfValue(iter.Value())
   106  			valSize += mapValTagSize
   107  			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
   108  		}
   109  		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
   110  	}
   111  	return n
   112  }
   113  
   114  func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   115  	if wtyp != protowire.BytesType {
   116  		return out, errUnknown
   117  	}
   118  	b, n := protowire.ConsumeBytes(b)
   119  	if n < 0 {
   120  		return out, errDecode
   121  	}
   122  	var (
   123  		key = mapi.keyZero
   124  		val = mapi.conv.valConv.New()
   125  	)
   126  	for len(b) > 0 {
   127  		num, wtyp, n := protowire.ConsumeTag(b)
   128  		if n < 0 {
   129  			return out, errDecode
   130  		}
   131  		if num > protowire.MaxValidNumber {
   132  			return out, errDecode
   133  		}
   134  		b = b[n:]
   135  		err := errUnknown
   136  		switch num {
   137  		case genid.MapEntry_Key_field_number:
   138  			var v protoreflect.Value
   139  			var o unmarshalOutput
   140  			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
   141  			if err != nil {
   142  				break
   143  			}
   144  			key = v
   145  			n = o.n
   146  		case genid.MapEntry_Value_field_number:
   147  			var v protoreflect.Value
   148  			var o unmarshalOutput
   149  			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
   150  			if err != nil {
   151  				break
   152  			}
   153  			val = v
   154  			n = o.n
   155  		}
   156  		if err == errUnknown {
   157  			n = protowire.ConsumeFieldValue(num, wtyp, b)
   158  			if n < 0 {
   159  				return out, errDecode
   160  			}
   161  		} else if err != nil {
   162  			return out, err
   163  		}
   164  		b = b[n:]
   165  	}
   166  	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
   167  	out.n = n
   168  	return out, nil
   169  }
   170  
   171  func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   172  	if wtyp != protowire.BytesType {
   173  		return out, errUnknown
   174  	}
   175  	b, n := protowire.ConsumeBytes(b)
   176  	if n < 0 {
   177  		return out, errDecode
   178  	}
   179  	var (
   180  		key = mapi.keyZero
   181  		val = reflect.New(f.mi.GoReflectType.Elem())
   182  	)
   183  	for len(b) > 0 {
   184  		num, wtyp, n := protowire.ConsumeTag(b)
   185  		if n < 0 {
   186  			return out, errDecode
   187  		}
   188  		if num > protowire.MaxValidNumber {
   189  			return out, errDecode
   190  		}
   191  		b = b[n:]
   192  		err := errUnknown
   193  		switch num {
   194  		case 1:
   195  			var v protoreflect.Value
   196  			var o unmarshalOutput
   197  			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
   198  			if err != nil {
   199  				break
   200  			}
   201  			key = v
   202  			n = o.n
   203  		case 2:
   204  			if wtyp != protowire.BytesType {
   205  				break
   206  			}
   207  			var v []byte
   208  			v, n = protowire.ConsumeBytes(b)
   209  			if n < 0 {
   210  				return out, errDecode
   211  			}
   212  			var o unmarshalOutput
   213  			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
   214  			if o.initialized {
   215  				// Consider this map item initialized so long as we see
   216  				// an initialized value.
   217  				out.initialized = true
   218  			}
   219  		}
   220  		if err == errUnknown {
   221  			n = protowire.ConsumeFieldValue(num, wtyp, b)
   222  			if n < 0 {
   223  				return out, errDecode
   224  			}
   225  		} else if err != nil {
   226  			return out, err
   227  		}
   228  		b = b[n:]
   229  	}
   230  	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
   231  	out.n = n
   232  	return out, nil
   233  }
   234  
   235  func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   236  	if f.mi == nil {
   237  		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
   238  		val := mapi.conv.valConv.PBValueOf(valrv)
   239  		size := 0
   240  		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
   241  		size += mapi.valFuncs.size(val, mapValTagSize, opts)
   242  		b = protowire.AppendVarint(b, uint64(size))
   243  		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
   244  		if err != nil {
   245  			return nil, err
   246  		}
   247  		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
   248  	} else {
   249  		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
   250  		val := pointerOfValue(valrv)
   251  		valSize := f.mi.sizePointer(val, opts)
   252  		size := 0
   253  		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
   254  		size += mapValTagSize + protowire.SizeBytes(valSize)
   255  		b = protowire.AppendVarint(b, uint64(size))
   256  		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  		b = protowire.AppendVarint(b, mapi.valWiretag)
   261  		b = protowire.AppendVarint(b, uint64(valSize))
   262  		return f.mi.marshalAppendPointer(b, val, opts)
   263  	}
   264  }
   265  
   266  func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   267  	if mapv.Len() == 0 {
   268  		return b, nil
   269  	}
   270  	if opts.Deterministic() {
   271  		return appendMapDeterministic(b, mapv, mapi, f, opts)
   272  	}
   273  	iter := mapRange(mapv)
   274  	for iter.Next() {
   275  		var err error
   276  		b = protowire.AppendVarint(b, f.wiretag)
   277  		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
   278  		if err != nil {
   279  			return b, err
   280  		}
   281  	}
   282  	return b, nil
   283  }
   284  
   285  func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   286  	keys := mapv.MapKeys()
   287  	sort.Slice(keys, func(i, j int) bool {
   288  		switch keys[i].Kind() {
   289  		case reflect.Bool:
   290  			return !keys[i].Bool() && keys[j].Bool()
   291  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   292  			return keys[i].Int() < keys[j].Int()
   293  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   294  			return keys[i].Uint() < keys[j].Uint()
   295  		case reflect.Float32, reflect.Float64:
   296  			return keys[i].Float() < keys[j].Float()
   297  		case reflect.String:
   298  			return keys[i].String() < keys[j].String()
   299  		default:
   300  			panic("invalid kind: " + keys[i].Kind().String())
   301  		}
   302  	})
   303  	for _, key := range keys {
   304  		var err error
   305  		b = protowire.AppendVarint(b, f.wiretag)
   306  		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
   307  		if err != nil {
   308  			return b, err
   309  		}
   310  	}
   311  	return b, nil
   312  }
   313  
   314  func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
   315  	if mi := f.mi; mi != nil {
   316  		mi.init()
   317  		if !mi.needsInitCheck {
   318  			return nil
   319  		}
   320  		iter := mapRange(mapv)
   321  		for iter.Next() {
   322  			val := pointerOfValue(iter.Value())
   323  			if err := mi.checkInitializedPointer(val); err != nil {
   324  				return err
   325  			}
   326  		}
   327  	} else {
   328  		iter := mapRange(mapv)
   329  		for iter.Next() {
   330  			val := mapi.conv.valConv.PBValueOf(iter.Value())
   331  			if err := mapi.valFuncs.isInit(val); err != nil {
   332  				return err
   333  			}
   334  		}
   335  	}
   336  	return nil
   337  }
   338  
   339  func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   340  	dstm := dst.AsValueOf(f.ft).Elem()
   341  	srcm := src.AsValueOf(f.ft).Elem()
   342  	if srcm.Len() == 0 {
   343  		return
   344  	}
   345  	if dstm.IsNil() {
   346  		dstm.Set(reflect.MakeMap(f.ft))
   347  	}
   348  	iter := mapRange(srcm)
   349  	for iter.Next() {
   350  		dstm.SetMapIndex(iter.Key(), iter.Value())
   351  	}
   352  }
   353  
   354  func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   355  	dstm := dst.AsValueOf(f.ft).Elem()
   356  	srcm := src.AsValueOf(f.ft).Elem()
   357  	if srcm.Len() == 0 {
   358  		return
   359  	}
   360  	if dstm.IsNil() {
   361  		dstm.Set(reflect.MakeMap(f.ft))
   362  	}
   363  	iter := mapRange(srcm)
   364  	for iter.Next() {
   365  		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
   366  	}
   367  }
   368  
   369  func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   370  	dstm := dst.AsValueOf(f.ft).Elem()
   371  	srcm := src.AsValueOf(f.ft).Elem()
   372  	if srcm.Len() == 0 {
   373  		return
   374  	}
   375  	if dstm.IsNil() {
   376  		dstm.Set(reflect.MakeMap(f.ft))
   377  	}
   378  	iter := mapRange(srcm)
   379  	for iter.Next() {
   380  		val := reflect.New(f.ft.Elem().Elem())
   381  		if f.mi != nil {
   382  			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
   383  		} else {
   384  			opts.Merge(asMessage(val), asMessage(iter.Value()))
   385  		}
   386  		dstm.SetMapIndex(iter.Key(), val)
   387  	}
   388  }
   389  

View as plain text