/* * Copyright 2021 ByteDance Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ast import ( `bufio` `encoding/json` `fmt` `io` `os` `sort` `strings` `testing` `github.com/stretchr/testify/assert` `github.com/stretchr/testify/require` ) type visitorNodeDiffTest struct { t *testing.T str string tracer io.Writer cursor Node stk visitorNodeStack sp uint8 } type visitorNodeStack = [256]struct { Node Node Object map[string]Node Array []Node ObjectKey string } func (self *visitorNodeDiffTest) incrSP() { self.t.Helper() self.sp++ require.NotZero(self.t, self.sp, "stack overflow") } func (self *visitorNodeDiffTest) debugStack() string { var buf strings.Builder buf.WriteString("[") for i := uint8(0); i < self.sp; i++ { if i != 0 { buf.WriteString(", ") } if self.stk[i].Array != nil { buf.WriteString("Array") } else if self.stk[i].Object != nil { buf.WriteString("Object") } else { fmt.Fprintf(&buf, "Key(%q)", self.stk[i].ObjectKey) } } buf.WriteString("]") return buf.String() } func (self *visitorNodeDiffTest) requireType(got int) { self.t.Helper() want := self.cursor.Type() require.EqualValues(self.t, want, got) } func (self *visitorNodeDiffTest) toArrayIndex(array Node, i int) { // set cursor to next Value if existed self.t.Helper() n, err := array.Len() require.NoError(self.t, err) if i < n { self.cursor = *array.Index(i) require.NoError(self.t, self.cursor.Check()) } } func (self *visitorNodeDiffTest) onValueEnd() { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnValueEnd: %s\n", self.debugStack()) } // cursor should point to the Value now self.t.Helper() if self.sp == 0 { if self.tracer != nil { fmt.Fprintf(self.tracer, "EOF\n\n") } return } // [..., Array, sp] if array := self.stk[self.sp-1].Array; array != nil { array = append(array, self.cursor) self.stk[self.sp-1].Array = array self.toArrayIndex(self.stk[self.sp-1].Node, len(array)) return } // [..., Object, ObjectKey, sp] require.GreaterOrEqual(self.t, self.sp, uint8(2)) require.NotNil(self.t, self.stk[self.sp-2].Object) require.Nil(self.t, self.stk[self.sp-1].Object) require.Nil(self.t, self.stk[self.sp-1].Array) self.stk[self.sp-2].Object[self.stk[self.sp-1].ObjectKey] = self.cursor self.cursor = self.stk[self.sp-2].Node // reset cursor to Object self.sp-- // pop ObjectKey } func (self *visitorNodeDiffTest) OnNull() error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnNull\n") } self.requireType(V_NULL) self.onValueEnd() return nil } func (self *visitorNodeDiffTest) OnBool(v bool) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnBool: %t\n", v) } if v { self.requireType(V_TRUE) } else { self.requireType(V_FALSE) } self.onValueEnd() return nil } func (self *visitorNodeDiffTest) OnString(v string) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnString: %q\n", v) } self.requireType(V_STRING) want, err := self.cursor.StrictString() require.NoError(self.t, err) require.EqualValues(self.t, want, v) self.onValueEnd() return nil } func (self *visitorNodeDiffTest) OnInt64(v int64, n json.Number) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnInt64: %d (%q)\n", v, n) } self.requireType(V_NUMBER) want, err := self.cursor.StrictInt64() require.NoError(self.t, err) require.EqualValues(self.t, want, v) nv, err := n.Int64() require.NoError(self.t, err) require.EqualValues(self.t, want, nv) self.onValueEnd() return nil } func (self *visitorNodeDiffTest) OnFloat64(v float64, n json.Number) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnFloat64: %f (%q)\n", v, n) } self.requireType(V_NUMBER) want, err := self.cursor.StrictFloat64() require.NoError(self.t, err) require.EqualValues(self.t, want, v) nv, err := n.Float64() require.NoError(self.t, err) require.EqualValues(self.t, want, nv) self.onValueEnd() return nil } func (self *visitorNodeDiffTest) OnObjectBegin(capacity int) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnObjectBegin: %d\n", capacity) } self.requireType(V_OBJECT) self.stk[self.sp].Node = self.cursor self.stk[self.sp].Object = make(map[string]Node, capacity) self.incrSP() return nil } func (self *visitorNodeDiffTest) OnObjectKey(key string) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnObjectKey: %q %s\n", key, self.debugStack()) } require.NotNil(self.t, self.stk[self.sp-1].Object) node := self.stk[self.sp-1].Node self.stk[self.sp].ObjectKey = key self.incrSP() self.cursor = *node.Get(key) // set cursor to Value require.NoError(self.t, self.cursor.Check()) return nil } func (self *visitorNodeDiffTest) OnObjectEnd() error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnObjectEnd\n") } object := self.stk[self.sp-1].Object require.NotNil(self.t, object) node := self.stk[self.sp-1].Node ps, err := node.unsafeMap() var pairs = make([]Pair, ps.Len()) ps.ToSlice(pairs) require.NoError(self.t, err) keysGot := make([]string, 0, len(object)) for key := range object { keysGot = append(keysGot, key) } keysWant := make([]string, 0, len(pairs)) for _, pair := range pairs { keysWant = append(keysWant, pair.Key) } sort.Strings(keysGot) sort.Strings(keysWant) require.EqualValues(self.t, keysWant, keysGot) for _, pair := range pairs { typeGot := object[pair.Key].Type() typeWant := pair.Value.Type() require.EqualValues(self.t, typeWant, typeGot) } // pop Object self.sp-- self.stk[self.sp].Node = Node{} self.stk[self.sp].Object = nil self.cursor = node // set cursor to this Object self.onValueEnd() return nil } func (self *visitorNodeDiffTest) OnArrayBegin(capacity int) error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnArrayBegin: %d\n", capacity) } self.requireType(V_ARRAY) self.stk[self.sp].Node = self.cursor self.stk[self.sp].Array = make([]Node, 0, capacity) self.incrSP() self.toArrayIndex(self.stk[self.sp-1].Node, 0) return nil } func (self *visitorNodeDiffTest) OnArrayEnd() error { if self.tracer != nil { fmt.Fprintf(self.tracer, "OnArrayEnd\n") } array := self.stk[self.sp-1].Array require.NotNil(self.t, array) node := self.stk[self.sp-1].Node vs, err := node.unsafeArray() require.NoError(self.t, err) var values = make([]Node, vs.Len()) vs.ToSlice(values) require.EqualValues(self.t, len(values), len(array)) for i, n := 0, len(values); i < n; i++ { typeGot := array[i].Type() typeWant := values[i].Type() require.EqualValues(self.t, typeWant, typeGot) } // pop Array self.sp-- self.stk[self.sp].Node = Node{} self.stk[self.sp].Array = nil self.cursor = node // set cursor to this Array self.onValueEnd() return nil } func (self *visitorNodeDiffTest) Run(t *testing.T, str string, tracer io.Writer) { self.t = t self.str = str self.tracer = tracer self.t.Helper() self.cursor = NewRaw(self.str) require.NoError(self.t, self.cursor.LoadAll()) self.stk = visitorNodeStack{} self.sp = 0 require.NoError(self.t, Preorder(self.str, self, nil)) } func TestVisitor_NodeDiff(t *testing.T) { var suite visitorNodeDiffTest newTracer := func(t *testing.T) io.Writer { const EnableTracer = false if !EnableTracer { return nil } basename := strings.ReplaceAll(t.Name(), "/", "_") fp, err := os.Create(fmt.Sprintf("../output/%s.log", basename)) require.NoError(t, err) writer := bufio.NewWriter(fp) t.Cleanup(func() { _ = writer.Flush() _ = fp.Close() }) return writer } t.Run("default", func(t *testing.T) { suite.Run(t, _TwitterJson, newTracer(t)) }) t.Run("issue_case01", func(t *testing.T) { suite.Run(t, `[1193.6419677734375]`, newTracer(t)) }) } type visitorUserNode interface { UserNode() } type ( visitorUserNull struct{} visitorUserBool struct{ Value bool } visitorUserInt64 struct{ Value int64 } visitorUserFloat64 struct{ Value float64 } visitorUserString struct{ Value string } visitorUserObject struct{ Value map[string]visitorUserNode } visitorUserArray struct{ Value []visitorUserNode } ) func (*visitorUserNull) UserNode() {} func (*visitorUserBool) UserNode() {} func (*visitorUserInt64) UserNode() {} func (*visitorUserFloat64) UserNode() {} func (*visitorUserString) UserNode() {} func (*visitorUserObject) UserNode() {} func (*visitorUserArray) UserNode() {} func compareUserNode(tb testing.TB, lhs, rhs visitorUserNode) bool { switch lhs := lhs.(type) { case *visitorUserNull: _, ok := rhs.(*visitorUserNull) return assert.True(tb, ok) case *visitorUserBool: rhs, ok := rhs.(*visitorUserBool) return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) case *visitorUserInt64: rhs, ok := rhs.(*visitorUserInt64) return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) case *visitorUserFloat64: rhs, ok := rhs.(*visitorUserFloat64) return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) case *visitorUserString: rhs, ok := rhs.(*visitorUserString) return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) case *visitorUserObject: rhs, ok := rhs.(*visitorUserObject) if !(assert.True(tb, ok) && assert.Equal(tb, len(lhs.Value), len(rhs.Value))) { return false } for key, lhs := range lhs.Value { rhs, ok := rhs.Value[key] if !(assert.True(tb, ok) && assert.True(tb, compareUserNode(tb, lhs, rhs))) { return false } } return true case *visitorUserArray: rhs, ok := rhs.(*visitorUserArray) if !(assert.True(tb, ok) && assert.Equal(tb, len(lhs.Value), len(rhs.Value))) { return false } for i, n := 0, len(lhs.Value); i < n; i++ { if !assert.True(tb, compareUserNode(tb, lhs.Value[i], rhs.Value[i])) { return false } } return true default: tb.Fatalf("unexpected type of UserNode: %T", lhs) return false } } type visitorUserNodeDecoder interface { Reset() Decode(str string) (visitorUserNode, error) } var _ visitorUserNodeDecoder = (*visitorUserNodeASTDecoder)(nil) type visitorUserNodeASTDecoder struct{} func (self *visitorUserNodeASTDecoder) Reset() {} func (self *visitorUserNodeASTDecoder) Decode(str string) (visitorUserNode, error) { root := NewRaw(str) if err := root.LoadAll(); err != nil { return nil, err } return self.decodeValue(&root) } func (self *visitorUserNodeASTDecoder) decodeValue(root *Node) (visitorUserNode, error) { switch typ := root.Type(); typ { // embed (*Node).Check case V_NONE: return nil, ErrNotExist case V_ERROR: return nil, root case V_NULL: return &visitorUserNull{}, nil case V_TRUE: return &visitorUserBool{Value: true}, nil case V_FALSE: return &visitorUserBool{Value: false}, nil case V_STRING: value, err := root.StrictString() if err != nil { return nil, err } return &visitorUserString{Value: value}, nil case V_NUMBER: value, err := root.StrictNumber() if err != nil { return nil, err } i64, ierr := value.Int64() if ierr == nil { return &visitorUserInt64{Value: i64}, nil } f64, ferr := value.Float64() if ferr == nil { return &visitorUserFloat64{Value: f64}, nil } return nil, fmt.Errorf("invalid number: %v, ierr: %v, ferr: %v", value, ierr, ferr) case V_ARRAY: nodes, err := root.unsafeArray() if err != nil { return nil, err } values := make([]visitorUserNode, nodes.Len()) for i := 0; i