1
2
3
4
5 package template
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/url"
12 "reflect"
13 "strings"
14 "sync"
15 "unicode"
16 "unicode/utf8"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 type FuncMap map[string]any
34
35
36
37
38
39 func builtins() FuncMap {
40 return FuncMap{
41 "and": and,
42 "call": call,
43 "html": HTMLEscaper,
44 "index": index,
45 "slice": slice,
46 "js": JSEscaper,
47 "len": length,
48 "not": not,
49 "or": or,
50 "print": fmt.Sprint,
51 "printf": fmt.Sprintf,
52 "println": fmt.Sprintln,
53 "urlquery": URLQueryEscaper,
54
55
56 "eq": eq,
57 "ge": ge,
58 "gt": gt,
59 "le": le,
60 "lt": lt,
61 "ne": ne,
62 }
63 }
64
65 var builtinFuncsOnce struct {
66 sync.Once
67 v map[string]reflect.Value
68 }
69
70
71
72 func builtinFuncs() map[string]reflect.Value {
73 builtinFuncsOnce.Do(func() {
74 builtinFuncsOnce.v = createValueFuncs(builtins())
75 })
76 return builtinFuncsOnce.v
77 }
78
79
80 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
81 m := make(map[string]reflect.Value)
82 addValueFuncs(m, funcMap)
83 return m
84 }
85
86
87 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
88 for name, fn := range in {
89 if !goodName(name) {
90 panic(fmt.Errorf("function name %q is not a valid identifier", name))
91 }
92 v := reflect.ValueOf(fn)
93 if v.Kind() != reflect.Func {
94 panic("value for " + name + " not a function")
95 }
96 if !goodFunc(v.Type()) {
97 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
98 }
99 out[name] = v
100 }
101 }
102
103
104
105 func addFuncs(out, in FuncMap) {
106 for name, fn := range in {
107 out[name] = fn
108 }
109 }
110
111
112 func goodFunc(typ reflect.Type) bool {
113
114 switch {
115 case typ.NumOut() == 1:
116 return true
117 case typ.NumOut() == 2 && typ.Out(1) == errorType:
118 return true
119 }
120 return false
121 }
122
123
124 func goodName(name string) bool {
125 if name == "" {
126 return false
127 }
128 for i, r := range name {
129 switch {
130 case r == '_':
131 case i == 0 && !unicode.IsLetter(r):
132 return false
133 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
134 return false
135 }
136 }
137 return true
138 }
139
140
141 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
142 if tmpl != nil && tmpl.common != nil {
143 tmpl.muFuncs.RLock()
144 defer tmpl.muFuncs.RUnlock()
145 if fn := tmpl.execFuncs[name]; fn.IsValid() {
146 return fn, false, true
147 }
148 }
149 if fn := builtinFuncs()[name]; fn.IsValid() {
150 return fn, true, true
151 }
152 return reflect.Value{}, false, false
153 }
154
155
156
157 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
158 if !value.IsValid() {
159 if !canBeNil(argType) {
160 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
161 }
162 value = reflect.Zero(argType)
163 }
164 if value.Type().AssignableTo(argType) {
165 return value, nil
166 }
167 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
168 value = value.Convert(argType)
169 return value, nil
170 }
171 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
172 }
173
174 func intLike(typ reflect.Kind) bool {
175 switch typ {
176 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
177 return true
178 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
179 return true
180 }
181 return false
182 }
183
184
185 func indexArg(index reflect.Value, cap int) (int, error) {
186 var x int64
187 switch index.Kind() {
188 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
189 x = index.Int()
190 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
191 x = int64(index.Uint())
192 case reflect.Invalid:
193 return 0, fmt.Errorf("cannot index slice/array with nil")
194 default:
195 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
196 }
197 if x < 0 || int(x) < 0 || int(x) > cap {
198 return 0, fmt.Errorf("index out of range: %d", x)
199 }
200 return int(x), nil
201 }
202
203
204
205
206
207
208 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
209 item = indirectInterface(item)
210 if !item.IsValid() {
211 return reflect.Value{}, fmt.Errorf("index of untyped nil")
212 }
213 for _, index := range indexes {
214 index = indirectInterface(index)
215 var isNil bool
216 if item, isNil = indirect(item); isNil {
217 return reflect.Value{}, fmt.Errorf("index of nil pointer")
218 }
219 switch item.Kind() {
220 case reflect.Array, reflect.Slice, reflect.String:
221 x, err := indexArg(index, item.Len())
222 if err != nil {
223 return reflect.Value{}, err
224 }
225 item = item.Index(x)
226 case reflect.Map:
227 index, err := prepareArg(index, item.Type().Key())
228 if err != nil {
229 return reflect.Value{}, err
230 }
231 if x := item.MapIndex(index); x.IsValid() {
232 item = x
233 } else {
234 item = reflect.Zero(item.Type().Elem())
235 }
236 case reflect.Invalid:
237
238 panic("unreachable")
239 default:
240 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
241 }
242 }
243 return item, nil
244 }
245
246
247
248
249
250
251
252 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
253 item = indirectInterface(item)
254 if !item.IsValid() {
255 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
256 }
257 if len(indexes) > 3 {
258 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
259 }
260 var cap int
261 switch item.Kind() {
262 case reflect.String:
263 if len(indexes) == 3 {
264 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
265 }
266 cap = item.Len()
267 case reflect.Array, reflect.Slice:
268 cap = item.Cap()
269 default:
270 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
271 }
272
273 idx := [3]int{0, item.Len()}
274 for i, index := range indexes {
275 x, err := indexArg(index, cap)
276 if err != nil {
277 return reflect.Value{}, err
278 }
279 idx[i] = x
280 }
281
282 if idx[0] > idx[1] {
283 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
284 }
285 if len(indexes) < 3 {
286 return item.Slice(idx[0], idx[1]), nil
287 }
288
289 if idx[1] > idx[2] {
290 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
291 }
292 return item.Slice3(idx[0], idx[1], idx[2]), nil
293 }
294
295
296
297
298 func length(item reflect.Value) (int, error) {
299 item, isNil := indirect(item)
300 if isNil {
301 return 0, fmt.Errorf("len of nil pointer")
302 }
303 switch item.Kind() {
304 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
305 return item.Len(), nil
306 }
307 return 0, fmt.Errorf("len of type %s", item.Type())
308 }
309
310
311
312
313
314 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
315 fn = indirectInterface(fn)
316 if !fn.IsValid() {
317 return reflect.Value{}, fmt.Errorf("call of nil")
318 }
319 typ := fn.Type()
320 if typ.Kind() != reflect.Func {
321 return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
322 }
323 if !goodFunc(typ) {
324 return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
325 }
326 numIn := typ.NumIn()
327 var dddType reflect.Type
328 if typ.IsVariadic() {
329 if len(args) < numIn-1 {
330 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
331 }
332 dddType = typ.In(numIn - 1).Elem()
333 } else {
334 if len(args) != numIn {
335 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
336 }
337 }
338 argv := make([]reflect.Value, len(args))
339 for i, arg := range args {
340 arg = indirectInterface(arg)
341
342 argType := dddType
343 if !typ.IsVariadic() || i < numIn-1 {
344 argType = typ.In(i)
345 }
346
347 var err error
348 if argv[i], err = prepareArg(arg, argType); err != nil {
349 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
350 }
351 }
352 return safeCall(fn, argv)
353 }
354
355
356
357 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
358 defer func() {
359 if r := recover(); r != nil {
360 if e, ok := r.(error); ok {
361 err = e
362 } else {
363 err = fmt.Errorf("%v", r)
364 }
365 }
366 }()
367 ret := fun.Call(args)
368 if len(ret) == 2 && !ret[1].IsNil() {
369 return ret[0], ret[1].Interface().(error)
370 }
371 return ret[0], nil
372 }
373
374
375
376 func truth(arg reflect.Value) bool {
377 t, _ := isTrue(indirectInterface(arg))
378 return t
379 }
380
381
382
383 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
384 panic("unreachable")
385 }
386
387
388
389 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
390 panic("unreachable")
391 }
392
393
394 func not(arg reflect.Value) bool {
395 return !truth(arg)
396 }
397
398
399
400
401
402 var (
403 errBadComparisonType = errors.New("invalid type for comparison")
404 errBadComparison = errors.New("incompatible types for comparison")
405 errNoComparison = errors.New("missing argument for comparison")
406 )
407
408 type kind int
409
410 const (
411 invalidKind kind = iota
412 boolKind
413 complexKind
414 intKind
415 floatKind
416 stringKind
417 uintKind
418 )
419
420 func basicKind(v reflect.Value) (kind, error) {
421 switch v.Kind() {
422 case reflect.Bool:
423 return boolKind, nil
424 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
425 return intKind, nil
426 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
427 return uintKind, nil
428 case reflect.Float32, reflect.Float64:
429 return floatKind, nil
430 case reflect.Complex64, reflect.Complex128:
431 return complexKind, nil
432 case reflect.String:
433 return stringKind, nil
434 }
435 return invalidKind, errBadComparisonType
436 }
437
438
439 func isNil(v reflect.Value) bool {
440 if !v.IsValid() {
441 return true
442 }
443 switch v.Kind() {
444 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
445 return v.IsNil()
446 }
447 return false
448 }
449
450
451
452 func canCompare(v1, v2 reflect.Value) bool {
453 k1 := v1.Kind()
454 k2 := v2.Kind()
455 if k1 == k2 {
456 return true
457 }
458
459 return k1 == reflect.Invalid || k2 == reflect.Invalid
460 }
461
462
463 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
464 arg1 = indirectInterface(arg1)
465 if len(arg2) == 0 {
466 return false, errNoComparison
467 }
468 k1, _ := basicKind(arg1)
469 for _, arg := range arg2 {
470 arg = indirectInterface(arg)
471 k2, _ := basicKind(arg)
472 truth := false
473 if k1 != k2 {
474
475 switch {
476 case k1 == intKind && k2 == uintKind:
477 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
478 case k1 == uintKind && k2 == intKind:
479 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
480 default:
481 if arg1.IsValid() && arg.IsValid() {
482 return false, errBadComparison
483 }
484 }
485 } else {
486 switch k1 {
487 case boolKind:
488 truth = arg1.Bool() == arg.Bool()
489 case complexKind:
490 truth = arg1.Complex() == arg.Complex()
491 case floatKind:
492 truth = arg1.Float() == arg.Float()
493 case intKind:
494 truth = arg1.Int() == arg.Int()
495 case stringKind:
496 truth = arg1.String() == arg.String()
497 case uintKind:
498 truth = arg1.Uint() == arg.Uint()
499 default:
500 if !canCompare(arg1, arg) {
501 return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
502 }
503 if isNil(arg1) || isNil(arg) {
504 truth = isNil(arg) == isNil(arg1)
505 } else {
506 if !arg.Type().Comparable() {
507 return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
508 }
509 truth = arg1.Interface() == arg.Interface()
510 }
511 }
512 }
513 if truth {
514 return true, nil
515 }
516 }
517 return false, nil
518 }
519
520
521 func ne(arg1, arg2 reflect.Value) (bool, error) {
522
523 equal, err := eq(arg1, arg2)
524 return !equal, err
525 }
526
527
528 func lt(arg1, arg2 reflect.Value) (bool, error) {
529 arg1 = indirectInterface(arg1)
530 k1, err := basicKind(arg1)
531 if err != nil {
532 return false, err
533 }
534 arg2 = indirectInterface(arg2)
535 k2, err := basicKind(arg2)
536 if err != nil {
537 return false, err
538 }
539 truth := false
540 if k1 != k2 {
541
542 switch {
543 case k1 == intKind && k2 == uintKind:
544 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
545 case k1 == uintKind && k2 == intKind:
546 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
547 default:
548 return false, errBadComparison
549 }
550 } else {
551 switch k1 {
552 case boolKind, complexKind:
553 return false, errBadComparisonType
554 case floatKind:
555 truth = arg1.Float() < arg2.Float()
556 case intKind:
557 truth = arg1.Int() < arg2.Int()
558 case stringKind:
559 truth = arg1.String() < arg2.String()
560 case uintKind:
561 truth = arg1.Uint() < arg2.Uint()
562 default:
563 panic("invalid kind")
564 }
565 }
566 return truth, nil
567 }
568
569
570 func le(arg1, arg2 reflect.Value) (bool, error) {
571
572 lessThan, err := lt(arg1, arg2)
573 if lessThan || err != nil {
574 return lessThan, err
575 }
576 return eq(arg1, arg2)
577 }
578
579
580 func gt(arg1, arg2 reflect.Value) (bool, error) {
581
582 lessOrEqual, err := le(arg1, arg2)
583 if err != nil {
584 return false, err
585 }
586 return !lessOrEqual, nil
587 }
588
589
590 func ge(arg1, arg2 reflect.Value) (bool, error) {
591
592 lessThan, err := lt(arg1, arg2)
593 if err != nil {
594 return false, err
595 }
596 return !lessThan, nil
597 }
598
599
600
601 var (
602 htmlQuot = []byte(""")
603 htmlApos = []byte("'")
604 htmlAmp = []byte("&")
605 htmlLt = []byte("<")
606 htmlGt = []byte(">")
607 htmlNull = []byte("\uFFFD")
608 )
609
610
611 func HTMLEscape(w io.Writer, b []byte) {
612 last := 0
613 for i, c := range b {
614 var html []byte
615 switch c {
616 case '\000':
617 html = htmlNull
618 case '"':
619 html = htmlQuot
620 case '\'':
621 html = htmlApos
622 case '&':
623 html = htmlAmp
624 case '<':
625 html = htmlLt
626 case '>':
627 html = htmlGt
628 default:
629 continue
630 }
631 w.Write(b[last:i])
632 w.Write(html)
633 last = i + 1
634 }
635 w.Write(b[last:])
636 }
637
638
639 func HTMLEscapeString(s string) string {
640
641 if !strings.ContainsAny(s, "'\"&<>\000") {
642 return s
643 }
644 var b strings.Builder
645 HTMLEscape(&b, []byte(s))
646 return b.String()
647 }
648
649
650
651 func HTMLEscaper(args ...any) string {
652 return HTMLEscapeString(evalArgs(args))
653 }
654
655
656
657 var (
658 jsLowUni = []byte(`\u00`)
659 hex = []byte("0123456789ABCDEF")
660
661 jsBackslash = []byte(`\\`)
662 jsApos = []byte(`\'`)
663 jsQuot = []byte(`\"`)
664 jsLt = []byte(`\u003C`)
665 jsGt = []byte(`\u003E`)
666 jsAmp = []byte(`\u0026`)
667 jsEq = []byte(`\u003D`)
668 )
669
670
671 func JSEscape(w io.Writer, b []byte) {
672 last := 0
673 for i := 0; i < len(b); i++ {
674 c := b[i]
675
676 if !jsIsSpecial(rune(c)) {
677
678 continue
679 }
680 w.Write(b[last:i])
681
682 if c < utf8.RuneSelf {
683
684
685 switch c {
686 case '\\':
687 w.Write(jsBackslash)
688 case '\'':
689 w.Write(jsApos)
690 case '"':
691 w.Write(jsQuot)
692 case '<':
693 w.Write(jsLt)
694 case '>':
695 w.Write(jsGt)
696 case '&':
697 w.Write(jsAmp)
698 case '=':
699 w.Write(jsEq)
700 default:
701 w.Write(jsLowUni)
702 t, b := c>>4, c&0x0f
703 w.Write(hex[t : t+1])
704 w.Write(hex[b : b+1])
705 }
706 } else {
707
708 r, size := utf8.DecodeRune(b[i:])
709 if unicode.IsPrint(r) {
710 w.Write(b[i : i+size])
711 } else {
712 fmt.Fprintf(w, "\\u%04X", r)
713 }
714 i += size - 1
715 }
716 last = i + 1
717 }
718 w.Write(b[last:])
719 }
720
721
722 func JSEscapeString(s string) string {
723
724 if strings.IndexFunc(s, jsIsSpecial) < 0 {
725 return s
726 }
727 var b strings.Builder
728 JSEscape(&b, []byte(s))
729 return b.String()
730 }
731
732 func jsIsSpecial(r rune) bool {
733 switch r {
734 case '\\', '\'', '"', '<', '>', '&', '=':
735 return true
736 }
737 return r < ' ' || utf8.RuneSelf <= r
738 }
739
740
741
742 func JSEscaper(args ...any) string {
743 return JSEscapeString(evalArgs(args))
744 }
745
746
747
748 func URLQueryEscaper(args ...any) string {
749 return url.QueryEscape(evalArgs(args))
750 }
751
752
753
754
755
756
757
758
759 func evalArgs(args []any) string {
760 ok := false
761 var s string
762
763 if len(args) == 1 {
764 s, ok = args[0].(string)
765 }
766 if !ok {
767 for i, arg := range args {
768 a, ok := printableValue(reflect.ValueOf(arg))
769 if ok {
770 args[i] = a
771 }
772 }
773 s = fmt.Sprint(args...)
774 }
775 return s
776 }
777
View as plain text