...

Source file src/github.com/json-iterator/go/reflect_extension.go

Documentation: github.com/json-iterator/go

     1  package jsoniter
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/modern-go/reflect2"
     6  	"reflect"
     7  	"sort"
     8  	"strings"
     9  	"unicode"
    10  	"unsafe"
    11  )
    12  
    13  var typeDecoders = map[string]ValDecoder{}
    14  var fieldDecoders = map[string]ValDecoder{}
    15  var typeEncoders = map[string]ValEncoder{}
    16  var fieldEncoders = map[string]ValEncoder{}
    17  var extensions = []Extension{}
    18  
    19  // StructDescriptor describe how should we encode/decode the struct
    20  type StructDescriptor struct {
    21  	Type   reflect2.Type
    22  	Fields []*Binding
    23  }
    24  
    25  // GetField get one field from the descriptor by its name.
    26  // Can not use map here to keep field orders.
    27  func (structDescriptor *StructDescriptor) GetField(fieldName string) *Binding {
    28  	for _, binding := range structDescriptor.Fields {
    29  		if binding.Field.Name() == fieldName {
    30  			return binding
    31  		}
    32  	}
    33  	return nil
    34  }
    35  
    36  // Binding describe how should we encode/decode the struct field
    37  type Binding struct {
    38  	levels    []int
    39  	Field     reflect2.StructField
    40  	FromNames []string
    41  	ToNames   []string
    42  	Encoder   ValEncoder
    43  	Decoder   ValDecoder
    44  }
    45  
    46  // Extension the one for all SPI. Customize encoding/decoding by specifying alternate encoder/decoder.
    47  // Can also rename fields by UpdateStructDescriptor.
    48  type Extension interface {
    49  	UpdateStructDescriptor(structDescriptor *StructDescriptor)
    50  	CreateMapKeyDecoder(typ reflect2.Type) ValDecoder
    51  	CreateMapKeyEncoder(typ reflect2.Type) ValEncoder
    52  	CreateDecoder(typ reflect2.Type) ValDecoder
    53  	CreateEncoder(typ reflect2.Type) ValEncoder
    54  	DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder
    55  	DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder
    56  }
    57  
    58  // DummyExtension embed this type get dummy implementation for all methods of Extension
    59  type DummyExtension struct {
    60  }
    61  
    62  // UpdateStructDescriptor No-op
    63  func (extension *DummyExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
    64  }
    65  
    66  // CreateMapKeyDecoder No-op
    67  func (extension *DummyExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
    68  	return nil
    69  }
    70  
    71  // CreateMapKeyEncoder No-op
    72  func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
    73  	return nil
    74  }
    75  
    76  // CreateDecoder No-op
    77  func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
    78  	return nil
    79  }
    80  
    81  // CreateEncoder No-op
    82  func (extension *DummyExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
    83  	return nil
    84  }
    85  
    86  // DecorateDecoder No-op
    87  func (extension *DummyExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
    88  	return decoder
    89  }
    90  
    91  // DecorateEncoder No-op
    92  func (extension *DummyExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder {
    93  	return encoder
    94  }
    95  
    96  type EncoderExtension map[reflect2.Type]ValEncoder
    97  
    98  // UpdateStructDescriptor No-op
    99  func (extension EncoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
   100  }
   101  
   102  // CreateDecoder No-op
   103  func (extension EncoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
   104  	return nil
   105  }
   106  
   107  // CreateEncoder get encoder from map
   108  func (extension EncoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
   109  	return extension[typ]
   110  }
   111  
   112  // CreateMapKeyDecoder No-op
   113  func (extension EncoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
   114  	return nil
   115  }
   116  
   117  // CreateMapKeyEncoder No-op
   118  func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
   119  	return nil
   120  }
   121  
   122  // DecorateDecoder No-op
   123  func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
   124  	return decoder
   125  }
   126  
   127  // DecorateEncoder No-op
   128  func (extension EncoderExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder {
   129  	return encoder
   130  }
   131  
   132  type DecoderExtension map[reflect2.Type]ValDecoder
   133  
   134  // UpdateStructDescriptor No-op
   135  func (extension DecoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
   136  }
   137  
   138  // CreateMapKeyDecoder No-op
   139  func (extension DecoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
   140  	return nil
   141  }
   142  
   143  // CreateMapKeyEncoder No-op
   144  func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
   145  	return nil
   146  }
   147  
   148  // CreateDecoder get decoder from map
   149  func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
   150  	return extension[typ]
   151  }
   152  
   153  // CreateEncoder No-op
   154  func (extension DecoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
   155  	return nil
   156  }
   157  
   158  // DecorateDecoder No-op
   159  func (extension DecoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
   160  	return decoder
   161  }
   162  
   163  // DecorateEncoder No-op
   164  func (extension DecoderExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder {
   165  	return encoder
   166  }
   167  
   168  type funcDecoder struct {
   169  	fun DecoderFunc
   170  }
   171  
   172  func (decoder *funcDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
   173  	decoder.fun(ptr, iter)
   174  }
   175  
   176  type funcEncoder struct {
   177  	fun         EncoderFunc
   178  	isEmptyFunc func(ptr unsafe.Pointer) bool
   179  }
   180  
   181  func (encoder *funcEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
   182  	encoder.fun(ptr, stream)
   183  }
   184  
   185  func (encoder *funcEncoder) IsEmpty(ptr unsafe.Pointer) bool {
   186  	if encoder.isEmptyFunc == nil {
   187  		return false
   188  	}
   189  	return encoder.isEmptyFunc(ptr)
   190  }
   191  
   192  // DecoderFunc the function form of TypeDecoder
   193  type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator)
   194  
   195  // EncoderFunc the function form of TypeEncoder
   196  type EncoderFunc func(ptr unsafe.Pointer, stream *Stream)
   197  
   198  // RegisterTypeDecoderFunc register TypeDecoder for a type with function
   199  func RegisterTypeDecoderFunc(typ string, fun DecoderFunc) {
   200  	typeDecoders[typ] = &funcDecoder{fun}
   201  }
   202  
   203  // RegisterTypeDecoder register TypeDecoder for a typ
   204  func RegisterTypeDecoder(typ string, decoder ValDecoder) {
   205  	typeDecoders[typ] = decoder
   206  }
   207  
   208  // RegisterFieldDecoderFunc register TypeDecoder for a struct field with function
   209  func RegisterFieldDecoderFunc(typ string, field string, fun DecoderFunc) {
   210  	RegisterFieldDecoder(typ, field, &funcDecoder{fun})
   211  }
   212  
   213  // RegisterFieldDecoder register TypeDecoder for a struct field
   214  func RegisterFieldDecoder(typ string, field string, decoder ValDecoder) {
   215  	fieldDecoders[fmt.Sprintf("%s/%s", typ, field)] = decoder
   216  }
   217  
   218  // RegisterTypeEncoderFunc register TypeEncoder for a type with encode/isEmpty function
   219  func RegisterTypeEncoderFunc(typ string, fun EncoderFunc, isEmptyFunc func(unsafe.Pointer) bool) {
   220  	typeEncoders[typ] = &funcEncoder{fun, isEmptyFunc}
   221  }
   222  
   223  // RegisterTypeEncoder register TypeEncoder for a type
   224  func RegisterTypeEncoder(typ string, encoder ValEncoder) {
   225  	typeEncoders[typ] = encoder
   226  }
   227  
   228  // RegisterFieldEncoderFunc register TypeEncoder for a struct field with encode/isEmpty function
   229  func RegisterFieldEncoderFunc(typ string, field string, fun EncoderFunc, isEmptyFunc func(unsafe.Pointer) bool) {
   230  	RegisterFieldEncoder(typ, field, &funcEncoder{fun, isEmptyFunc})
   231  }
   232  
   233  // RegisterFieldEncoder register TypeEncoder for a struct field
   234  func RegisterFieldEncoder(typ string, field string, encoder ValEncoder) {
   235  	fieldEncoders[fmt.Sprintf("%s/%s", typ, field)] = encoder
   236  }
   237  
   238  // RegisterExtension register extension
   239  func RegisterExtension(extension Extension) {
   240  	extensions = append(extensions, extension)
   241  }
   242  
   243  func getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
   244  	decoder := _getTypeDecoderFromExtension(ctx, typ)
   245  	if decoder != nil {
   246  		for _, extension := range extensions {
   247  			decoder = extension.DecorateDecoder(typ, decoder)
   248  		}
   249  		decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder)
   250  		for _, extension := range ctx.extraExtensions {
   251  			decoder = extension.DecorateDecoder(typ, decoder)
   252  		}
   253  	}
   254  	return decoder
   255  }
   256  func _getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
   257  	for _, extension := range extensions {
   258  		decoder := extension.CreateDecoder(typ)
   259  		if decoder != nil {
   260  			return decoder
   261  		}
   262  	}
   263  	decoder := ctx.decoderExtension.CreateDecoder(typ)
   264  	if decoder != nil {
   265  		return decoder
   266  	}
   267  	for _, extension := range ctx.extraExtensions {
   268  		decoder := extension.CreateDecoder(typ)
   269  		if decoder != nil {
   270  			return decoder
   271  		}
   272  	}
   273  	typeName := typ.String()
   274  	decoder = typeDecoders[typeName]
   275  	if decoder != nil {
   276  		return decoder
   277  	}
   278  	if typ.Kind() == reflect.Ptr {
   279  		ptrType := typ.(*reflect2.UnsafePtrType)
   280  		decoder := typeDecoders[ptrType.Elem().String()]
   281  		if decoder != nil {
   282  			return &OptionalDecoder{ptrType.Elem(), decoder}
   283  		}
   284  	}
   285  	return nil
   286  }
   287  
   288  func getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
   289  	encoder := _getTypeEncoderFromExtension(ctx, typ)
   290  	if encoder != nil {
   291  		for _, extension := range extensions {
   292  			encoder = extension.DecorateEncoder(typ, encoder)
   293  		}
   294  		encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder)
   295  		for _, extension := range ctx.extraExtensions {
   296  			encoder = extension.DecorateEncoder(typ, encoder)
   297  		}
   298  	}
   299  	return encoder
   300  }
   301  
   302  func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
   303  	for _, extension := range extensions {
   304  		encoder := extension.CreateEncoder(typ)
   305  		if encoder != nil {
   306  			return encoder
   307  		}
   308  	}
   309  	encoder := ctx.encoderExtension.CreateEncoder(typ)
   310  	if encoder != nil {
   311  		return encoder
   312  	}
   313  	for _, extension := range ctx.extraExtensions {
   314  		encoder := extension.CreateEncoder(typ)
   315  		if encoder != nil {
   316  			return encoder
   317  		}
   318  	}
   319  	typeName := typ.String()
   320  	encoder = typeEncoders[typeName]
   321  	if encoder != nil {
   322  		return encoder
   323  	}
   324  	if typ.Kind() == reflect.Ptr {
   325  		typePtr := typ.(*reflect2.UnsafePtrType)
   326  		encoder := typeEncoders[typePtr.Elem().String()]
   327  		if encoder != nil {
   328  			return &OptionalEncoder{encoder}
   329  		}
   330  	}
   331  	return nil
   332  }
   333  
   334  func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor {
   335  	structType := typ.(*reflect2.UnsafeStructType)
   336  	embeddedBindings := []*Binding{}
   337  	bindings := []*Binding{}
   338  	for i := 0; i < structType.NumField(); i++ {
   339  		field := structType.Field(i)
   340  		tag, hastag := field.Tag().Lookup(ctx.getTagKey())
   341  		if ctx.onlyTaggedField && !hastag && !field.Anonymous() {
   342  			continue
   343  		}
   344  		if tag == "-" || field.Name() == "_" {
   345  			continue
   346  		}
   347  		tagParts := strings.Split(tag, ",")
   348  		if field.Anonymous() && (tag == "" || tagParts[0] == "") {
   349  			if field.Type().Kind() == reflect.Struct {
   350  				structDescriptor := describeStruct(ctx, field.Type())
   351  				for _, binding := range structDescriptor.Fields {
   352  					binding.levels = append([]int{i}, binding.levels...)
   353  					omitempty := binding.Encoder.(*structFieldEncoder).omitempty
   354  					binding.Encoder = &structFieldEncoder{field, binding.Encoder, omitempty}
   355  					binding.Decoder = &structFieldDecoder{field, binding.Decoder}
   356  					embeddedBindings = append(embeddedBindings, binding)
   357  				}
   358  				continue
   359  			} else if field.Type().Kind() == reflect.Ptr {
   360  				ptrType := field.Type().(*reflect2.UnsafePtrType)
   361  				if ptrType.Elem().Kind() == reflect.Struct {
   362  					structDescriptor := describeStruct(ctx, ptrType.Elem())
   363  					for _, binding := range structDescriptor.Fields {
   364  						binding.levels = append([]int{i}, binding.levels...)
   365  						omitempty := binding.Encoder.(*structFieldEncoder).omitempty
   366  						binding.Encoder = &dereferenceEncoder{binding.Encoder}
   367  						binding.Encoder = &structFieldEncoder{field, binding.Encoder, omitempty}
   368  						binding.Decoder = &dereferenceDecoder{ptrType.Elem(), binding.Decoder}
   369  						binding.Decoder = &structFieldDecoder{field, binding.Decoder}
   370  						embeddedBindings = append(embeddedBindings, binding)
   371  					}
   372  					continue
   373  				}
   374  			}
   375  		}
   376  		fieldNames := calcFieldNames(field.Name(), tagParts[0], tag)
   377  		fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name())
   378  		decoder := fieldDecoders[fieldCacheKey]
   379  		if decoder == nil {
   380  			decoder = decoderOfType(ctx.append(field.Name()), field.Type())
   381  		}
   382  		encoder := fieldEncoders[fieldCacheKey]
   383  		if encoder == nil {
   384  			encoder = encoderOfType(ctx.append(field.Name()), field.Type())
   385  		}
   386  		binding := &Binding{
   387  			Field:     field,
   388  			FromNames: fieldNames,
   389  			ToNames:   fieldNames,
   390  			Decoder:   decoder,
   391  			Encoder:   encoder,
   392  		}
   393  		binding.levels = []int{i}
   394  		bindings = append(bindings, binding)
   395  	}
   396  	return createStructDescriptor(ctx, typ, bindings, embeddedBindings)
   397  }
   398  func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, embeddedBindings []*Binding) *StructDescriptor {
   399  	structDescriptor := &StructDescriptor{
   400  		Type:   typ,
   401  		Fields: bindings,
   402  	}
   403  	for _, extension := range extensions {
   404  		extension.UpdateStructDescriptor(structDescriptor)
   405  	}
   406  	ctx.encoderExtension.UpdateStructDescriptor(structDescriptor)
   407  	ctx.decoderExtension.UpdateStructDescriptor(structDescriptor)
   408  	for _, extension := range ctx.extraExtensions {
   409  		extension.UpdateStructDescriptor(structDescriptor)
   410  	}
   411  	processTags(structDescriptor, ctx.frozenConfig)
   412  	// merge normal & embedded bindings & sort with original order
   413  	allBindings := sortableBindings(append(embeddedBindings, structDescriptor.Fields...))
   414  	sort.Sort(allBindings)
   415  	structDescriptor.Fields = allBindings
   416  	return structDescriptor
   417  }
   418  
   419  type sortableBindings []*Binding
   420  
   421  func (bindings sortableBindings) Len() int {
   422  	return len(bindings)
   423  }
   424  
   425  func (bindings sortableBindings) Less(i, j int) bool {
   426  	left := bindings[i].levels
   427  	right := bindings[j].levels
   428  	k := 0
   429  	for {
   430  		if left[k] < right[k] {
   431  			return true
   432  		} else if left[k] > right[k] {
   433  			return false
   434  		}
   435  		k++
   436  	}
   437  }
   438  
   439  func (bindings sortableBindings) Swap(i, j int) {
   440  	bindings[i], bindings[j] = bindings[j], bindings[i]
   441  }
   442  
   443  func processTags(structDescriptor *StructDescriptor, cfg *frozenConfig) {
   444  	for _, binding := range structDescriptor.Fields {
   445  		shouldOmitEmpty := false
   446  		tagParts := strings.Split(binding.Field.Tag().Get(cfg.getTagKey()), ",")
   447  		for _, tagPart := range tagParts[1:] {
   448  			if tagPart == "omitempty" {
   449  				shouldOmitEmpty = true
   450  			} else if tagPart == "string" {
   451  				if binding.Field.Type().Kind() == reflect.String {
   452  					binding.Decoder = &stringModeStringDecoder{binding.Decoder, cfg}
   453  					binding.Encoder = &stringModeStringEncoder{binding.Encoder, cfg}
   454  				} else {
   455  					binding.Decoder = &stringModeNumberDecoder{binding.Decoder}
   456  					binding.Encoder = &stringModeNumberEncoder{binding.Encoder}
   457  				}
   458  			}
   459  		}
   460  		binding.Decoder = &structFieldDecoder{binding.Field, binding.Decoder}
   461  		binding.Encoder = &structFieldEncoder{binding.Field, binding.Encoder, shouldOmitEmpty}
   462  	}
   463  }
   464  
   465  func calcFieldNames(originalFieldName string, tagProvidedFieldName string, wholeTag string) []string {
   466  	// ignore?
   467  	if wholeTag == "-" {
   468  		return []string{}
   469  	}
   470  	// rename?
   471  	var fieldNames []string
   472  	if tagProvidedFieldName == "" {
   473  		fieldNames = []string{originalFieldName}
   474  	} else {
   475  		fieldNames = []string{tagProvidedFieldName}
   476  	}
   477  	// private?
   478  	isNotExported := unicode.IsLower(rune(originalFieldName[0])) || originalFieldName[0] == '_'
   479  	if isNotExported {
   480  		fieldNames = []string{}
   481  	}
   482  	return fieldNames
   483  }
   484  

View as plain text