1 package encoder
2
3 import (
4 "context"
5 "encoding"
6 "encoding/json"
7 "reflect"
8 "sync/atomic"
9 "unsafe"
10
11 "github.com/goccy/go-json/internal/errors"
12 "github.com/goccy/go-json/internal/runtime"
13 )
14
15 type marshalerContext interface {
16 MarshalJSON(context.Context) ([]byte, error)
17 }
18
19 var (
20 marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
21 marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
22 marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
23 jsonNumberType = reflect.TypeOf(json.Number(""))
24 cachedOpcodeSets []*OpcodeSet
25 cachedOpcodeMap unsafe.Pointer
26 typeAddr *runtime.TypeAddr
27 )
28
29 func init() {
30 typeAddr = runtime.AnalyzeTypeAddr()
31 if typeAddr == nil {
32 typeAddr = &runtime.TypeAddr{}
33 }
34 cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1)
35 }
36
37 func loadOpcodeMap() map[uintptr]*OpcodeSet {
38 p := atomic.LoadPointer(&cachedOpcodeMap)
39 return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p))
40 }
41
42 func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) {
43 newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1)
44 newOpcodeMap[typ] = set
45
46 for k, v := range m {
47 newOpcodeMap[k] = v
48 }
49
50 atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap)))
51 }
52
53 func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
54 opcodeMap := loadOpcodeMap()
55 if codeSet, exists := opcodeMap[typeptr]; exists {
56 return codeSet, nil
57 }
58 codeSet, err := newCompiler().compile(typeptr)
59 if err != nil {
60 return nil, err
61 }
62 storeOpcodeSet(typeptr, codeSet, opcodeMap)
63 return codeSet, nil
64 }
65
66 func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
67 if (ctx.Option.Flag & ContextOption) == 0 {
68 return codeSet, nil
69 }
70 query := FieldQueryFromContext(ctx.Option.Context)
71 if query == nil {
72 return codeSet, nil
73 }
74 ctx.Option.Flag |= FieldQueryOption
75 cacheCodeSet := codeSet.getQueryCache(query.Hash())
76 if cacheCodeSet != nil {
77 return cacheCodeSet, nil
78 }
79 queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
80 if err != nil {
81 return nil, err
82 }
83 codeSet.setQueryCache(query.Hash(), queryCodeSet)
84 return queryCodeSet, nil
85 }
86
87 type Compiler struct {
88 structTypeToCode map[uintptr]*StructCode
89 }
90
91 func newCompiler() *Compiler {
92 return &Compiler{
93 structTypeToCode: map[uintptr]*StructCode{},
94 }
95 }
96
97 func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
98
99 typ := *(**runtime.Type)(unsafe.Pointer(&typeptr))
100 code, err := c.typeToCode(typ)
101 if err != nil {
102 return nil, err
103 }
104 return c.codeToOpcodeSet(typ, code)
105 }
106
107 func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
108 noescapeKeyCode := c.codeToOpcode(&compileContext{
109 structTypeToCodes: map[uintptr]Opcodes{},
110 recursiveCodes: &Opcodes{},
111 }, typ, code)
112 if err := noescapeKeyCode.Validate(); err != nil {
113 return nil, err
114 }
115 escapeKeyCode := c.codeToOpcode(&compileContext{
116 structTypeToCodes: map[uintptr]Opcodes{},
117 recursiveCodes: &Opcodes{},
118 escapeKey: true,
119 }, typ, code)
120 noescapeKeyCode = copyOpcode(noescapeKeyCode)
121 escapeKeyCode = copyOpcode(escapeKeyCode)
122 setTotalLengthToInterfaceOp(noescapeKeyCode)
123 setTotalLengthToInterfaceOp(escapeKeyCode)
124 interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode)
125 interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode)
126 codeLength := noescapeKeyCode.TotalLength()
127 return &OpcodeSet{
128 Type: typ,
129 NoescapeKeyCode: noescapeKeyCode,
130 EscapeKeyCode: escapeKeyCode,
131 InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode,
132 InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
133 CodeLength: codeLength,
134 EndCode: ToEndCode(interfaceNoescapeKeyCode),
135 Code: code,
136 QueryCache: map[string]*OpcodeSet{},
137 }, nil
138 }
139
140 func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
141 switch {
142 case c.implementsMarshalJSON(typ):
143 return c.marshalJSONCode(typ)
144 case c.implementsMarshalText(typ):
145 return c.marshalTextCode(typ)
146 }
147
148 isPtr := false
149 orgType := typ
150 if typ.Kind() == reflect.Ptr {
151 typ = typ.Elem()
152 isPtr = true
153 }
154 switch {
155 case c.implementsMarshalJSON(typ):
156 return c.marshalJSONCode(orgType)
157 case c.implementsMarshalText(typ):
158 return c.marshalTextCode(orgType)
159 }
160 switch typ.Kind() {
161 case reflect.Slice:
162 elem := typ.Elem()
163 if elem.Kind() == reflect.Uint8 {
164 p := runtime.PtrTo(elem)
165 if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
166 return c.bytesCode(typ, isPtr)
167 }
168 }
169 return c.sliceCode(typ)
170 case reflect.Map:
171 if isPtr {
172 return c.ptrCode(runtime.PtrTo(typ))
173 }
174 return c.mapCode(typ)
175 case reflect.Struct:
176 return c.structCode(typ, isPtr)
177 case reflect.Int:
178 return c.intCode(typ, isPtr)
179 case reflect.Int8:
180 return c.int8Code(typ, isPtr)
181 case reflect.Int16:
182 return c.int16Code(typ, isPtr)
183 case reflect.Int32:
184 return c.int32Code(typ, isPtr)
185 case reflect.Int64:
186 return c.int64Code(typ, isPtr)
187 case reflect.Uint, reflect.Uintptr:
188 return c.uintCode(typ, isPtr)
189 case reflect.Uint8:
190 return c.uint8Code(typ, isPtr)
191 case reflect.Uint16:
192 return c.uint16Code(typ, isPtr)
193 case reflect.Uint32:
194 return c.uint32Code(typ, isPtr)
195 case reflect.Uint64:
196 return c.uint64Code(typ, isPtr)
197 case reflect.Float32:
198 return c.float32Code(typ, isPtr)
199 case reflect.Float64:
200 return c.float64Code(typ, isPtr)
201 case reflect.String:
202 return c.stringCode(typ, isPtr)
203 case reflect.Bool:
204 return c.boolCode(typ, isPtr)
205 case reflect.Interface:
206 return c.interfaceCode(typ, isPtr)
207 default:
208 if isPtr && typ.Implements(marshalTextType) {
209 typ = orgType
210 }
211 return c.typeToCodeWithPtr(typ, isPtr)
212 }
213 }
214
215 func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) {
216 switch {
217 case c.implementsMarshalJSON(typ):
218 return c.marshalJSONCode(typ)
219 case c.implementsMarshalText(typ):
220 return c.marshalTextCode(typ)
221 }
222 switch typ.Kind() {
223 case reflect.Ptr:
224 return c.ptrCode(typ)
225 case reflect.Slice:
226 elem := typ.Elem()
227 if elem.Kind() == reflect.Uint8 {
228 p := runtime.PtrTo(elem)
229 if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
230 return c.bytesCode(typ, false)
231 }
232 }
233 return c.sliceCode(typ)
234 case reflect.Array:
235 return c.arrayCode(typ)
236 case reflect.Map:
237 return c.mapCode(typ)
238 case reflect.Struct:
239 return c.structCode(typ, isPtr)
240 case reflect.Interface:
241 return c.interfaceCode(typ, false)
242 case reflect.Int:
243 return c.intCode(typ, false)
244 case reflect.Int8:
245 return c.int8Code(typ, false)
246 case reflect.Int16:
247 return c.int16Code(typ, false)
248 case reflect.Int32:
249 return c.int32Code(typ, false)
250 case reflect.Int64:
251 return c.int64Code(typ, false)
252 case reflect.Uint:
253 return c.uintCode(typ, false)
254 case reflect.Uint8:
255 return c.uint8Code(typ, false)
256 case reflect.Uint16:
257 return c.uint16Code(typ, false)
258 case reflect.Uint32:
259 return c.uint32Code(typ, false)
260 case reflect.Uint64:
261 return c.uint64Code(typ, false)
262 case reflect.Uintptr:
263 return c.uintCode(typ, false)
264 case reflect.Float32:
265 return c.float32Code(typ, false)
266 case reflect.Float64:
267 return c.float64Code(typ, false)
268 case reflect.String:
269 return c.stringCode(typ, false)
270 case reflect.Bool:
271 return c.boolCode(typ, false)
272 }
273 return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
274 }
275
276 const intSize = 32 << (^uint(0) >> 63)
277
278
279 func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) {
280 return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
281 }
282
283
284 func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
285 return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
286 }
287
288
289 func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
290 return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
291 }
292
293
294 func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
295 return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
296 }
297
298
299 func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
300 return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
301 }
302
303
304 func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) {
305 return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
306 }
307
308
309 func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
310 return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
311 }
312
313
314 func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
315 return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
316 }
317
318
319 func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
320 return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
321 }
322
323
324 func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
325 return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
326 }
327
328
329 func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
330 return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
331 }
332
333
334 func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
335 return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
336 }
337
338
339 func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) {
340 return &StringCode{typ: typ, isPtr: isPtr}, nil
341 }
342
343
344 func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) {
345 return &BoolCode{typ: typ, isPtr: isPtr}, nil
346 }
347
348
349 func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) {
350 return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil
351 }
352
353
354 func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) {
355 return &IntCode{typ: typ, bitSize: 8, isString: true}, nil
356 }
357
358
359 func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) {
360 return &IntCode{typ: typ, bitSize: 16, isString: true}, nil
361 }
362
363
364 func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) {
365 return &IntCode{typ: typ, bitSize: 32, isString: true}, nil
366 }
367
368
369 func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) {
370 return &IntCode{typ: typ, bitSize: 64, isString: true}, nil
371 }
372
373
374 func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) {
375 return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil
376 }
377
378
379 func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) {
380 return &UintCode{typ: typ, bitSize: 8, isString: true}, nil
381 }
382
383
384 func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) {
385 return &UintCode{typ: typ, bitSize: 16, isString: true}, nil
386 }
387
388
389 func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) {
390 return &UintCode{typ: typ, bitSize: 32, isString: true}, nil
391 }
392
393
394 func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) {
395 return &UintCode{typ: typ, bitSize: 64, isString: true}, nil
396 }
397
398
399 func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) {
400 return &BytesCode{typ: typ, isPtr: isPtr}, nil
401 }
402
403
404 func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) {
405 return &InterfaceCode{typ: typ, isPtr: isPtr}, nil
406 }
407
408
409 func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) {
410 return &MarshalJSONCode{
411 typ: typ,
412 isAddrForMarshaler: c.isPtrMarshalJSONType(typ),
413 isNilableType: c.isNilableType(typ),
414 isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType),
415 }, nil
416 }
417
418
419 func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) {
420 return &MarshalTextCode{
421 typ: typ,
422 isAddrForMarshaler: c.isPtrMarshalTextType(typ),
423 isNilableType: c.isNilableType(typ),
424 }, nil
425 }
426
427 func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) {
428 code, err := c.typeToCodeWithPtr(typ.Elem(), true)
429 if err != nil {
430 return nil, err
431 }
432 ptr, ok := code.(*PtrCode)
433 if ok {
434 return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil
435 }
436 return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil
437 }
438
439 func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) {
440 elem := typ.Elem()
441 code, err := c.listElemCode(elem)
442 if err != nil {
443 return nil, err
444 }
445 if code.Kind() == CodeKindStruct {
446 structCode := code.(*StructCode)
447 structCode.enableIndirect()
448 }
449 return &SliceCode{typ: typ, value: code}, nil
450 }
451
452 func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) {
453 elem := typ.Elem()
454 code, err := c.listElemCode(elem)
455 if err != nil {
456 return nil, err
457 }
458 if code.Kind() == CodeKindStruct {
459 structCode := code.(*StructCode)
460 structCode.enableIndirect()
461 }
462 return &ArrayCode{typ: typ, value: code}, nil
463 }
464
465 func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) {
466 keyCode, err := c.mapKeyCode(typ.Key())
467 if err != nil {
468 return nil, err
469 }
470 valueCode, err := c.mapValueCode(typ.Elem())
471 if err != nil {
472 return nil, err
473 }
474 if valueCode.Kind() == CodeKindStruct {
475 structCode := valueCode.(*StructCode)
476 structCode.enableIndirect()
477 }
478 return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil
479 }
480
481 func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) {
482 switch {
483 case c.isPtrMarshalJSONType(typ):
484 return c.marshalJSONCode(typ)
485 case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
486 return c.marshalTextCode(typ)
487 case typ.Kind() == reflect.Map:
488 return c.ptrCode(runtime.PtrTo(typ))
489 default:
490
491
492
493 code, err := c.typeToCodeWithPtr(typ, true)
494 if err != nil {
495 return nil, err
496 }
497 ptr, ok := code.(*PtrCode)
498 if ok {
499 if ptr.value.Kind() == CodeKindMap {
500 ptr.ptrNum++
501 }
502 }
503 return code, nil
504 }
505 }
506
507 func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
508 switch {
509 case c.implementsMarshalText(typ):
510 return c.marshalTextCode(typ)
511 }
512 switch typ.Kind() {
513 case reflect.Ptr:
514 return c.ptrCode(typ)
515 case reflect.String:
516 return c.stringCode(typ, false)
517 case reflect.Int:
518 return c.intStringCode(typ)
519 case reflect.Int8:
520 return c.int8StringCode(typ)
521 case reflect.Int16:
522 return c.int16StringCode(typ)
523 case reflect.Int32:
524 return c.int32StringCode(typ)
525 case reflect.Int64:
526 return c.int64StringCode(typ)
527 case reflect.Uint:
528 return c.uintStringCode(typ)
529 case reflect.Uint8:
530 return c.uint8StringCode(typ)
531 case reflect.Uint16:
532 return c.uint16StringCode(typ)
533 case reflect.Uint32:
534 return c.uint32StringCode(typ)
535 case reflect.Uint64:
536 return c.uint64StringCode(typ)
537 case reflect.Uintptr:
538 return c.uintStringCode(typ)
539 }
540 return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
541 }
542
543 func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
544 switch typ.Kind() {
545 case reflect.Map:
546 return c.ptrCode(runtime.PtrTo(typ))
547 default:
548 code, err := c.typeToCodeWithPtr(typ, false)
549 if err != nil {
550 return nil, err
551 }
552 ptr, ok := code.(*PtrCode)
553 if ok {
554 if ptr.value.Kind() == CodeKindMap {
555 ptr.ptrNum++
556 }
557 }
558 return code, nil
559 }
560 }
561
562 func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) {
563 typeptr := uintptr(unsafe.Pointer(typ))
564 if code, exists := c.structTypeToCode[typeptr]; exists {
565 derefCode := *code
566 derefCode.isRecursive = true
567 return &derefCode, nil
568 }
569 indirect := runtime.IfaceIndir(typ)
570 code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect}
571 c.structTypeToCode[typeptr] = code
572
573 fieldNum := typ.NumField()
574 tags := c.typeToStructTags(typ)
575 fields := []*StructFieldCode{}
576 for i, tag := range tags {
577 isOnlyOneFirstField := i == 0 && fieldNum == 1
578 field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField)
579 if err != nil {
580 return nil, err
581 }
582 if field.isAnonymous {
583 structCode := field.getAnonymousStruct()
584 if structCode != nil {
585 structCode.removeFieldsByTags(tags)
586 if c.isAssignableIndirect(field, isPtr) {
587 if indirect {
588 structCode.isIndirect = true
589 } else {
590 structCode.isIndirect = false
591 }
592 }
593 }
594 } else {
595 structCode := field.getStruct()
596 if structCode != nil {
597 if indirect {
598
599 structCode.isIndirect = true
600 } else {
601
602
603
604 structCode.isIndirect = false
605 }
606 }
607 }
608 fields = append(fields, field)
609 }
610 fieldMap := c.getFieldMap(fields)
611 duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap)
612 code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap)
613 if !code.disableIndirectConversion && !indirect && isPtr {
614 code.enableIndirect()
615 }
616 delete(c.structTypeToCode, typeptr)
617 return code, nil
618 }
619
620 func toElemType(t *runtime.Type) *runtime.Type {
621 for t.Kind() == reflect.Ptr {
622 t = t.Elem()
623 }
624 return t
625 }
626
627 func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) {
628 field := tag.Field
629 fieldType := runtime.Type2RType(field.Type)
630 isIndirectSpecialCase := isPtr && isOnlyOneFirstField
631 fieldCode := &StructFieldCode{
632 typ: fieldType,
633 key: tag.Key,
634 tag: tag,
635 offset: field.Offset,
636 isAnonymous: field.Anonymous && !tag.IsTaggedKey && toElemType(fieldType).Kind() == reflect.Struct,
637 isTaggedKey: tag.IsTaggedKey,
638 isNilableType: c.isNilableType(fieldType),
639 isNilCheck: true,
640 }
641 switch {
642 case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase):
643 code, err := c.marshalJSONCode(fieldType)
644 if err != nil {
645 return nil, err
646 }
647 fieldCode.value = code
648 fieldCode.isAddrForMarshaler = true
649 fieldCode.isNilCheck = false
650 structCode.isIndirect = false
651 structCode.disableIndirectConversion = true
652 case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase):
653 code, err := c.marshalTextCode(fieldType)
654 if err != nil {
655 return nil, err
656 }
657 fieldCode.value = code
658 fieldCode.isAddrForMarshaler = true
659 fieldCode.isNilCheck = false
660 structCode.isIndirect = false
661 structCode.disableIndirectConversion = true
662 case isPtr && c.isPtrMarshalJSONType(fieldType):
663
664
665 code, err := c.marshalJSONCode(fieldType)
666 if err != nil {
667 return nil, err
668 }
669 fieldCode.value = code
670 fieldCode.isAddrForMarshaler = true
671 fieldCode.isNilCheck = false
672 case isPtr && c.isPtrMarshalTextType(fieldType):
673
674
675 code, err := c.marshalTextCode(fieldType)
676 if err != nil {
677 return nil, err
678 }
679 fieldCode.value = code
680 fieldCode.isAddrForMarshaler = true
681 fieldCode.isNilCheck = false
682 default:
683 code, err := c.typeToCodeWithPtr(fieldType, isPtr)
684 if err != nil {
685 return nil, err
686 }
687 switch code.Kind() {
688 case CodeKindPtr, CodeKindInterface:
689 fieldCode.isNextOpPtrType = true
690 }
691 fieldCode.value = code
692 }
693 return fieldCode, nil
694 }
695
696 func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool {
697 if isPtr {
698 return false
699 }
700 codeType := fieldCode.value.Kind()
701 if codeType == CodeKindMarshalJSON {
702 return false
703 }
704 if codeType == CodeKindMarshalText {
705 return false
706 }
707 return true
708 }
709
710 func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode {
711 fieldMap := map[string][]*StructFieldCode{}
712 for _, field := range fields {
713 if field.isAnonymous {
714 for k, v := range c.getAnonymousFieldMap(field) {
715 fieldMap[k] = append(fieldMap[k], v...)
716 }
717 continue
718 }
719 fieldMap[field.key] = append(fieldMap[field.key], field)
720 }
721 return fieldMap
722 }
723
724 func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode {
725 fieldMap := map[string][]*StructFieldCode{}
726 structCode := field.getAnonymousStruct()
727 if structCode == nil || structCode.isRecursive {
728 fieldMap[field.key] = append(fieldMap[field.key], field)
729 return fieldMap
730 }
731 for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) {
732 fieldMap[k] = append(fieldMap[k], v...)
733 }
734 return fieldMap
735 }
736
737 func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode {
738 fieldMap := map[string][]*StructFieldCode{}
739 for _, field := range fields {
740 if field.isAnonymous {
741 for k, v := range c.getAnonymousFieldMap(field) {
742
743 for _, vv := range v {
744 vv.isTaggedKey = false
745 }
746 fieldMap[k] = append(fieldMap[k], v...)
747 }
748 continue
749 }
750 fieldMap[field.key] = append(fieldMap[field.key], field)
751 }
752 return fieldMap
753 }
754
755 func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} {
756 duplicatedFieldMap := map[*StructFieldCode]struct{}{}
757 for _, fields := range fieldMap {
758 if len(fields) == 1 {
759 continue
760 }
761 if c.isTaggedKeyOnly(fields) {
762 for _, field := range fields {
763 if field.isTaggedKey {
764 continue
765 }
766 duplicatedFieldMap[field] = struct{}{}
767 }
768 } else {
769 for _, field := range fields {
770 duplicatedFieldMap[field] = struct{}{}
771 }
772 }
773 }
774 return duplicatedFieldMap
775 }
776
777 func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode {
778 filteredFields := make([]*StructFieldCode, 0, len(fields))
779 for _, field := range fields {
780 if field.isAnonymous {
781 structCode := field.getAnonymousStruct()
782 if structCode != nil && !structCode.isRecursive {
783 structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap)
784 if len(structCode.fields) > 0 {
785 filteredFields = append(filteredFields, field)
786 }
787 continue
788 }
789 }
790 if _, exists := duplicatedFieldMap[field]; exists {
791 continue
792 }
793 filteredFields = append(filteredFields, field)
794 }
795 return filteredFields
796 }
797
798 func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool {
799 var taggedKeyFieldCount int
800 for _, field := range fields {
801 if field.isTaggedKey {
802 taggedKeyFieldCount++
803 }
804 }
805 return taggedKeyFieldCount == 1
806 }
807
808 func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags {
809 tags := runtime.StructTags{}
810 fieldNum := typ.NumField()
811 for i := 0; i < fieldNum; i++ {
812 field := typ.Field(i)
813 if runtime.IsIgnoredStructField(field) {
814 continue
815 }
816 tags = append(tags, runtime.StructTagFromField(field))
817 }
818 return tags
819 }
820
821
822
823 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
824 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ)
825 }
826
827
828
829 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
830 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ)
831 }
832
833 func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool {
834 if !c.implementsMarshalJSONType(typ) {
835 return false
836 }
837 if typ.Kind() != reflect.Ptr {
838 return true
839 }
840
841 if !c.implementsMarshalJSONType(typ.Elem()) {
842 return true
843 }
844
845 return false
846 }
847
848 func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool {
849 if !typ.Implements(marshalTextType) {
850 return false
851 }
852 if typ.Kind() != reflect.Ptr {
853 return true
854 }
855
856 if !typ.Elem().Implements(marshalTextType) {
857 return true
858 }
859
860 return false
861 }
862
863 func (c *Compiler) isNilableType(typ *runtime.Type) bool {
864 if !runtime.IfaceIndir(typ) {
865 return true
866 }
867 switch typ.Kind() {
868 case reflect.Ptr:
869 return true
870 case reflect.Map:
871 return true
872 case reflect.Func:
873 return true
874 default:
875 return false
876 }
877 }
878
879 func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool {
880 return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
881 }
882
883 func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool {
884 return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ))
885 }
886
887 func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool {
888 return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType)
889 }
890
891 func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode {
892 codes := code.ToOpcode(ctx)
893 codes.Last().Next = newEndOp(ctx, typ)
894 c.linkRecursiveCode(ctx)
895 return codes.First()
896 }
897
898 func (c *Compiler) linkRecursiveCode(ctx *compileContext) {
899 recursiveCodes := map[uintptr]*CompiledCode{}
900 for _, recursive := range *ctx.recursiveCodes {
901 typeptr := uintptr(unsafe.Pointer(recursive.Type))
902 codes := ctx.structTypeToCodes[typeptr]
903 if recursiveCode, ok := recursiveCodes[typeptr]; ok {
904 *recursive.Jmp = *recursiveCode
905 continue
906 }
907
908 code := copyOpcode(codes.First())
909 code.Op = code.Op.PtrHeadToHead()
910 lastCode := newEndOp(&compileContext{}, recursive.Type)
911 lastCode.Op = OpRecursiveEnd
912
913
914 code.End.Next = lastCode
915
916 totalLength := code.TotalLength()
917
918
919 lastCode.Idx = uint32((totalLength + 1) * uintptrSize)
920 lastCode.ElemIdx = lastCode.Idx + uintptrSize
921 lastCode.Length = lastCode.Idx + 2*uintptrSize
922
923
924 curTotalLength := uintptr(recursive.TotalLength()) + 3
925 nextTotalLength := uintptr(totalLength) + 3
926
927 compiled := recursive.Jmp
928 compiled.Code = code
929 compiled.CurLen = curTotalLength
930 compiled.NextLen = nextTotalLength
931 compiled.Linked = true
932
933 recursiveCodes[typeptr] = compiled
934 }
935 }
936
View as plain text