...

Source file src/github.com/goccy/go-json/internal/encoder/compiler.go

Documentation: github.com/goccy/go-json/internal/encoder

     1  package encoder
     2  
     3  import (
     4  	"context"
     5  	"encoding"
     6  	"encoding/json"
     7  	"reflect"
     8  	"sync/atomic"
     9  	"unsafe"
    10  
    11  	"github.com/goccy/go-json/internal/errors"
    12  	"github.com/goccy/go-json/internal/runtime"
    13  )
    14  
    15  type marshalerContext interface {
    16  	MarshalJSON(context.Context) ([]byte, error)
    17  }
    18  
    19  var (
    20  	marshalJSONType        = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
    21  	marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
    22  	marshalTextType        = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
    23  	jsonNumberType         = reflect.TypeOf(json.Number(""))
    24  	cachedOpcodeSets       []*OpcodeSet
    25  	cachedOpcodeMap        unsafe.Pointer // map[uintptr]*OpcodeSet
    26  	typeAddr               *runtime.TypeAddr
    27  )
    28  
    29  func init() {
    30  	typeAddr = runtime.AnalyzeTypeAddr()
    31  	if typeAddr == nil {
    32  		typeAddr = &runtime.TypeAddr{}
    33  	}
    34  	cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1)
    35  }
    36  
    37  func loadOpcodeMap() map[uintptr]*OpcodeSet {
    38  	p := atomic.LoadPointer(&cachedOpcodeMap)
    39  	return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p))
    40  }
    41  
    42  func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) {
    43  	newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1)
    44  	newOpcodeMap[typ] = set
    45  
    46  	for k, v := range m {
    47  		newOpcodeMap[k] = v
    48  	}
    49  
    50  	atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap)))
    51  }
    52  
    53  func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
    54  	opcodeMap := loadOpcodeMap()
    55  	if codeSet, exists := opcodeMap[typeptr]; exists {
    56  		return codeSet, nil
    57  	}
    58  	codeSet, err := newCompiler().compile(typeptr)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	storeOpcodeSet(typeptr, codeSet, opcodeMap)
    63  	return codeSet, nil
    64  }
    65  
    66  func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
    67  	if (ctx.Option.Flag & ContextOption) == 0 {
    68  		return codeSet, nil
    69  	}
    70  	query := FieldQueryFromContext(ctx.Option.Context)
    71  	if query == nil {
    72  		return codeSet, nil
    73  	}
    74  	ctx.Option.Flag |= FieldQueryOption
    75  	cacheCodeSet := codeSet.getQueryCache(query.Hash())
    76  	if cacheCodeSet != nil {
    77  		return cacheCodeSet, nil
    78  	}
    79  	queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	codeSet.setQueryCache(query.Hash(), queryCodeSet)
    84  	return queryCodeSet, nil
    85  }
    86  
    87  type Compiler struct {
    88  	structTypeToCode map[uintptr]*StructCode
    89  }
    90  
    91  func newCompiler() *Compiler {
    92  	return &Compiler{
    93  		structTypeToCode: map[uintptr]*StructCode{},
    94  	}
    95  }
    96  
    97  func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
    98  	// noescape trick for header.typ ( reflect.*rtype )
    99  	typ := *(**runtime.Type)(unsafe.Pointer(&typeptr))
   100  	code, err := c.typeToCode(typ)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	return c.codeToOpcodeSet(typ, code)
   105  }
   106  
   107  func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
   108  	noescapeKeyCode := c.codeToOpcode(&compileContext{
   109  		structTypeToCodes: map[uintptr]Opcodes{},
   110  		recursiveCodes:    &Opcodes{},
   111  	}, typ, code)
   112  	if err := noescapeKeyCode.Validate(); err != nil {
   113  		return nil, err
   114  	}
   115  	escapeKeyCode := c.codeToOpcode(&compileContext{
   116  		structTypeToCodes: map[uintptr]Opcodes{},
   117  		recursiveCodes:    &Opcodes{},
   118  		escapeKey:         true,
   119  	}, typ, code)
   120  	noescapeKeyCode = copyOpcode(noescapeKeyCode)
   121  	escapeKeyCode = copyOpcode(escapeKeyCode)
   122  	setTotalLengthToInterfaceOp(noescapeKeyCode)
   123  	setTotalLengthToInterfaceOp(escapeKeyCode)
   124  	interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode)
   125  	interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode)
   126  	codeLength := noescapeKeyCode.TotalLength()
   127  	return &OpcodeSet{
   128  		Type:                     typ,
   129  		NoescapeKeyCode:          noescapeKeyCode,
   130  		EscapeKeyCode:            escapeKeyCode,
   131  		InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode,
   132  		InterfaceEscapeKeyCode:   interfaceEscapeKeyCode,
   133  		CodeLength:               codeLength,
   134  		EndCode:                  ToEndCode(interfaceNoescapeKeyCode),
   135  		Code:                     code,
   136  		QueryCache:               map[string]*OpcodeSet{},
   137  	}, nil
   138  }
   139  
   140  func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
   141  	switch {
   142  	case c.implementsMarshalJSON(typ):
   143  		return c.marshalJSONCode(typ)
   144  	case c.implementsMarshalText(typ):
   145  		return c.marshalTextCode(typ)
   146  	}
   147  
   148  	isPtr := false
   149  	orgType := typ
   150  	if typ.Kind() == reflect.Ptr {
   151  		typ = typ.Elem()
   152  		isPtr = true
   153  	}
   154  	switch {
   155  	case c.implementsMarshalJSON(typ):
   156  		return c.marshalJSONCode(orgType)
   157  	case c.implementsMarshalText(typ):
   158  		return c.marshalTextCode(orgType)
   159  	}
   160  	switch typ.Kind() {
   161  	case reflect.Slice:
   162  		elem := typ.Elem()
   163  		if elem.Kind() == reflect.Uint8 {
   164  			p := runtime.PtrTo(elem)
   165  			if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
   166  				return c.bytesCode(typ, isPtr)
   167  			}
   168  		}
   169  		return c.sliceCode(typ)
   170  	case reflect.Map:
   171  		if isPtr {
   172  			return c.ptrCode(runtime.PtrTo(typ))
   173  		}
   174  		return c.mapCode(typ)
   175  	case reflect.Struct:
   176  		return c.structCode(typ, isPtr)
   177  	case reflect.Int:
   178  		return c.intCode(typ, isPtr)
   179  	case reflect.Int8:
   180  		return c.int8Code(typ, isPtr)
   181  	case reflect.Int16:
   182  		return c.int16Code(typ, isPtr)
   183  	case reflect.Int32:
   184  		return c.int32Code(typ, isPtr)
   185  	case reflect.Int64:
   186  		return c.int64Code(typ, isPtr)
   187  	case reflect.Uint, reflect.Uintptr:
   188  		return c.uintCode(typ, isPtr)
   189  	case reflect.Uint8:
   190  		return c.uint8Code(typ, isPtr)
   191  	case reflect.Uint16:
   192  		return c.uint16Code(typ, isPtr)
   193  	case reflect.Uint32:
   194  		return c.uint32Code(typ, isPtr)
   195  	case reflect.Uint64:
   196  		return c.uint64Code(typ, isPtr)
   197  	case reflect.Float32:
   198  		return c.float32Code(typ, isPtr)
   199  	case reflect.Float64:
   200  		return c.float64Code(typ, isPtr)
   201  	case reflect.String:
   202  		return c.stringCode(typ, isPtr)
   203  	case reflect.Bool:
   204  		return c.boolCode(typ, isPtr)
   205  	case reflect.Interface:
   206  		return c.interfaceCode(typ, isPtr)
   207  	default:
   208  		if isPtr && typ.Implements(marshalTextType) {
   209  			typ = orgType
   210  		}
   211  		return c.typeToCodeWithPtr(typ, isPtr)
   212  	}
   213  }
   214  
   215  func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) {
   216  	switch {
   217  	case c.implementsMarshalJSON(typ):
   218  		return c.marshalJSONCode(typ)
   219  	case c.implementsMarshalText(typ):
   220  		return c.marshalTextCode(typ)
   221  	}
   222  	switch typ.Kind() {
   223  	case reflect.Ptr:
   224  		return c.ptrCode(typ)
   225  	case reflect.Slice:
   226  		elem := typ.Elem()
   227  		if elem.Kind() == reflect.Uint8 {
   228  			p := runtime.PtrTo(elem)
   229  			if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
   230  				return c.bytesCode(typ, false)
   231  			}
   232  		}
   233  		return c.sliceCode(typ)
   234  	case reflect.Array:
   235  		return c.arrayCode(typ)
   236  	case reflect.Map:
   237  		return c.mapCode(typ)
   238  	case reflect.Struct:
   239  		return c.structCode(typ, isPtr)
   240  	case reflect.Interface:
   241  		return c.interfaceCode(typ, false)
   242  	case reflect.Int:
   243  		return c.intCode(typ, false)
   244  	case reflect.Int8:
   245  		return c.int8Code(typ, false)
   246  	case reflect.Int16:
   247  		return c.int16Code(typ, false)
   248  	case reflect.Int32:
   249  		return c.int32Code(typ, false)
   250  	case reflect.Int64:
   251  		return c.int64Code(typ, false)
   252  	case reflect.Uint:
   253  		return c.uintCode(typ, false)
   254  	case reflect.Uint8:
   255  		return c.uint8Code(typ, false)
   256  	case reflect.Uint16:
   257  		return c.uint16Code(typ, false)
   258  	case reflect.Uint32:
   259  		return c.uint32Code(typ, false)
   260  	case reflect.Uint64:
   261  		return c.uint64Code(typ, false)
   262  	case reflect.Uintptr:
   263  		return c.uintCode(typ, false)
   264  	case reflect.Float32:
   265  		return c.float32Code(typ, false)
   266  	case reflect.Float64:
   267  		return c.float64Code(typ, false)
   268  	case reflect.String:
   269  		return c.stringCode(typ, false)
   270  	case reflect.Bool:
   271  		return c.boolCode(typ, false)
   272  	}
   273  	return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
   274  }
   275  
   276  const intSize = 32 << (^uint(0) >> 63)
   277  
   278  //nolint:unparam
   279  func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) {
   280  	return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
   281  }
   282  
   283  //nolint:unparam
   284  func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
   285  	return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
   286  }
   287  
   288  //nolint:unparam
   289  func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
   290  	return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
   291  }
   292  
   293  //nolint:unparam
   294  func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
   295  	return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
   296  }
   297  
   298  //nolint:unparam
   299  func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
   300  	return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
   301  }
   302  
   303  //nolint:unparam
   304  func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) {
   305  	return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
   306  }
   307  
   308  //nolint:unparam
   309  func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
   310  	return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
   311  }
   312  
   313  //nolint:unparam
   314  func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
   315  	return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
   316  }
   317  
   318  //nolint:unparam
   319  func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
   320  	return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
   321  }
   322  
   323  //nolint:unparam
   324  func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
   325  	return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
   326  }
   327  
   328  //nolint:unparam
   329  func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
   330  	return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
   331  }
   332  
   333  //nolint:unparam
   334  func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
   335  	return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
   336  }
   337  
   338  //nolint:unparam
   339  func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) {
   340  	return &StringCode{typ: typ, isPtr: isPtr}, nil
   341  }
   342  
   343  //nolint:unparam
   344  func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) {
   345  	return &BoolCode{typ: typ, isPtr: isPtr}, nil
   346  }
   347  
   348  //nolint:unparam
   349  func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) {
   350  	return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil
   351  }
   352  
   353  //nolint:unparam
   354  func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) {
   355  	return &IntCode{typ: typ, bitSize: 8, isString: true}, nil
   356  }
   357  
   358  //nolint:unparam
   359  func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) {
   360  	return &IntCode{typ: typ, bitSize: 16, isString: true}, nil
   361  }
   362  
   363  //nolint:unparam
   364  func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) {
   365  	return &IntCode{typ: typ, bitSize: 32, isString: true}, nil
   366  }
   367  
   368  //nolint:unparam
   369  func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) {
   370  	return &IntCode{typ: typ, bitSize: 64, isString: true}, nil
   371  }
   372  
   373  //nolint:unparam
   374  func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) {
   375  	return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil
   376  }
   377  
   378  //nolint:unparam
   379  func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) {
   380  	return &UintCode{typ: typ, bitSize: 8, isString: true}, nil
   381  }
   382  
   383  //nolint:unparam
   384  func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) {
   385  	return &UintCode{typ: typ, bitSize: 16, isString: true}, nil
   386  }
   387  
   388  //nolint:unparam
   389  func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) {
   390  	return &UintCode{typ: typ, bitSize: 32, isString: true}, nil
   391  }
   392  
   393  //nolint:unparam
   394  func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) {
   395  	return &UintCode{typ: typ, bitSize: 64, isString: true}, nil
   396  }
   397  
   398  //nolint:unparam
   399  func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) {
   400  	return &BytesCode{typ: typ, isPtr: isPtr}, nil
   401  }
   402  
   403  //nolint:unparam
   404  func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) {
   405  	return &InterfaceCode{typ: typ, isPtr: isPtr}, nil
   406  }
   407  
   408  //nolint:unparam
   409  func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) {
   410  	return &MarshalJSONCode{
   411  		typ:                typ,
   412  		isAddrForMarshaler: c.isPtrMarshalJSONType(typ),
   413  		isNilableType:      c.isNilableType(typ),
   414  		isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType),
   415  	}, nil
   416  }
   417  
   418  //nolint:unparam
   419  func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) {
   420  	return &MarshalTextCode{
   421  		typ:                typ,
   422  		isAddrForMarshaler: c.isPtrMarshalTextType(typ),
   423  		isNilableType:      c.isNilableType(typ),
   424  	}, nil
   425  }
   426  
   427  func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) {
   428  	code, err := c.typeToCodeWithPtr(typ.Elem(), true)
   429  	if err != nil {
   430  		return nil, err
   431  	}
   432  	ptr, ok := code.(*PtrCode)
   433  	if ok {
   434  		return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil
   435  	}
   436  	return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil
   437  }
   438  
   439  func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) {
   440  	elem := typ.Elem()
   441  	code, err := c.listElemCode(elem)
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  	if code.Kind() == CodeKindStruct {
   446  		structCode := code.(*StructCode)
   447  		structCode.enableIndirect()
   448  	}
   449  	return &SliceCode{typ: typ, value: code}, nil
   450  }
   451  
   452  func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) {
   453  	elem := typ.Elem()
   454  	code, err := c.listElemCode(elem)
   455  	if err != nil {
   456  		return nil, err
   457  	}
   458  	if code.Kind() == CodeKindStruct {
   459  		structCode := code.(*StructCode)
   460  		structCode.enableIndirect()
   461  	}
   462  	return &ArrayCode{typ: typ, value: code}, nil
   463  }
   464  
   465  func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) {
   466  	keyCode, err := c.mapKeyCode(typ.Key())
   467  	if err != nil {
   468  		return nil, err
   469  	}
   470  	valueCode, err := c.mapValueCode(typ.Elem())
   471  	if err != nil {
   472  		return nil, err
   473  	}
   474  	if valueCode.Kind() == CodeKindStruct {
   475  		structCode := valueCode.(*StructCode)
   476  		structCode.enableIndirect()
   477  	}
   478  	return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil
   479  }
   480  
   481  func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) {
   482  	switch {
   483  	case c.isPtrMarshalJSONType(typ):
   484  		return c.marshalJSONCode(typ)
   485  	case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
   486  		return c.marshalTextCode(typ)
   487  	case typ.Kind() == reflect.Map:
   488  		return c.ptrCode(runtime.PtrTo(typ))
   489  	default:
   490  		// isPtr was originally used to indicate whether the type of top level is pointer.
   491  		// However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true.
   492  		// See here for related issues: https://github.com/goccy/go-json/issues/370
   493  		code, err := c.typeToCodeWithPtr(typ, true)
   494  		if err != nil {
   495  			return nil, err
   496  		}
   497  		ptr, ok := code.(*PtrCode)
   498  		if ok {
   499  			if ptr.value.Kind() == CodeKindMap {
   500  				ptr.ptrNum++
   501  			}
   502  		}
   503  		return code, nil
   504  	}
   505  }
   506  
   507  func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
   508  	switch {
   509  	case c.implementsMarshalText(typ):
   510  		return c.marshalTextCode(typ)
   511  	}
   512  	switch typ.Kind() {
   513  	case reflect.Ptr:
   514  		return c.ptrCode(typ)
   515  	case reflect.String:
   516  		return c.stringCode(typ, false)
   517  	case reflect.Int:
   518  		return c.intStringCode(typ)
   519  	case reflect.Int8:
   520  		return c.int8StringCode(typ)
   521  	case reflect.Int16:
   522  		return c.int16StringCode(typ)
   523  	case reflect.Int32:
   524  		return c.int32StringCode(typ)
   525  	case reflect.Int64:
   526  		return c.int64StringCode(typ)
   527  	case reflect.Uint:
   528  		return c.uintStringCode(typ)
   529  	case reflect.Uint8:
   530  		return c.uint8StringCode(typ)
   531  	case reflect.Uint16:
   532  		return c.uint16StringCode(typ)
   533  	case reflect.Uint32:
   534  		return c.uint32StringCode(typ)
   535  	case reflect.Uint64:
   536  		return c.uint64StringCode(typ)
   537  	case reflect.Uintptr:
   538  		return c.uintStringCode(typ)
   539  	}
   540  	return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
   541  }
   542  
   543  func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
   544  	switch typ.Kind() {
   545  	case reflect.Map:
   546  		return c.ptrCode(runtime.PtrTo(typ))
   547  	default:
   548  		code, err := c.typeToCodeWithPtr(typ, false)
   549  		if err != nil {
   550  			return nil, err
   551  		}
   552  		ptr, ok := code.(*PtrCode)
   553  		if ok {
   554  			if ptr.value.Kind() == CodeKindMap {
   555  				ptr.ptrNum++
   556  			}
   557  		}
   558  		return code, nil
   559  	}
   560  }
   561  
   562  func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) {
   563  	typeptr := uintptr(unsafe.Pointer(typ))
   564  	if code, exists := c.structTypeToCode[typeptr]; exists {
   565  		derefCode := *code
   566  		derefCode.isRecursive = true
   567  		return &derefCode, nil
   568  	}
   569  	indirect := runtime.IfaceIndir(typ)
   570  	code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect}
   571  	c.structTypeToCode[typeptr] = code
   572  
   573  	fieldNum := typ.NumField()
   574  	tags := c.typeToStructTags(typ)
   575  	fields := []*StructFieldCode{}
   576  	for i, tag := range tags {
   577  		isOnlyOneFirstField := i == 0 && fieldNum == 1
   578  		field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField)
   579  		if err != nil {
   580  			return nil, err
   581  		}
   582  		if field.isAnonymous {
   583  			structCode := field.getAnonymousStruct()
   584  			if structCode != nil {
   585  				structCode.removeFieldsByTags(tags)
   586  				if c.isAssignableIndirect(field, isPtr) {
   587  					if indirect {
   588  						structCode.isIndirect = true
   589  					} else {
   590  						structCode.isIndirect = false
   591  					}
   592  				}
   593  			}
   594  		} else {
   595  			structCode := field.getStruct()
   596  			if structCode != nil {
   597  				if indirect {
   598  					// if parent is indirect type, set child indirect property to true
   599  					structCode.isIndirect = true
   600  				} else {
   601  					// if parent is not indirect type, set child indirect property to false.
   602  					// but if parent's indirect is false and isPtr is true, then indirect must be true.
   603  					// Do this only if indirectConversion is enabled at the end of compileStruct.
   604  					structCode.isIndirect = false
   605  				}
   606  			}
   607  		}
   608  		fields = append(fields, field)
   609  	}
   610  	fieldMap := c.getFieldMap(fields)
   611  	duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap)
   612  	code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap)
   613  	if !code.disableIndirectConversion && !indirect && isPtr {
   614  		code.enableIndirect()
   615  	}
   616  	delete(c.structTypeToCode, typeptr)
   617  	return code, nil
   618  }
   619  
   620  func toElemType(t *runtime.Type) *runtime.Type {
   621  	for t.Kind() == reflect.Ptr {
   622  		t = t.Elem()
   623  	}
   624  	return t
   625  }
   626  
   627  func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) {
   628  	field := tag.Field
   629  	fieldType := runtime.Type2RType(field.Type)
   630  	isIndirectSpecialCase := isPtr && isOnlyOneFirstField
   631  	fieldCode := &StructFieldCode{
   632  		typ:           fieldType,
   633  		key:           tag.Key,
   634  		tag:           tag,
   635  		offset:        field.Offset,
   636  		isAnonymous:   field.Anonymous && !tag.IsTaggedKey && toElemType(fieldType).Kind() == reflect.Struct,
   637  		isTaggedKey:   tag.IsTaggedKey,
   638  		isNilableType: c.isNilableType(fieldType),
   639  		isNilCheck:    true,
   640  	}
   641  	switch {
   642  	case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase):
   643  		code, err := c.marshalJSONCode(fieldType)
   644  		if err != nil {
   645  			return nil, err
   646  		}
   647  		fieldCode.value = code
   648  		fieldCode.isAddrForMarshaler = true
   649  		fieldCode.isNilCheck = false
   650  		structCode.isIndirect = false
   651  		structCode.disableIndirectConversion = true
   652  	case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase):
   653  		code, err := c.marshalTextCode(fieldType)
   654  		if err != nil {
   655  			return nil, err
   656  		}
   657  		fieldCode.value = code
   658  		fieldCode.isAddrForMarshaler = true
   659  		fieldCode.isNilCheck = false
   660  		structCode.isIndirect = false
   661  		structCode.disableIndirectConversion = true
   662  	case isPtr && c.isPtrMarshalJSONType(fieldType):
   663  		// *struct{ field T }
   664  		// func (*T) MarshalJSON() ([]byte, error)
   665  		code, err := c.marshalJSONCode(fieldType)
   666  		if err != nil {
   667  			return nil, err
   668  		}
   669  		fieldCode.value = code
   670  		fieldCode.isAddrForMarshaler = true
   671  		fieldCode.isNilCheck = false
   672  	case isPtr && c.isPtrMarshalTextType(fieldType):
   673  		// *struct{ field T }
   674  		// func (*T) MarshalText() ([]byte, error)
   675  		code, err := c.marshalTextCode(fieldType)
   676  		if err != nil {
   677  			return nil, err
   678  		}
   679  		fieldCode.value = code
   680  		fieldCode.isAddrForMarshaler = true
   681  		fieldCode.isNilCheck = false
   682  	default:
   683  		code, err := c.typeToCodeWithPtr(fieldType, isPtr)
   684  		if err != nil {
   685  			return nil, err
   686  		}
   687  		switch code.Kind() {
   688  		case CodeKindPtr, CodeKindInterface:
   689  			fieldCode.isNextOpPtrType = true
   690  		}
   691  		fieldCode.value = code
   692  	}
   693  	return fieldCode, nil
   694  }
   695  
   696  func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool {
   697  	if isPtr {
   698  		return false
   699  	}
   700  	codeType := fieldCode.value.Kind()
   701  	if codeType == CodeKindMarshalJSON {
   702  		return false
   703  	}
   704  	if codeType == CodeKindMarshalText {
   705  		return false
   706  	}
   707  	return true
   708  }
   709  
   710  func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode {
   711  	fieldMap := map[string][]*StructFieldCode{}
   712  	for _, field := range fields {
   713  		if field.isAnonymous {
   714  			for k, v := range c.getAnonymousFieldMap(field) {
   715  				fieldMap[k] = append(fieldMap[k], v...)
   716  			}
   717  			continue
   718  		}
   719  		fieldMap[field.key] = append(fieldMap[field.key], field)
   720  	}
   721  	return fieldMap
   722  }
   723  
   724  func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode {
   725  	fieldMap := map[string][]*StructFieldCode{}
   726  	structCode := field.getAnonymousStruct()
   727  	if structCode == nil || structCode.isRecursive {
   728  		fieldMap[field.key] = append(fieldMap[field.key], field)
   729  		return fieldMap
   730  	}
   731  	for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) {
   732  		fieldMap[k] = append(fieldMap[k], v...)
   733  	}
   734  	return fieldMap
   735  }
   736  
   737  func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode {
   738  	fieldMap := map[string][]*StructFieldCode{}
   739  	for _, field := range fields {
   740  		if field.isAnonymous {
   741  			for k, v := range c.getAnonymousFieldMap(field) {
   742  				// Do not handle tagged key when embedding more than once
   743  				for _, vv := range v {
   744  					vv.isTaggedKey = false
   745  				}
   746  				fieldMap[k] = append(fieldMap[k], v...)
   747  			}
   748  			continue
   749  		}
   750  		fieldMap[field.key] = append(fieldMap[field.key], field)
   751  	}
   752  	return fieldMap
   753  }
   754  
   755  func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} {
   756  	duplicatedFieldMap := map[*StructFieldCode]struct{}{}
   757  	for _, fields := range fieldMap {
   758  		if len(fields) == 1 {
   759  			continue
   760  		}
   761  		if c.isTaggedKeyOnly(fields) {
   762  			for _, field := range fields {
   763  				if field.isTaggedKey {
   764  					continue
   765  				}
   766  				duplicatedFieldMap[field] = struct{}{}
   767  			}
   768  		} else {
   769  			for _, field := range fields {
   770  				duplicatedFieldMap[field] = struct{}{}
   771  			}
   772  		}
   773  	}
   774  	return duplicatedFieldMap
   775  }
   776  
   777  func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode {
   778  	filteredFields := make([]*StructFieldCode, 0, len(fields))
   779  	for _, field := range fields {
   780  		if field.isAnonymous {
   781  			structCode := field.getAnonymousStruct()
   782  			if structCode != nil && !structCode.isRecursive {
   783  				structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap)
   784  				if len(structCode.fields) > 0 {
   785  					filteredFields = append(filteredFields, field)
   786  				}
   787  				continue
   788  			}
   789  		}
   790  		if _, exists := duplicatedFieldMap[field]; exists {
   791  			continue
   792  		}
   793  		filteredFields = append(filteredFields, field)
   794  	}
   795  	return filteredFields
   796  }
   797  
   798  func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool {
   799  	var taggedKeyFieldCount int
   800  	for _, field := range fields {
   801  		if field.isTaggedKey {
   802  			taggedKeyFieldCount++
   803  		}
   804  	}
   805  	return taggedKeyFieldCount == 1
   806  }
   807  
   808  func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags {
   809  	tags := runtime.StructTags{}
   810  	fieldNum := typ.NumField()
   811  	for i := 0; i < fieldNum; i++ {
   812  		field := typ.Field(i)
   813  		if runtime.IsIgnoredStructField(field) {
   814  			continue
   815  		}
   816  		tags = append(tags, runtime.StructTagFromField(field))
   817  	}
   818  	return tags
   819  }
   820  
   821  // *struct{ field T } => struct { field *T }
   822  // func (*T) MarshalJSON() ([]byte, error)
   823  func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
   824  	return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ)
   825  }
   826  
   827  // *struct{ field T } => struct { field *T }
   828  // func (*T) MarshalText() ([]byte, error)
   829  func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
   830  	return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ)
   831  }
   832  
   833  func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool {
   834  	if !c.implementsMarshalJSONType(typ) {
   835  		return false
   836  	}
   837  	if typ.Kind() != reflect.Ptr {
   838  		return true
   839  	}
   840  	// type kind is reflect.Ptr
   841  	if !c.implementsMarshalJSONType(typ.Elem()) {
   842  		return true
   843  	}
   844  	// needs to dereference
   845  	return false
   846  }
   847  
   848  func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool {
   849  	if !typ.Implements(marshalTextType) {
   850  		return false
   851  	}
   852  	if typ.Kind() != reflect.Ptr {
   853  		return true
   854  	}
   855  	// type kind is reflect.Ptr
   856  	if !typ.Elem().Implements(marshalTextType) {
   857  		return true
   858  	}
   859  	// needs to dereference
   860  	return false
   861  }
   862  
   863  func (c *Compiler) isNilableType(typ *runtime.Type) bool {
   864  	if !runtime.IfaceIndir(typ) {
   865  		return true
   866  	}
   867  	switch typ.Kind() {
   868  	case reflect.Ptr:
   869  		return true
   870  	case reflect.Map:
   871  		return true
   872  	case reflect.Func:
   873  		return true
   874  	default:
   875  		return false
   876  	}
   877  }
   878  
   879  func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool {
   880  	return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
   881  }
   882  
   883  func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool {
   884  	return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ))
   885  }
   886  
   887  func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool {
   888  	return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType)
   889  }
   890  
   891  func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode {
   892  	codes := code.ToOpcode(ctx)
   893  	codes.Last().Next = newEndOp(ctx, typ)
   894  	c.linkRecursiveCode(ctx)
   895  	return codes.First()
   896  }
   897  
   898  func (c *Compiler) linkRecursiveCode(ctx *compileContext) {
   899  	recursiveCodes := map[uintptr]*CompiledCode{}
   900  	for _, recursive := range *ctx.recursiveCodes {
   901  		typeptr := uintptr(unsafe.Pointer(recursive.Type))
   902  		codes := ctx.structTypeToCodes[typeptr]
   903  		if recursiveCode, ok := recursiveCodes[typeptr]; ok {
   904  			*recursive.Jmp = *recursiveCode
   905  			continue
   906  		}
   907  
   908  		code := copyOpcode(codes.First())
   909  		code.Op = code.Op.PtrHeadToHead()
   910  		lastCode := newEndOp(&compileContext{}, recursive.Type)
   911  		lastCode.Op = OpRecursiveEnd
   912  
   913  		// OpRecursiveEnd must set before call TotalLength
   914  		code.End.Next = lastCode
   915  
   916  		totalLength := code.TotalLength()
   917  
   918  		// Idx, ElemIdx, Length must set after call TotalLength
   919  		lastCode.Idx = uint32((totalLength + 1) * uintptrSize)
   920  		lastCode.ElemIdx = lastCode.Idx + uintptrSize
   921  		lastCode.Length = lastCode.Idx + 2*uintptrSize
   922  
   923  		// extend length to alloc slot for elemIdx + length
   924  		curTotalLength := uintptr(recursive.TotalLength()) + 3
   925  		nextTotalLength := uintptr(totalLength) + 3
   926  
   927  		compiled := recursive.Jmp
   928  		compiled.Code = code
   929  		compiled.CurLen = curTotalLength
   930  		compiled.NextLen = nextTotalLength
   931  		compiled.Linked = true
   932  
   933  		recursiveCodes[typeptr] = compiled
   934  	}
   935  }
   936  

View as plain text