1 package exec
2
3 import (
4 "math"
5 "reflect"
6 "strings"
7
8 "github.com/pkg/errors"
9
10 "github.com/noirbizarre/gonja/nodes"
11 )
12
13 var (
14 typeOfValuePtr = reflect.TypeOf(new(Value))
15 typeOfExecCtxPtr = reflect.TypeOf(new(Context))
16 )
17
18 type Evaluator struct {
19 *EvalConfig
20 Ctx *Context
21 }
22
23 func (r *Renderer) Evaluator() *Evaluator {
24 return &Evaluator{
25 EvalConfig: r.EvalConfig,
26 Ctx: r.Ctx,
27 }
28 }
29
30 func (r *Renderer) Eval(node nodes.Expression) *Value {
31 e := r.Evaluator()
32 return e.Eval(node)
33 }
34
35 func (e *Evaluator) Eval(node nodes.Expression) *Value {
36 switch n := node.(type) {
37 case *nodes.String:
38 return AsValue(n.Val)
39 case *nodes.Integer:
40 return AsValue(n.Val)
41 case *nodes.Float:
42 return AsValue(n.Val)
43 case *nodes.Bool:
44 return AsValue(n.Val)
45 case *nodes.List:
46 return e.evalList(n)
47 case *nodes.Tuple:
48 return e.evalTuple(n)
49 case *nodes.Dict:
50 return e.evalDict(n)
51 case *nodes.Pair:
52 return e.evalPair(n)
53 case *nodes.Name:
54 return e.evalName(n)
55 case *nodes.Call:
56 return e.evalCall(n)
57 case *nodes.Getitem:
58 return e.evalGetitem(n)
59 case *nodes.Getattr:
60 return e.evalGetattr(n)
61 case *nodes.Negation:
62 result := e.Eval(n.Term)
63 if result.IsError() {
64 return result
65 }
66 return result.Negate()
67 case *nodes.BinaryExpression:
68 return e.evalBinaryExpression(n)
69 case *nodes.UnaryExpression:
70 return e.evalUnaryExpression(n)
71 case *nodes.FilteredExpression:
72 return e.EvaluateFiltered(n)
73 case *nodes.TestExpression:
74 return e.EvalTest(n)
75 default:
76 return AsValue(errors.Errorf(`Unknown expression type "%T"`, n))
77 }
78 }
79
80 func (e *Evaluator) evalBinaryExpression(node *nodes.BinaryExpression) *Value {
81 var (
82 left *Value
83 right *Value
84 )
85 left = e.Eval(node.Left)
86 if left.IsError() {
87 return AsValue(errors.Wrapf(left, `Unable to evaluate left parameter %s`, node.Left))
88 }
89
90 switch node.Operator.Token.Val {
91
92 case "and", "or":
93 default:
94 right = e.Eval(node.Right)
95 if right.IsError() {
96 return AsValue(errors.Wrapf(right, `Unable to evaluate right parameter %s`, node.Right))
97 }
98 }
99
100 switch node.Operator.Token.Val {
101 case "+":
102 if left.IsList() {
103 if !right.IsList() {
104 return AsValue(errors.Wrapf(right, `Unable to concatenate list to %s`, node.Right))
105 }
106
107 v := &Value{Val: reflect.ValueOf([]interface{}{})}
108
109 for ix := 0; ix < left.getResolvedValue().Len(); ix ++ {
110 v.Val = reflect.Append(v.Val, left.getResolvedValue().Index(ix))
111 }
112
113 for ix := 0; ix < right.getResolvedValue().Len(); ix ++ {
114 v.Val = reflect.Append(v.Val, right.getResolvedValue().Index(ix))
115 }
116
117 return v
118 }
119 if left.IsFloat() || right.IsFloat() {
120
121 return AsValue(left.Float() + right.Float())
122 }
123
124 return AsValue(left.Integer() + right.Integer())
125 case "-":
126 if left.IsFloat() || right.IsFloat() {
127
128 return AsValue(left.Float() - right.Float())
129 }
130
131 return AsValue(left.Integer() - right.Integer())
132 case "*":
133 if left.IsFloat() || right.IsFloat() {
134
135 return AsValue(left.Float() * right.Float())
136 }
137 if left.IsString() {
138 return AsValue(strings.Repeat(left.String(), right.Integer()))
139 }
140
141 return AsValue(left.Integer() * right.Integer())
142 case "/":
143
144 return AsValue(left.Float() / right.Float())
145 case "//":
146
147 return AsValue(int(left.Float() / right.Float()))
148 case "%":
149
150 return AsValue(left.Integer() % right.Integer())
151 case "**":
152 return AsValue(math.Pow(left.Float(), right.Float()))
153 case "~":
154 return AsValue(strings.Join([]string{left.String(), right.String()}, ""))
155 case "and":
156 if !left.IsTrue() {
157 return AsValue(false)
158 }
159 right = e.Eval(node.Right)
160 if right.IsError() {
161 return AsValue(errors.Wrapf(right, `Unable to evaluate right parameter %s`, node.Right))
162 }
163 return AsValue(right.IsTrue())
164 case "or":
165 if left.IsTrue() {
166 return AsValue(true)
167 }
168 right = e.Eval(node.Right)
169 if right.IsError() {
170 return AsValue(errors.Wrapf(right, `Unable to evaluate right parameter %s`, node.Right))
171 }
172 return AsValue(right.IsTrue())
173 case "<=":
174 if left.IsFloat() || right.IsFloat() {
175 return AsValue(left.Float() <= right.Float())
176 }
177 return AsValue(left.Integer() <= right.Integer())
178 case ">=":
179 if left.IsFloat() || right.IsFloat() {
180 return AsValue(left.Float() >= right.Float())
181 }
182 return AsValue(left.Integer() >= right.Integer())
183 case "==":
184 return AsValue(left.EqualValueTo(right))
185 case ">":
186 if left.IsFloat() || right.IsFloat() {
187 return AsValue(left.Float() > right.Float())
188 }
189 return AsValue(left.Integer() > right.Integer())
190 case "<":
191 if left.IsFloat() || right.IsFloat() {
192 return AsValue(left.Float() < right.Float())
193 }
194 return AsValue(left.Integer() < right.Integer())
195 case "!=", "<>":
196 return AsValue(!left.EqualValueTo(right))
197 case "in":
198 return AsValue(right.Contains(left))
199 case "is":
200 return nil
201 default:
202 return AsValue(errors.Errorf(`Unknown operator "%s"`, node.Operator.Token))
203 }
204 }
205
206 func (e *Evaluator) evalUnaryExpression(expr *nodes.UnaryExpression) *Value {
207 result := e.Eval(expr.Term)
208 if result.IsError() {
209 return AsValue(errors.Wrapf(result, `Unable to evaluate term %s`, expr.Term))
210 }
211 if expr.Negative {
212 if result.IsNumber() {
213 switch {
214 case result.IsFloat():
215 return AsValue(-1 * result.Float())
216 case result.IsInteger():
217 return AsValue(-1 * result.Integer())
218 default:
219 return AsValue(errors.New("Operation between a number and a non-(float/integer) is not possible"))
220 }
221 } else {
222 return AsValue(errors.Errorf("Negative sign on a non-number expression %s", expr.Position()))
223 }
224 }
225 return result
226 }
227
228 func (e *Evaluator) evalList(node *nodes.List) *Value {
229 values := ValuesList{}
230 for _, val := range node.Val {
231 value := e.Eval(val)
232 values = append(values, value)
233 }
234 return AsValue(values)
235 }
236
237 func (e *Evaluator) evalTuple(node *nodes.Tuple) *Value {
238 values := ValuesList{}
239 for _, val := range node.Val {
240 value := e.Eval(val)
241 values = append(values, value)
242 }
243 return AsValue(values)
244 }
245
246 func (e *Evaluator) evalDict(node *nodes.Dict) *Value {
247 pairs := []*Pair{}
248 for _, pair := range node.Pairs {
249 p := e.evalPair(pair)
250 if p.IsError() {
251 return AsValue(errors.Wrapf(p, `Unable to evaluate pair "%s"`, pair))
252 }
253 pairs = append(pairs, p.Interface().(*Pair))
254 }
255 return AsValue(&Dict{pairs})
256 }
257
258 func (e *Evaluator) evalPair(node *nodes.Pair) *Value {
259 key := e.Eval(node.Key)
260 if key.IsError() {
261 return AsValue(errors.Wrapf(key, `Unable to evaluate key "%s"`, node.Key))
262 }
263 value := e.Eval(node.Value)
264 if value.IsError() {
265 return AsValue(errors.Wrapf(value, `Unable to evaluate value "%s"`, node.Value))
266 }
267 return AsValue(&Pair{key, value})
268 }
269
270 func (e *Evaluator) evalName(node *nodes.Name) *Value {
271 val := e.Ctx.Get(node.Name.Val)
272 return ToValue(val)
273 }
274
275 func (e *Evaluator) evalGetitem(node *nodes.Getitem) *Value {
276 value := e.Eval(node.Node)
277 if value.IsError() {
278 return AsValue(errors.Wrapf(value, `Unable to evaluate target %s`, node.Node))
279 }
280
281 if node.Arg != "" {
282 item, found := value.Getitem(node.Arg)
283 if !found {
284 item, found = value.Getattr(node.Arg)
285 }
286 if !found {
287 if item.IsError() {
288 return AsValue(errors.Wrapf(item, `Unable to evaluate %s`, node))
289 }
290 return AsValue(nil)
291
292 }
293 return item
294 } else {
295 item, found := value.Getitem(node.Index)
296 if !found {
297 if item.IsError() {
298 return AsValue(errors.Wrapf(item, `Unable to evaluate %s`, node))
299 }
300 return AsValue(nil)
301
302 }
303 return item
304 }
305 return AsValue(errors.Errorf(`Unable to evaluate %s`, node))
306 }
307
308 func (e *Evaluator) evalGetattr(node *nodes.Getattr) *Value {
309 value := e.Eval(node.Node)
310 if value.IsError() {
311 return AsValue(errors.Wrapf(value, `Unable to evaluate target %s`, node.Node))
312 }
313
314 if node.Attr != "" {
315 attr, found := value.Getattr(node.Attr)
316 if !found {
317 attr, found = value.Getitem(node.Attr)
318 }
319 if !found {
320 if attr.IsError() {
321 return AsValue(errors.Wrapf(attr, `Unable to evaluate %s`, node))
322 }
323 return AsValue(nil)
324
325 }
326 return attr
327 } else {
328 item, found := value.Getitem(node.Index)
329 if !found {
330 if item.IsError() {
331 return AsValue(errors.Wrapf(item, `Unable to evaluate %s`, node))
332 }
333 return AsValue(nil)
334
335 }
336 return item
337 }
338 return AsValue(errors.Errorf(`Unable to evaluate %s`, node))
339 }
340
341 func (e *Evaluator) evalCall(node *nodes.Call) *Value {
342 fn := e.Eval(node.Func)
343 if fn.IsError() {
344 return AsValue(errors.Wrapf(fn, `Unable to evaluate function "%s"`, node.Func))
345 }
346
347 if !fn.IsCallable() {
348 return AsValue(errors.Errorf(`%s is not callable`, node.Func))
349 }
350
351
352
353 var current reflect.Value
354 var isSafe bool
355
356 var params []reflect.Value
357 var err error
358 t := fn.Val.Type()
359
360 if t.NumIn() == 1 && t.In(0) == reflect.TypeOf(&VarArgs{}) {
361 params, err = e.evalVarArgs(node)
362 } else {
363 params, err = e.evalParams(node, fn)
364 }
365 if err != nil {
366 return AsValue(errors.Wrapf(err, `Unable to evaluate parameters`))
367 }
368
369
370 values := fn.Val.Call(params)
371 rv := values[0]
372 if t.NumOut() == 2 {
373 e := values[1].Interface()
374 if e != nil {
375 err, ok := e.(error)
376 if !ok {
377 return AsValue(errors.Errorf("The second return value is not an error"))
378 }
379 if err != nil {
380 return AsValue(err)
381 }
382 }
383 }
384
385 if rv.Type() != typeOfValuePtr {
386 current = reflect.ValueOf(rv.Interface())
387 } else {
388
389 current = rv.Interface().(*Value).Val
390 isSafe = rv.Interface().(*Value).Safe
391 }
392
393 if !current.IsValid() {
394
395 return AsValue(nil)
396 }
397
398 return &Value{Val: current, Safe: isSafe}
399 }
400
401 func (e *Evaluator) evalVariable(node *nodes.Variable) (*Value, error) {
402 var current reflect.Value
403 var isSafe bool
404
405 for idx, part := range node.Parts {
406 if idx == 0 {
407 val := e.Ctx.Get(node.Parts[0].S)
408 current = reflect.ValueOf(val)
409 } else {
410
411
412
413
414 isFunc := false
415 if part.Type == nodes.VarTypeIdent {
416 funcValue := current.MethodByName(part.S)
417 if funcValue.IsValid() {
418 current = funcValue
419 isFunc = true
420 }
421 }
422
423 if !isFunc {
424
425 if current.Kind() == reflect.Ptr {
426 current = current.Elem()
427 if !current.IsValid() {
428
429 return AsValue(nil), nil
430 }
431 }
432
433
434 switch part.Type {
435 case nodes.VarTypeInt:
436
437
438 switch current.Kind() {
439 case reflect.String, reflect.Array, reflect.Slice:
440 if part.I >= 0 && current.Len() > part.I {
441 current = current.Index(part.I)
442 } else {
443
444 return AsValue(nil), nil
445 }
446 default:
447 return nil, errors.Errorf("Can't access an index on type %s (variable %s)",
448 current.Kind().String(), node.String())
449 }
450 case nodes.VarTypeIdent:
451
452
453
454
455 switch current.Kind() {
456 case reflect.Struct:
457 current = current.FieldByName(part.S)
458 case reflect.Map:
459 current = current.MapIndex(reflect.ValueOf(part.S))
460 default:
461 return nil, errors.Errorf("Can't access a field by name on type %s (variable %s)",
462 current.Kind().String(), node.String())
463 }
464 default:
465 panic("unimplemented")
466 }
467 }
468 }
469
470 if !current.IsValid() {
471
472 return AsValue(nil), nil
473 }
474
475
476
477
478 if current.Type() == typeOfValuePtr {
479 tmpValue := current.Interface().(*Value)
480 current = tmpValue.Val
481 isSafe = tmpValue.Safe
482 }
483
484
485 if current.Kind() == reflect.Interface {
486 current = reflect.ValueOf(current.Interface())
487 }
488
489
490 if part.IsFunctionCall {
491
492 var params []reflect.Value
493 var err error
494 t := current.Type()
495
496 if t.NumIn() == 1 && t.In(0) == reflect.TypeOf(&VarArgs{}) {
497
498 } else {
499
500 }
501 if err != nil {
502 return nil, err
503 }
504
505
506 values := current.Call(params)
507 rv := values[0]
508 if t.NumOut() == 2 {
509 e := values[1].Interface()
510 if e != nil {
511 err, ok := e.(error)
512 if !ok {
513 return nil, errors.Errorf("The second return value is not an error")
514 }
515 if err != nil {
516 return nil, err
517 }
518 }
519 }
520
521 if rv.Type() != typeOfValuePtr {
522 current = reflect.ValueOf(rv.Interface())
523 } else {
524
525 current = rv.Interface().(*Value).Val
526 isSafe = rv.Interface().(*Value).Safe
527 }
528 }
529
530 if !current.IsValid() {
531
532 return AsValue(nil), nil
533 }
534 }
535
536 return &Value{Val: current, Safe: isSafe}, nil
537 }
538
539 func (e *Evaluator) evalVarArgs(node *nodes.Call) ([]reflect.Value, error) {
540 params := &VarArgs{
541 Args: []*Value{},
542 KwArgs: map[string]*Value{},
543 }
544 for _, param := range node.Args {
545 value := e.Eval(param)
546 if value.IsError() {
547 return nil, value
548 }
549 params.Args = append(params.Args, value)
550 }
551
552 for key, param := range node.Kwargs {
553 value := e.Eval(param)
554 if value.IsError() {
555 return nil, value
556 }
557 params.KwArgs[key] = value
558 }
559
560 return []reflect.Value{reflect.ValueOf(params)}, nil
561 }
562
563 func (e *Evaluator) evalParams(node *nodes.Call, fn *Value) ([]reflect.Value, error) {
564 args := node.Args
565 t := fn.Val.Type()
566
567 if len(args) != t.NumIn() && !(len(args) >= t.NumIn()-1 && t.IsVariadic()) {
568 msg := "Function input argument count (%d) of '%s' must be equal to the calling argument count (%d)."
569 return nil, errors.Errorf(msg, t.NumIn(), node.String(), len(args))
570 }
571
572
573 if t.NumOut() != 1 && t.NumOut() != 2 {
574 msg := "'%s' must have exactly 1 or 2 output arguments, the second argument must be of type error"
575 return nil, errors.Errorf(msg, node.String())
576 }
577
578
579 var parameters []reflect.Value
580
581 numArgs := t.NumIn()
582 isVariadic := t.IsVariadic()
583 var fnArg reflect.Type
584
585 for idx, arg := range args {
586 pv := e.Eval(arg)
587 if pv.IsError() {
588 return nil, pv
589 }
590
591 if isVariadic {
592 if idx >= numArgs-1 {
593 fnArg = t.In(numArgs - 1).Elem()
594 } else {
595 fnArg = t.In(idx)
596 }
597 } else {
598 fnArg = t.In(idx)
599 }
600
601 if fnArg != typeOfValuePtr {
602
603 if !isVariadic {
604 if fnArg != reflect.TypeOf(pv.Interface()) && fnArg.Kind() != reflect.Interface {
605 msg := "Function input argument %d of '%s' must be of type %s or *gonja.Value (not %T)."
606 return nil, errors.Errorf(msg, idx, node.String(), fnArg.String(), pv.Interface())
607 }
608
609 parameters = append(parameters, reflect.ValueOf(pv.Interface()))
610 } else {
611 if fnArg != reflect.TypeOf(pv.Interface()) && fnArg.Kind() != reflect.Interface {
612 msg := "Function variadic input argument of '%s' must be of type %s or *gonja.Value (not %T)."
613 return nil, errors.Errorf(msg, node.String(), fnArg.String(), pv.Interface())
614 }
615
616 parameters = append(parameters, reflect.ValueOf(pv.Interface()))
617 }
618 } else {
619
620 parameters = append(parameters, reflect.ValueOf(pv))
621 }
622 }
623
624
625 for _, p := range parameters {
626 if p.Kind() == reflect.Invalid {
627 return nil, errors.Errorf("Calling a function using an invalid parameter")
628 }
629 }
630
631 return parameters, nil
632 }
633
View as plain text