1
16
17 package ast
18
19 import (
20 `bufio`
21 `encoding/json`
22 `fmt`
23 `io`
24 `os`
25 `sort`
26 `strings`
27 `testing`
28
29 `github.com/stretchr/testify/assert`
30 `github.com/stretchr/testify/require`
31 )
32
33 type visitorNodeDiffTest struct {
34 t *testing.T
35 str string
36
37 tracer io.Writer
38
39 cursor Node
40 stk visitorNodeStack
41 sp uint8
42 }
43
44 type visitorNodeStack = [256]struct {
45 Node Node
46 Object map[string]Node
47 Array []Node
48
49 ObjectKey string
50 }
51
52 func (self *visitorNodeDiffTest) incrSP() {
53 self.t.Helper()
54 self.sp++
55 require.NotZero(self.t, self.sp, "stack overflow")
56 }
57
58 func (self *visitorNodeDiffTest) debugStack() string {
59 var buf strings.Builder
60 buf.WriteString("[")
61 for i := uint8(0); i < self.sp; i++ {
62 if i != 0 {
63 buf.WriteString(", ")
64 }
65 if self.stk[i].Array != nil {
66 buf.WriteString("Array")
67 } else if self.stk[i].Object != nil {
68 buf.WriteString("Object")
69 } else {
70 fmt.Fprintf(&buf, "Key(%q)", self.stk[i].ObjectKey)
71 }
72 }
73 buf.WriteString("]")
74 return buf.String()
75 }
76
77 func (self *visitorNodeDiffTest) requireType(got int) {
78 self.t.Helper()
79 want := self.cursor.Type()
80 require.EqualValues(self.t, want, got)
81 }
82
83 func (self *visitorNodeDiffTest) toArrayIndex(array Node, i int) {
84
85 self.t.Helper()
86 n, err := array.Len()
87 require.NoError(self.t, err)
88 if i < n {
89 self.cursor = *array.Index(i)
90 require.NoError(self.t, self.cursor.Check())
91 }
92 }
93
94 func (self *visitorNodeDiffTest) onValueEnd() {
95 if self.tracer != nil {
96 fmt.Fprintf(self.tracer, "OnValueEnd: %s\n", self.debugStack())
97 }
98
99 self.t.Helper()
100 if self.sp == 0 {
101 if self.tracer != nil {
102 fmt.Fprintf(self.tracer, "EOF\n\n")
103 }
104 return
105 }
106
107 if array := self.stk[self.sp-1].Array; array != nil {
108 array = append(array, self.cursor)
109 self.stk[self.sp-1].Array = array
110 self.toArrayIndex(self.stk[self.sp-1].Node, len(array))
111 return
112 }
113
114 require.GreaterOrEqual(self.t, self.sp, uint8(2))
115 require.NotNil(self.t, self.stk[self.sp-2].Object)
116 require.Nil(self.t, self.stk[self.sp-1].Object)
117 require.Nil(self.t, self.stk[self.sp-1].Array)
118 self.stk[self.sp-2].Object[self.stk[self.sp-1].ObjectKey] = self.cursor
119 self.cursor = self.stk[self.sp-2].Node
120 self.sp--
121 }
122
123 func (self *visitorNodeDiffTest) OnNull() error {
124 if self.tracer != nil {
125 fmt.Fprintf(self.tracer, "OnNull\n")
126 }
127 self.requireType(V_NULL)
128 self.onValueEnd()
129 return nil
130 }
131
132 func (self *visitorNodeDiffTest) OnBool(v bool) error {
133 if self.tracer != nil {
134 fmt.Fprintf(self.tracer, "OnBool: %t\n", v)
135 }
136 if v {
137 self.requireType(V_TRUE)
138 } else {
139 self.requireType(V_FALSE)
140 }
141 self.onValueEnd()
142 return nil
143 }
144
145 func (self *visitorNodeDiffTest) OnString(v string) error {
146 if self.tracer != nil {
147 fmt.Fprintf(self.tracer, "OnString: %q\n", v)
148 }
149 self.requireType(V_STRING)
150 want, err := self.cursor.StrictString()
151 require.NoError(self.t, err)
152 require.EqualValues(self.t, want, v)
153 self.onValueEnd()
154 return nil
155 }
156
157 func (self *visitorNodeDiffTest) OnInt64(v int64, n json.Number) error {
158 if self.tracer != nil {
159 fmt.Fprintf(self.tracer, "OnInt64: %d (%q)\n", v, n)
160 }
161 self.requireType(V_NUMBER)
162 want, err := self.cursor.StrictInt64()
163 require.NoError(self.t, err)
164 require.EqualValues(self.t, want, v)
165 nv, err := n.Int64()
166 require.NoError(self.t, err)
167 require.EqualValues(self.t, want, nv)
168 self.onValueEnd()
169 return nil
170 }
171
172 func (self *visitorNodeDiffTest) OnFloat64(v float64, n json.Number) error {
173 if self.tracer != nil {
174 fmt.Fprintf(self.tracer, "OnFloat64: %f (%q)\n", v, n)
175 }
176 self.requireType(V_NUMBER)
177 want, err := self.cursor.StrictFloat64()
178 require.NoError(self.t, err)
179 require.EqualValues(self.t, want, v)
180 nv, err := n.Float64()
181 require.NoError(self.t, err)
182 require.EqualValues(self.t, want, nv)
183 self.onValueEnd()
184 return nil
185 }
186
187 func (self *visitorNodeDiffTest) OnObjectBegin(capacity int) error {
188 if self.tracer != nil {
189 fmt.Fprintf(self.tracer, "OnObjectBegin: %d\n", capacity)
190 }
191 self.requireType(V_OBJECT)
192 self.stk[self.sp].Node = self.cursor
193 self.stk[self.sp].Object = make(map[string]Node, capacity)
194 self.incrSP()
195 return nil
196 }
197
198 func (self *visitorNodeDiffTest) OnObjectKey(key string) error {
199 if self.tracer != nil {
200 fmt.Fprintf(self.tracer, "OnObjectKey: %q %s\n", key, self.debugStack())
201 }
202 require.NotNil(self.t, self.stk[self.sp-1].Object)
203 node := self.stk[self.sp-1].Node
204 self.stk[self.sp].ObjectKey = key
205 self.incrSP()
206 self.cursor = *node.Get(key)
207 require.NoError(self.t, self.cursor.Check())
208 return nil
209 }
210
211 func (self *visitorNodeDiffTest) OnObjectEnd() error {
212 if self.tracer != nil {
213 fmt.Fprintf(self.tracer, "OnObjectEnd\n")
214 }
215 object := self.stk[self.sp-1].Object
216 require.NotNil(self.t, object)
217
218 node := self.stk[self.sp-1].Node
219 ps, err := node.unsafeMap()
220 var pairs = make([]Pair, ps.Len())
221 ps.ToSlice(pairs)
222 require.NoError(self.t, err)
223
224 keysGot := make([]string, 0, len(object))
225 for key := range object {
226 keysGot = append(keysGot, key)
227 }
228 keysWant := make([]string, 0, len(pairs))
229 for _, pair := range pairs {
230 keysWant = append(keysWant, pair.Key)
231 }
232 sort.Strings(keysGot)
233 sort.Strings(keysWant)
234 require.EqualValues(self.t, keysWant, keysGot)
235
236 for _, pair := range pairs {
237 typeGot := object[pair.Key].Type()
238 typeWant := pair.Value.Type()
239 require.EqualValues(self.t, typeWant, typeGot)
240 }
241
242
243 self.sp--
244 self.stk[self.sp].Node = Node{}
245 self.stk[self.sp].Object = nil
246
247 self.cursor = node
248 self.onValueEnd()
249 return nil
250 }
251
252 func (self *visitorNodeDiffTest) OnArrayBegin(capacity int) error {
253 if self.tracer != nil {
254 fmt.Fprintf(self.tracer, "OnArrayBegin: %d\n", capacity)
255 }
256 self.requireType(V_ARRAY)
257 self.stk[self.sp].Node = self.cursor
258 self.stk[self.sp].Array = make([]Node, 0, capacity)
259 self.incrSP()
260 self.toArrayIndex(self.stk[self.sp-1].Node, 0)
261 return nil
262 }
263
264 func (self *visitorNodeDiffTest) OnArrayEnd() error {
265 if self.tracer != nil {
266 fmt.Fprintf(self.tracer, "OnArrayEnd\n")
267 }
268 array := self.stk[self.sp-1].Array
269 require.NotNil(self.t, array)
270
271 node := self.stk[self.sp-1].Node
272 vs, err := node.unsafeArray()
273 require.NoError(self.t, err)
274 var values = make([]Node, vs.Len())
275 vs.ToSlice(values)
276
277 require.EqualValues(self.t, len(values), len(array))
278
279 for i, n := 0, len(values); i < n; i++ {
280 typeGot := array[i].Type()
281 typeWant := values[i].Type()
282 require.EqualValues(self.t, typeWant, typeGot)
283 }
284
285
286 self.sp--
287 self.stk[self.sp].Node = Node{}
288 self.stk[self.sp].Array = nil
289
290 self.cursor = node
291 self.onValueEnd()
292 return nil
293 }
294
295 func (self *visitorNodeDiffTest) Run(t *testing.T, str string,
296 tracer io.Writer) {
297 self.t = t
298 self.str = str
299 self.tracer = tracer
300
301 self.t.Helper()
302
303 self.cursor = NewRaw(self.str)
304 require.NoError(self.t, self.cursor.LoadAll())
305
306 self.stk = visitorNodeStack{}
307 self.sp = 0
308
309 require.NoError(self.t, Preorder(self.str, self, nil))
310 }
311
312 func TestVisitor_NodeDiff(t *testing.T) {
313 var suite visitorNodeDiffTest
314
315 newTracer := func(t *testing.T) io.Writer {
316 const EnableTracer = false
317 if !EnableTracer {
318 return nil
319 }
320 basename := strings.ReplaceAll(t.Name(), "/", "_")
321 fp, err := os.Create(fmt.Sprintf("../output/%s.log", basename))
322 require.NoError(t, err)
323 writer := bufio.NewWriter(fp)
324 t.Cleanup(func() {
325 _ = writer.Flush()
326 _ = fp.Close()
327 })
328 return writer
329 }
330
331 t.Run("default", func(t *testing.T) {
332 suite.Run(t, _TwitterJson, newTracer(t))
333 })
334 t.Run("issue_case01", func(t *testing.T) {
335 suite.Run(t, `[1193.6419677734375]`, newTracer(t))
336 })
337 }
338
339 type visitorUserNode interface {
340 UserNode()
341 }
342
343 type (
344 visitorUserNull struct{}
345 visitorUserBool struct{ Value bool }
346 visitorUserInt64 struct{ Value int64 }
347 visitorUserFloat64 struct{ Value float64 }
348 visitorUserString struct{ Value string }
349 visitorUserObject struct{ Value map[string]visitorUserNode }
350 visitorUserArray struct{ Value []visitorUserNode }
351 )
352
353 func (*visitorUserNull) UserNode() {}
354 func (*visitorUserBool) UserNode() {}
355 func (*visitorUserInt64) UserNode() {}
356 func (*visitorUserFloat64) UserNode() {}
357 func (*visitorUserString) UserNode() {}
358 func (*visitorUserObject) UserNode() {}
359 func (*visitorUserArray) UserNode() {}
360
361 func compareUserNode(tb testing.TB, lhs, rhs visitorUserNode) bool {
362 switch lhs := lhs.(type) {
363 case *visitorUserNull:
364 _, ok := rhs.(*visitorUserNull)
365 return assert.True(tb, ok)
366 case *visitorUserBool:
367 rhs, ok := rhs.(*visitorUserBool)
368 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value)
369 case *visitorUserInt64:
370 rhs, ok := rhs.(*visitorUserInt64)
371 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value)
372 case *visitorUserFloat64:
373 rhs, ok := rhs.(*visitorUserFloat64)
374 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value)
375 case *visitorUserString:
376 rhs, ok := rhs.(*visitorUserString)
377 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value)
378 case *visitorUserObject:
379 rhs, ok := rhs.(*visitorUserObject)
380 if !(assert.True(tb, ok) && assert.Equal(tb, len(lhs.Value), len(rhs.Value))) {
381 return false
382 }
383 for key, lhs := range lhs.Value {
384 rhs, ok := rhs.Value[key]
385 if !(assert.True(tb, ok) && assert.True(tb, compareUserNode(tb, lhs, rhs))) {
386 return false
387 }
388 }
389 return true
390 case *visitorUserArray:
391 rhs, ok := rhs.(*visitorUserArray)
392 if !(assert.True(tb, ok) && assert.Equal(tb, len(lhs.Value), len(rhs.Value))) {
393 return false
394 }
395 for i, n := 0, len(lhs.Value); i < n; i++ {
396 if !assert.True(tb, compareUserNode(tb, lhs.Value[i], rhs.Value[i])) {
397 return false
398 }
399 }
400 return true
401 default:
402 tb.Fatalf("unexpected type of UserNode: %T", lhs)
403 return false
404 }
405 }
406
407 type visitorUserNodeDecoder interface {
408 Reset()
409 Decode(str string) (visitorUserNode, error)
410 }
411
412 var _ visitorUserNodeDecoder = (*visitorUserNodeASTDecoder)(nil)
413
414 type visitorUserNodeASTDecoder struct{}
415
416 func (self *visitorUserNodeASTDecoder) Reset() {}
417
418 func (self *visitorUserNodeASTDecoder) Decode(str string) (visitorUserNode, error) {
419 root := NewRaw(str)
420 if err := root.LoadAll(); err != nil {
421 return nil, err
422 }
423 return self.decodeValue(&root)
424 }
425
426 func (self *visitorUserNodeASTDecoder) decodeValue(root *Node) (visitorUserNode, error) {
427 switch typ := root.Type(); typ {
428
429 case V_NONE:
430 return nil, ErrNotExist
431 case V_ERROR:
432 return nil, root
433
434 case V_NULL:
435 return &visitorUserNull{}, nil
436 case V_TRUE:
437 return &visitorUserBool{Value: true}, nil
438 case V_FALSE:
439 return &visitorUserBool{Value: false}, nil
440
441 case V_STRING:
442 value, err := root.StrictString()
443 if err != nil {
444 return nil, err
445 }
446 return &visitorUserString{Value: value}, nil
447
448 case V_NUMBER:
449 value, err := root.StrictNumber()
450 if err != nil {
451 return nil, err
452 }
453 i64, ierr := value.Int64()
454 if ierr == nil {
455 return &visitorUserInt64{Value: i64}, nil
456 }
457 f64, ferr := value.Float64()
458 if ferr == nil {
459 return &visitorUserFloat64{Value: f64}, nil
460 }
461 return nil, fmt.Errorf("invalid number: %v, ierr: %v, ferr: %v",
462 value, ierr, ferr)
463
464 case V_ARRAY:
465 nodes, err := root.unsafeArray()
466 if err != nil {
467 return nil, err
468 }
469 values := make([]visitorUserNode, nodes.Len())
470 for i := 0; i<nodes.Len(); i++ {
471 n := nodes.At(i)
472 value, err := self.decodeValue(n)
473 if err != nil {
474 return nil, err
475 }
476 values[i] = value
477 }
478 return &visitorUserArray{Value: values}, nil
479
480 case V_OBJECT:
481 pairs, err := root.unsafeMap()
482 if err != nil {
483 return nil, err
484 }
485 values := make(map[string]visitorUserNode, pairs.Len())
486 for i := 0; i < pairs.Len(); i++ {
487 value, err := self.decodeValue(&pairs.At(i).Value)
488 if err != nil {
489 return nil, err
490 }
491 values[pairs.At(i).Key] = value
492 }
493 return &visitorUserObject{Value: values}, nil
494
495 case V_ANY:
496 fallthrough
497 default:
498 return nil, fmt.Errorf("unexpected Node type: %v", typ)
499 }
500 }
501
502 var _ visitorUserNodeDecoder = (*visitorUserNodeVisitorDecoder)(nil)
503
504 type visitorUserNodeVisitorDecoder struct {
505 stk visitorUserNodeStack
506 sp uint8
507 }
508
509 type visitorUserNodeStack = [256]struct {
510 val visitorUserNode
511 obj map[string]visitorUserNode
512 arr []visitorUserNode
513 key string
514 }
515
516 func (self *visitorUserNodeVisitorDecoder) Reset() {
517 self.stk = visitorUserNodeStack{}
518 self.sp = 0
519 }
520
521 func (self *visitorUserNodeVisitorDecoder) Decode(str string) (visitorUserNode, error) {
522 if err := Preorder(str, self, nil); err != nil {
523 return nil, err
524 }
525 return self.result()
526 }
527
528 func (self *visitorUserNodeVisitorDecoder) result() (visitorUserNode, error) {
529 if self.sp != 1 {
530 return nil, fmt.Errorf("incorrect sp: %d", self.sp)
531 }
532 return self.stk[0].val, nil
533 }
534
535 func (self *visitorUserNodeVisitorDecoder) incrSP() error {
536 self.sp++
537 if self.sp == 0 {
538 return fmt.Errorf("reached max depth: %d", len(self.stk))
539 }
540 return nil
541 }
542
543 func (self *visitorUserNodeVisitorDecoder) OnNull() error {
544 self.stk[self.sp].val = &visitorUserNull{}
545 if err := self.incrSP(); err != nil {
546 return err
547 }
548 return self.onValueEnd()
549 }
550
551 func (self *visitorUserNodeVisitorDecoder) OnBool(v bool) error {
552 self.stk[self.sp].val = &visitorUserBool{Value: v}
553 if err := self.incrSP(); err != nil {
554 return err
555 }
556 return self.onValueEnd()
557 }
558
559 func (self *visitorUserNodeVisitorDecoder) OnString(v string) error {
560 self.stk[self.sp].val = &visitorUserString{Value: v}
561 if err := self.incrSP(); err != nil {
562 return err
563 }
564 return self.onValueEnd()
565 }
566
567 func (self *visitorUserNodeVisitorDecoder) OnInt64(v int64, n json.Number) error {
568 self.stk[self.sp].val = &visitorUserInt64{Value: v}
569 if err := self.incrSP(); err != nil {
570 return err
571 }
572 return self.onValueEnd()
573 }
574
575 func (self *visitorUserNodeVisitorDecoder) OnFloat64(v float64, n json.Number) error {
576 self.stk[self.sp].val = &visitorUserFloat64{Value: v}
577 if err := self.incrSP(); err != nil {
578 return err
579 }
580 return self.onValueEnd()
581 }
582
583 func (self *visitorUserNodeVisitorDecoder) OnObjectBegin(capacity int) error {
584 self.stk[self.sp].obj = make(map[string]visitorUserNode, capacity)
585 return self.incrSP()
586 }
587
588 func (self *visitorUserNodeVisitorDecoder) OnObjectKey(key string) error {
589 self.stk[self.sp].key = key
590 return self.incrSP()
591 }
592
593 func (self *visitorUserNodeVisitorDecoder) OnObjectEnd() error {
594 self.stk[self.sp-1].val = &visitorUserObject{Value: self.stk[self.sp-1].obj}
595 self.stk[self.sp-1].obj = nil
596 return self.onValueEnd()
597 }
598
599 func (self *visitorUserNodeVisitorDecoder) OnArrayBegin(capacity int) error {
600 self.stk[self.sp].arr = make([]visitorUserNode, 0, capacity)
601 return self.incrSP()
602 }
603
604 func (self *visitorUserNodeVisitorDecoder) OnArrayEnd() error {
605 self.stk[self.sp-1].val = &visitorUserArray{Value: self.stk[self.sp-1].arr}
606 self.stk[self.sp-1].arr = nil
607 return self.onValueEnd()
608 }
609
610 func (self *visitorUserNodeVisitorDecoder) onValueEnd() error {
611 if self.sp == 1 {
612 return nil
613 }
614
615 if self.stk[self.sp-2].arr != nil {
616 self.stk[self.sp-2].arr = append(self.stk[self.sp-2].arr, self.stk[self.sp-1].val)
617 self.sp--
618 return nil
619 }
620
621 self.stk[self.sp-3].obj[self.stk[self.sp-2].key] = self.stk[self.sp-1].val
622 self.sp -= 2
623 return nil
624 }
625
626 func testUserNodeDiff(t *testing.T, d1, d2 visitorUserNodeDecoder, str string) {
627 t.Helper()
628 d1.Reset()
629 n1, err := d1.Decode(_TwitterJson)
630 require.NoError(t, err)
631
632 d2.Reset()
633 n2, err := d2.Decode(_TwitterJson)
634 require.NoError(t, err)
635
636 require.True(t, compareUserNode(t, n1, n2))
637 }
638
639 func TestVisitor_UserNodeDiff(t *testing.T) {
640 var d1 visitorUserNodeASTDecoder
641 var d2 visitorUserNodeVisitorDecoder
642
643 t.Run("default", func(t *testing.T) {
644 testUserNodeDiff(t, &d1, &d2, _TwitterJson)
645 })
646 t.Run("issue_case01", func(t *testing.T) {
647 testUserNodeDiff(t, &d1, &d2, `[1193.6419677734375]`)
648 })
649 }
650
651 func BenchmarkVisitor_UserNode(b *testing.B) {
652 const str = _TwitterJson
653 b.Run("AST", func(b *testing.B) {
654 var d visitorUserNodeASTDecoder
655 b.ResetTimer()
656 for k := 0; k < b.N; k++ {
657 d.Reset()
658 _, err := d.Decode(str)
659 require.NoError(b, err)
660 b.SetBytes(int64(len(str)))
661 }
662 })
663 b.Run("Visitor", func(b *testing.B) {
664 var d visitorUserNodeVisitorDecoder
665 b.ResetTimer()
666 for k := 0; k < b.N; k++ {
667 d.Reset()
668 _, err := d.Decode(str)
669 require.NoError(b, err)
670 b.SetBytes(int64(len(str)))
671 }
672 })
673 }
674
View as plain text