...

Source file src/github.com/noirbizarre/gonja/exec/evaluate.go

Documentation: github.com/noirbizarre/gonja/exec

     1  package exec
     2  
     3  import (
     4  	"math"
     5  	"reflect"
     6  	"strings"
     7  
     8  	"github.com/pkg/errors"
     9  
    10  	"github.com/noirbizarre/gonja/nodes"
    11  )
    12  
    13  var (
    14  	typeOfValuePtr   = reflect.TypeOf(new(Value))
    15  	typeOfExecCtxPtr = reflect.TypeOf(new(Context))
    16  )
    17  
    18  type Evaluator struct {
    19  	*EvalConfig
    20  	Ctx *Context
    21  }
    22  
    23  func (r *Renderer) Evaluator() *Evaluator {
    24  	return &Evaluator{
    25  		EvalConfig: r.EvalConfig,
    26  		Ctx:        r.Ctx,
    27  	}
    28  }
    29  
    30  func (r *Renderer) Eval(node nodes.Expression) *Value {
    31  	e := r.Evaluator()
    32  	return e.Eval(node)
    33  }
    34  
    35  func (e *Evaluator) Eval(node nodes.Expression) *Value {
    36  	switch n := node.(type) {
    37  	case *nodes.String:
    38  		return AsValue(n.Val)
    39  	case *nodes.Integer:
    40  		return AsValue(n.Val)
    41  	case *nodes.Float:
    42  		return AsValue(n.Val)
    43  	case *nodes.Bool:
    44  		return AsValue(n.Val)
    45  	case *nodes.List:
    46  		return e.evalList(n)
    47  	case *nodes.Tuple:
    48  		return e.evalTuple(n)
    49  	case *nodes.Dict:
    50  		return e.evalDict(n)
    51  	case *nodes.Pair:
    52  		return e.evalPair(n)
    53  	case *nodes.Name:
    54  		return e.evalName(n)
    55  	case *nodes.Call:
    56  		return e.evalCall(n)
    57  	case *nodes.Getitem:
    58  		return e.evalGetitem(n)
    59  	case *nodes.Getattr:
    60  		return e.evalGetattr(n)
    61  	case *nodes.Negation:
    62  		result := e.Eval(n.Term)
    63  		if result.IsError() {
    64  			return result
    65  		}
    66  		return result.Negate()
    67  	case *nodes.BinaryExpression:
    68  		return e.evalBinaryExpression(n)
    69  	case *nodes.UnaryExpression:
    70  		return e.evalUnaryExpression(n)
    71  	case *nodes.FilteredExpression:
    72  		return e.EvaluateFiltered(n)
    73  	case *nodes.TestExpression:
    74  		return e.EvalTest(n)
    75  	default:
    76  		return AsValue(errors.Errorf(`Unknown expression type "%T"`, n))
    77  	}
    78  }
    79  
    80  func (e *Evaluator) evalBinaryExpression(node *nodes.BinaryExpression) *Value {
    81  	var (
    82  		left  *Value
    83  		right *Value
    84  	)
    85  	left = e.Eval(node.Left)
    86  	if left.IsError() {
    87  		return AsValue(errors.Wrapf(left, `Unable to evaluate left parameter %s`, node.Left))
    88  	}
    89  
    90  	switch node.Operator.Token.Val {
    91  	// These operators allow lazy right expression evluation
    92  	case "and", "or":
    93  	default:
    94  		right = e.Eval(node.Right)
    95  		if right.IsError() {
    96  			return AsValue(errors.Wrapf(right, `Unable to evaluate right parameter %s`, node.Right))
    97  		}
    98  	}
    99  
   100  	switch node.Operator.Token.Val {
   101  	case "+":
   102  		if left.IsList() {
   103  			if !right.IsList() {
   104  				return AsValue(errors.Wrapf(right, `Unable to concatenate list to %s`, node.Right))
   105  			}
   106  
   107  			v := &Value{Val: reflect.ValueOf([]interface{}{})}
   108  
   109  			for ix := 0; ix < left.getResolvedValue().Len(); ix ++ {
   110  				v.Val = reflect.Append(v.Val, left.getResolvedValue().Index(ix))
   111  			}
   112  
   113  			for ix := 0; ix < right.getResolvedValue().Len(); ix ++ {
   114  				v.Val = reflect.Append(v.Val, right.getResolvedValue().Index(ix))
   115  			}
   116  
   117  			return v
   118  		}
   119  		if left.IsFloat() || right.IsFloat() {
   120  			// Result will be a float
   121  			return AsValue(left.Float() + right.Float())
   122  		}
   123  		// Result will be an integer
   124  		return AsValue(left.Integer() + right.Integer())
   125  	case "-":
   126  		if left.IsFloat() || right.IsFloat() {
   127  			// Result will be a float
   128  			return AsValue(left.Float() - right.Float())
   129  		}
   130  		// Result will be an integer
   131  		return AsValue(left.Integer() - right.Integer())
   132  	case "*":
   133  		if left.IsFloat() || right.IsFloat() {
   134  			// Result will be float
   135  			return AsValue(left.Float() * right.Float())
   136  		}
   137  		if left.IsString() {
   138  			return AsValue(strings.Repeat(left.String(), right.Integer()))
   139  		}
   140  		// Result will be int
   141  		return AsValue(left.Integer() * right.Integer())
   142  	case "/":
   143  		// Float division
   144  		return AsValue(left.Float() / right.Float())
   145  	case "//":
   146  		// Int division
   147  		return AsValue(int(left.Float() / right.Float()))
   148  	case "%":
   149  		// Result will be int
   150  		return AsValue(left.Integer() % right.Integer())
   151  	case "**":
   152  		return AsValue(math.Pow(left.Float(), right.Float()))
   153  	case "~":
   154  		return AsValue(strings.Join([]string{left.String(), right.String()}, ""))
   155  	case "and":
   156  		if !left.IsTrue() {
   157  			return AsValue(false)
   158  		}
   159  		right = e.Eval(node.Right)
   160  		if right.IsError() {
   161  			return AsValue(errors.Wrapf(right, `Unable to evaluate right parameter %s`, node.Right))
   162  		}
   163  		return AsValue(right.IsTrue())
   164  	case "or":
   165  		if left.IsTrue() {
   166  			return AsValue(true)
   167  		}
   168  		right = e.Eval(node.Right)
   169  		if right.IsError() {
   170  			return AsValue(errors.Wrapf(right, `Unable to evaluate right parameter %s`, node.Right))
   171  		}
   172  		return AsValue(right.IsTrue())
   173  	case "<=":
   174  		if left.IsFloat() || right.IsFloat() {
   175  			return AsValue(left.Float() <= right.Float())
   176  		}
   177  		return AsValue(left.Integer() <= right.Integer())
   178  	case ">=":
   179  		if left.IsFloat() || right.IsFloat() {
   180  			return AsValue(left.Float() >= right.Float())
   181  		}
   182  		return AsValue(left.Integer() >= right.Integer())
   183  	case "==":
   184  		return AsValue(left.EqualValueTo(right))
   185  	case ">":
   186  		if left.IsFloat() || right.IsFloat() {
   187  			return AsValue(left.Float() > right.Float())
   188  		}
   189  		return AsValue(left.Integer() > right.Integer())
   190  	case "<":
   191  		if left.IsFloat() || right.IsFloat() {
   192  			return AsValue(left.Float() < right.Float())
   193  		}
   194  		return AsValue(left.Integer() < right.Integer())
   195  	case "!=", "<>":
   196  		return AsValue(!left.EqualValueTo(right))
   197  	case "in":
   198  		return AsValue(right.Contains(left))
   199  	case "is":
   200  		return nil
   201  	default:
   202  		return AsValue(errors.Errorf(`Unknown operator "%s"`, node.Operator.Token))
   203  	}
   204  }
   205  
   206  func (e *Evaluator) evalUnaryExpression(expr *nodes.UnaryExpression) *Value {
   207  	result := e.Eval(expr.Term)
   208  	if result.IsError() {
   209  		return AsValue(errors.Wrapf(result, `Unable to evaluate term %s`, expr.Term))
   210  	}
   211  	if expr.Negative {
   212  		if result.IsNumber() {
   213  			switch {
   214  			case result.IsFloat():
   215  				return AsValue(-1 * result.Float())
   216  			case result.IsInteger():
   217  				return AsValue(-1 * result.Integer())
   218  			default:
   219  				return AsValue(errors.New("Operation between a number and a non-(float/integer) is not possible"))
   220  			}
   221  		} else {
   222  			return AsValue(errors.Errorf("Negative sign on a non-number expression %s", expr.Position()))
   223  		}
   224  	}
   225  	return result
   226  }
   227  
   228  func (e *Evaluator) evalList(node *nodes.List) *Value {
   229  	values := ValuesList{}
   230  	for _, val := range node.Val {
   231  		value := e.Eval(val)
   232  		values = append(values, value)
   233  	}
   234  	return AsValue(values)
   235  }
   236  
   237  func (e *Evaluator) evalTuple(node *nodes.Tuple) *Value {
   238  	values := ValuesList{}
   239  	for _, val := range node.Val {
   240  		value := e.Eval(val)
   241  		values = append(values, value)
   242  	}
   243  	return AsValue(values)
   244  }
   245  
   246  func (e *Evaluator) evalDict(node *nodes.Dict) *Value {
   247  	pairs := []*Pair{}
   248  	for _, pair := range node.Pairs {
   249  		p := e.evalPair(pair)
   250  		if p.IsError() {
   251  			return AsValue(errors.Wrapf(p, `Unable to evaluate pair "%s"`, pair))
   252  		}
   253  		pairs = append(pairs, p.Interface().(*Pair))
   254  	}
   255  	return AsValue(&Dict{pairs})
   256  }
   257  
   258  func (e *Evaluator) evalPair(node *nodes.Pair) *Value {
   259  	key := e.Eval(node.Key)
   260  	if key.IsError() {
   261  		return AsValue(errors.Wrapf(key, `Unable to evaluate key "%s"`, node.Key))
   262  	}
   263  	value := e.Eval(node.Value)
   264  	if value.IsError() {
   265  		return AsValue(errors.Wrapf(value, `Unable to evaluate value "%s"`, node.Value))
   266  	}
   267  	return AsValue(&Pair{key, value})
   268  }
   269  
   270  func (e *Evaluator) evalName(node *nodes.Name) *Value {
   271  	val := e.Ctx.Get(node.Name.Val)
   272  	return ToValue(val)
   273  }
   274  
   275  func (e *Evaluator) evalGetitem(node *nodes.Getitem) *Value {
   276  	value := e.Eval(node.Node)
   277  	if value.IsError() {
   278  		return AsValue(errors.Wrapf(value, `Unable to evaluate target %s`, node.Node))
   279  	}
   280  
   281  	if node.Arg != "" {
   282  		item, found := value.Getitem(node.Arg)
   283  		if !found {
   284  			item, found = value.Getattr(node.Arg)
   285  		}
   286  		if !found {
   287  			if item.IsError() {
   288  				return AsValue(errors.Wrapf(item, `Unable to evaluate %s`, node))
   289  			}
   290  			return AsValue(nil)
   291  			// return AsValue(errors.Errorf(`Unable to evaluate %s: item '%s' not found`, node, node.Arg))
   292  		}
   293  		return item
   294  	} else {
   295  		item, found := value.Getitem(node.Index)
   296  		if !found {
   297  			if item.IsError() {
   298  				return AsValue(errors.Wrapf(item, `Unable to evaluate %s`, node))
   299  			}
   300  			return AsValue(nil)
   301  			// return AsValue(errors.Errorf(`Unable to evaluate %s: item %d not found`, node, node.Index))
   302  		}
   303  		return item
   304  	}
   305  	return AsValue(errors.Errorf(`Unable to evaluate %s`, node))
   306  }
   307  
   308  func (e *Evaluator) evalGetattr(node *nodes.Getattr) *Value {
   309  	value := e.Eval(node.Node)
   310  	if value.IsError() {
   311  		return AsValue(errors.Wrapf(value, `Unable to evaluate target %s`, node.Node))
   312  	}
   313  
   314  	if node.Attr != "" {
   315  		attr, found := value.Getattr(node.Attr)
   316  		if !found {
   317  			attr, found = value.Getitem(node.Attr)
   318  		}
   319  		if !found {
   320  			if attr.IsError() {
   321  				return AsValue(errors.Wrapf(attr, `Unable to evaluate %s`, node))
   322  			}
   323  			return AsValue(nil)
   324  			// return AsValue(errors.Errorf(`Unable to evaluate %s: attribute '%s' not found`, node, node.Attr))
   325  		}
   326  		return attr
   327  	} else {
   328  		item, found := value.Getitem(node.Index)
   329  		if !found {
   330  			if item.IsError() {
   331  				return AsValue(errors.Wrapf(item, `Unable to evaluate %s`, node))
   332  			}
   333  			return AsValue(nil)
   334  			// return AsValue(errors.Errorf(`Unable to evaluate %s: item %d not found`, node, node.Index))
   335  		}
   336  		return item
   337  	}
   338  	return AsValue(errors.Errorf(`Unable to evaluate %s`, node))
   339  }
   340  
   341  func (e *Evaluator) evalCall(node *nodes.Call) *Value {
   342  	fn := e.Eval(node.Func)
   343  	if fn.IsError() {
   344  		return AsValue(errors.Wrapf(fn, `Unable to evaluate function "%s"`, node.Func))
   345  	}
   346  
   347  	if !fn.IsCallable() {
   348  		return AsValue(errors.Errorf(`%s is not callable`, node.Func))
   349  	}
   350  
   351  	// current := reflect.ValueOf(fn) // Get the initial value
   352  
   353  	var current reflect.Value
   354  	var isSafe bool
   355  
   356  	var params []reflect.Value
   357  	var err error
   358  	t := fn.Val.Type()
   359  
   360  	if t.NumIn() == 1 && t.In(0) == reflect.TypeOf(&VarArgs{}) {
   361  		params, err = e.evalVarArgs(node)
   362  	} else {
   363  		params, err = e.evalParams(node, fn)
   364  	}
   365  	if err != nil {
   366  		return AsValue(errors.Wrapf(err, `Unable to evaluate parameters`))
   367  	}
   368  
   369  	// Call it and get first return parameter back
   370  	values := fn.Val.Call(params)
   371  	rv := values[0]
   372  	if t.NumOut() == 2 {
   373  		e := values[1].Interface()
   374  		if e != nil {
   375  			err, ok := e.(error)
   376  			if !ok {
   377  				return AsValue(errors.Errorf("The second return value is not an error"))
   378  			}
   379  			if err != nil {
   380  				return AsValue(err)
   381  			}
   382  		}
   383  	}
   384  
   385  	if rv.Type() != typeOfValuePtr {
   386  		current = reflect.ValueOf(rv.Interface())
   387  	} else {
   388  		// Return the function call value
   389  		current = rv.Interface().(*Value).Val
   390  		isSafe = rv.Interface().(*Value).Safe
   391  	}
   392  
   393  	if !current.IsValid() {
   394  		// Value is not valid (e. g. NIL value)
   395  		return AsValue(nil)
   396  	}
   397  
   398  	return &Value{Val: current, Safe: isSafe}
   399  }
   400  
   401  func (e *Evaluator) evalVariable(node *nodes.Variable) (*Value, error) {
   402  	var current reflect.Value
   403  	var isSafe bool
   404  
   405  	for idx, part := range node.Parts {
   406  		if idx == 0 {
   407  			val := e.Ctx.Get(node.Parts[0].S)
   408  			current = reflect.ValueOf(val) // Get the initial value
   409  		} else {
   410  			// Next parts, resolve it from current
   411  
   412  			// Before resolving the pointer, let's see if we have a method to call
   413  			// Problem with resolving the pointer is we're changing the receiver
   414  			isFunc := false
   415  			if part.Type == nodes.VarTypeIdent {
   416  				funcValue := current.MethodByName(part.S)
   417  				if funcValue.IsValid() {
   418  					current = funcValue
   419  					isFunc = true
   420  				}
   421  			}
   422  
   423  			if !isFunc {
   424  				// If current a pointer, resolve it
   425  				if current.Kind() == reflect.Ptr {
   426  					current = current.Elem()
   427  					if !current.IsValid() {
   428  						// Value is not valid (anymore)
   429  						return AsValue(nil), nil
   430  					}
   431  				}
   432  
   433  				// Look up which part must be called now
   434  				switch part.Type {
   435  				case nodes.VarTypeInt:
   436  					// Calling an index is only possible for:
   437  					// * slices/arrays/strings
   438  					switch current.Kind() {
   439  					case reflect.String, reflect.Array, reflect.Slice:
   440  						if part.I >= 0 && current.Len() > part.I {
   441  							current = current.Index(part.I)
   442  						} else {
   443  							// In Django, exceeding the length of a list is just empty.
   444  							return AsValue(nil), nil
   445  						}
   446  					default:
   447  						return nil, errors.Errorf("Can't access an index on type %s (variable %s)",
   448  							current.Kind().String(), node.String())
   449  					}
   450  				case nodes.VarTypeIdent:
   451  					// debugging:
   452  					// fmt.Printf("now = %s (kind: %s)\n", part.s, current.Kind().String())
   453  
   454  					// Calling a field or key
   455  					switch current.Kind() {
   456  					case reflect.Struct:
   457  						current = current.FieldByName(part.S)
   458  					case reflect.Map:
   459  						current = current.MapIndex(reflect.ValueOf(part.S))
   460  					default:
   461  						return nil, errors.Errorf("Can't access a field by name on type %s (variable %s)",
   462  							current.Kind().String(), node.String())
   463  					}
   464  				default:
   465  					panic("unimplemented")
   466  				}
   467  			}
   468  		}
   469  
   470  		if !current.IsValid() {
   471  			// Value is not valid (anymore)
   472  			return AsValue(nil), nil
   473  		}
   474  
   475  		// If current is a reflect.ValueOf(gonja.Value), then unpack it
   476  		// Happens in function calls (as a return value) or by injecting
   477  		// into the execution context (e.g. in a for-loop)
   478  		if current.Type() == typeOfValuePtr {
   479  			tmpValue := current.Interface().(*Value)
   480  			current = tmpValue.Val
   481  			isSafe = tmpValue.Safe
   482  		}
   483  
   484  		// Check whether this is an interface and resolve it where required
   485  		if current.Kind() == reflect.Interface {
   486  			current = reflect.ValueOf(current.Interface())
   487  		}
   488  
   489  		// Check if the part is a function call
   490  		if part.IsFunctionCall {
   491  
   492  			var params []reflect.Value
   493  			var err error
   494  			t := current.Type()
   495  
   496  			if t.NumIn() == 1 && t.In(0) == reflect.TypeOf(&VarArgs{}) {
   497  				// params, err = e.evalVarArgs(node, t, part)
   498  			} else {
   499  				// params, err = e.evalParams(node, t, part)
   500  			}
   501  			if err != nil {
   502  				return nil, err
   503  			}
   504  
   505  			// Call it and get first return parameter back
   506  			values := current.Call(params)
   507  			rv := values[0]
   508  			if t.NumOut() == 2 {
   509  				e := values[1].Interface()
   510  				if e != nil {
   511  					err, ok := e.(error)
   512  					if !ok {
   513  						return nil, errors.Errorf("The second return value is not an error")
   514  					}
   515  					if err != nil {
   516  						return nil, err
   517  					}
   518  				}
   519  			}
   520  
   521  			if rv.Type() != typeOfValuePtr {
   522  				current = reflect.ValueOf(rv.Interface())
   523  			} else {
   524  				// Return the function call value
   525  				current = rv.Interface().(*Value).Val
   526  				isSafe = rv.Interface().(*Value).Safe
   527  			}
   528  		}
   529  
   530  		if !current.IsValid() {
   531  			// Value is not valid (e. g. NIL value)
   532  			return AsValue(nil), nil
   533  		}
   534  	}
   535  
   536  	return &Value{Val: current, Safe: isSafe}, nil
   537  }
   538  
   539  func (e *Evaluator) evalVarArgs(node *nodes.Call) ([]reflect.Value, error) {
   540  	params := &VarArgs{
   541  		Args:   []*Value{},
   542  		KwArgs: map[string]*Value{},
   543  	}
   544  	for _, param := range node.Args {
   545  		value := e.Eval(param)
   546  		if value.IsError() {
   547  			return nil, value
   548  		}
   549  		params.Args = append(params.Args, value)
   550  	}
   551  
   552  	for key, param := range node.Kwargs {
   553  		value := e.Eval(param)
   554  		if value.IsError() {
   555  			return nil, value
   556  		}
   557  		params.KwArgs[key] = value
   558  	}
   559  	// va := AsValue(VarArgs{})
   560  	return []reflect.Value{reflect.ValueOf(params)}, nil
   561  }
   562  
   563  func (e *Evaluator) evalParams(node *nodes.Call, fn *Value) ([]reflect.Value, error) {
   564  	args := node.Args
   565  	t := fn.Val.Type()
   566  
   567  	if len(args) != t.NumIn() && !(len(args) >= t.NumIn()-1 && t.IsVariadic()) {
   568  		msg := "Function input argument count (%d) of '%s' must be equal to the calling argument count (%d)."
   569  		return nil, errors.Errorf(msg, t.NumIn(), node.String(), len(args))
   570  	}
   571  
   572  	// Output arguments
   573  	if t.NumOut() != 1 && t.NumOut() != 2 {
   574  		msg := "'%s' must have exactly 1 or 2 output arguments, the second argument must be of type error"
   575  		return nil, errors.Errorf(msg, node.String())
   576  	}
   577  
   578  	// Evaluate all parameters
   579  	var parameters []reflect.Value
   580  
   581  	numArgs := t.NumIn()
   582  	isVariadic := t.IsVariadic()
   583  	var fnArg reflect.Type
   584  
   585  	for idx, arg := range args {
   586  		pv := e.Eval(arg)
   587  		if pv.IsError() {
   588  			return nil, pv
   589  		}
   590  
   591  		if isVariadic {
   592  			if idx >= numArgs-1 {
   593  				fnArg = t.In(numArgs - 1).Elem()
   594  			} else {
   595  				fnArg = t.In(idx)
   596  			}
   597  		} else {
   598  			fnArg = t.In(idx)
   599  		}
   600  
   601  		if fnArg != typeOfValuePtr {
   602  			// Function's argument is not a *gonja.Value, then we have to check whether input argument is of the same type as the function's argument
   603  			if !isVariadic {
   604  				if fnArg != reflect.TypeOf(pv.Interface()) && fnArg.Kind() != reflect.Interface {
   605  					msg := "Function input argument %d of '%s' must be of type %s or *gonja.Value (not %T)."
   606  					return nil, errors.Errorf(msg, idx, node.String(), fnArg.String(), pv.Interface())
   607  				}
   608  				// Function's argument has another type, using the interface-value
   609  				parameters = append(parameters, reflect.ValueOf(pv.Interface()))
   610  			} else {
   611  				if fnArg != reflect.TypeOf(pv.Interface()) && fnArg.Kind() != reflect.Interface {
   612  					msg := "Function variadic input argument of '%s' must be of type %s or *gonja.Value (not %T)."
   613  					return nil, errors.Errorf(msg, node.String(), fnArg.String(), pv.Interface())
   614  				}
   615  				// Function's argument has another type, using the interface-value
   616  				parameters = append(parameters, reflect.ValueOf(pv.Interface()))
   617  			}
   618  		} else {
   619  			// Function's argument is a *gonja.Value
   620  			parameters = append(parameters, reflect.ValueOf(pv))
   621  		}
   622  	}
   623  
   624  	// Check if any of the values are invalid
   625  	for _, p := range parameters {
   626  		if p.Kind() == reflect.Invalid {
   627  			return nil, errors.Errorf("Calling a function using an invalid parameter")
   628  		}
   629  	}
   630  
   631  	return parameters, nil
   632  }
   633  

View as plain text