1 package validator
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "reflect"
8 "strings"
9 "sync"
10 "time"
11
12 ut "github.com/go-playground/universal-translator"
13 )
14
15 const (
16 defaultTagName = "validate"
17 utf8HexComma = "0x2C"
18 utf8Pipe = "0x7C"
19 tagSeparator = ","
20 orSeparator = "|"
21 tagKeySeparator = "="
22 structOnlyTag = "structonly"
23 noStructLevelTag = "nostructlevel"
24 omitempty = "omitempty"
25 omitnil = "omitnil"
26 isdefault = "isdefault"
27 requiredWithoutAllTag = "required_without_all"
28 requiredWithoutTag = "required_without"
29 requiredWithTag = "required_with"
30 requiredWithAllTag = "required_with_all"
31 requiredIfTag = "required_if"
32 requiredUnlessTag = "required_unless"
33 skipUnlessTag = "skip_unless"
34 excludedWithoutAllTag = "excluded_without_all"
35 excludedWithoutTag = "excluded_without"
36 excludedWithTag = "excluded_with"
37 excludedWithAllTag = "excluded_with_all"
38 excludedIfTag = "excluded_if"
39 excludedUnlessTag = "excluded_unless"
40 skipValidationTag = "-"
41 diveTag = "dive"
42 keysTag = "keys"
43 endKeysTag = "endkeys"
44 requiredTag = "required"
45 namespaceSeparator = "."
46 leftBracket = "["
47 rightBracket = "]"
48 restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}"
49 restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
50 restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
51 )
52
53 var (
54 timeDurationType = reflect.TypeOf(time.Duration(0))
55 timeType = reflect.TypeOf(time.Time{})
56
57 byteSliceType = reflect.TypeOf([]byte{})
58
59 defaultCField = &cField{namesEqual: true}
60 )
61
62
63
64
65
66 type FilterFunc func(ns []byte) bool
67
68
69
70
71 type CustomTypeFunc func(field reflect.Value) interface{}
72
73
74 type TagNameFunc func(field reflect.StructField) string
75
76 type internalValidationFuncWrapper struct {
77 fn FuncCtx
78 runValidatinOnNil bool
79 }
80
81
82 type Validate struct {
83 tagName string
84 pool *sync.Pool
85 tagNameFunc TagNameFunc
86 structLevelFuncs map[reflect.Type]StructLevelFuncCtx
87 customFuncs map[reflect.Type]CustomTypeFunc
88 aliases map[string]string
89 validations map[string]internalValidationFuncWrapper
90 transTagFunc map[ut.Translator]map[string]TranslationFunc
91 rules map[reflect.Type]map[string]string
92 tagCache *tagCache
93 structCache *structCache
94 hasCustomFuncs bool
95 hasTagNameFunc bool
96 requiredStructEnabled bool
97 }
98
99
100
101
102
103
104 func New(options ...Option) *Validate {
105
106 tc := new(tagCache)
107 tc.m.Store(make(map[string]*cTag))
108
109 sc := new(structCache)
110 sc.m.Store(make(map[reflect.Type]*cStruct))
111
112 v := &Validate{
113 tagName: defaultTagName,
114 aliases: make(map[string]string, len(bakedInAliases)),
115 validations: make(map[string]internalValidationFuncWrapper, len(bakedInValidators)),
116 tagCache: tc,
117 structCache: sc,
118 }
119
120
121 for k, val := range bakedInAliases {
122 v.RegisterAlias(k, val)
123 }
124
125
126 for k, val := range bakedInValidators {
127
128 switch k {
129
130 case requiredIfTag, requiredUnlessTag, requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag,
131 excludedIfTag, excludedUnlessTag, excludedWithTag, excludedWithAllTag, excludedWithoutTag, excludedWithoutAllTag,
132 skipUnlessTag:
133 _ = v.registerValidation(k, wrapFunc(val), true, true)
134 default:
135
136 _ = v.registerValidation(k, wrapFunc(val), true, false)
137 }
138 }
139
140 v.pool = &sync.Pool{
141 New: func() interface{} {
142 return &validate{
143 v: v,
144 ns: make([]byte, 0, 64),
145 actualNs: make([]byte, 0, 64),
146 misc: make([]byte, 32),
147 }
148 },
149 }
150
151 for _, o := range options {
152 o(v)
153 }
154 return v
155 }
156
157
158 func (v *Validate) SetTagName(name string) {
159 v.tagName = name
160 }
161
162
163
164 func (v Validate) ValidateMapCtx(ctx context.Context, data map[string]interface{}, rules map[string]interface{}) map[string]interface{} {
165 errs := make(map[string]interface{})
166 for field, rule := range rules {
167 if ruleObj, ok := rule.(map[string]interface{}); ok {
168 if dataObj, ok := data[field].(map[string]interface{}); ok {
169 err := v.ValidateMapCtx(ctx, dataObj, ruleObj)
170 if len(err) > 0 {
171 errs[field] = err
172 }
173 } else if dataObjs, ok := data[field].([]map[string]interface{}); ok {
174 for _, obj := range dataObjs {
175 err := v.ValidateMapCtx(ctx, obj, ruleObj)
176 if len(err) > 0 {
177 errs[field] = err
178 }
179 }
180 } else {
181 errs[field] = errors.New("The field: '" + field + "' is not a map to dive")
182 }
183 } else if ruleStr, ok := rule.(string); ok {
184 err := v.VarCtx(ctx, data[field], ruleStr)
185 if err != nil {
186 errs[field] = err
187 }
188 }
189 }
190 return errs
191 }
192
193
194 func (v *Validate) ValidateMap(data map[string]interface{}, rules map[string]interface{}) map[string]interface{} {
195 return v.ValidateMapCtx(context.Background(), data, rules)
196 }
197
198
199
200
201
202
203
204
205
206
207
208
209
210 func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
211 v.tagNameFunc = fn
212 v.hasTagNameFunc = true
213 }
214
215
216
217
218
219
220 func (v *Validate) RegisterValidation(tag string, fn Func, callValidationEvenIfNull ...bool) error {
221 return v.RegisterValidationCtx(tag, wrapFunc(fn), callValidationEvenIfNull...)
222 }
223
224
225
226 func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx, callValidationEvenIfNull ...bool) error {
227 var nilCheckable bool
228 if len(callValidationEvenIfNull) > 0 {
229 nilCheckable = callValidationEvenIfNull[0]
230 }
231 return v.registerValidation(tag, fn, false, nilCheckable)
232 }
233
234 func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool, nilCheckable bool) error {
235 if len(tag) == 0 {
236 return errors.New("function Key cannot be empty")
237 }
238
239 if fn == nil {
240 return errors.New("function cannot be empty")
241 }
242
243 _, ok := restrictedTags[tag]
244 if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) {
245 panic(fmt.Sprintf(restrictedTagErr, tag))
246 }
247 v.validations[tag] = internalValidationFuncWrapper{fn: fn, runValidatinOnNil: nilCheckable}
248 return nil
249 }
250
251
252
253
254
255
256 func (v *Validate) RegisterAlias(alias, tags string) {
257
258 _, ok := restrictedTags[alias]
259
260 if ok || strings.ContainsAny(alias, restrictedTagChars) {
261 panic(fmt.Sprintf(restrictedAliasErr, alias))
262 }
263
264 v.aliases[alias] = tags
265 }
266
267
268
269
270
271 func (v *Validate) RegisterStructValidation(fn StructLevelFunc, types ...interface{}) {
272 v.RegisterStructValidationCtx(wrapStructLevelFunc(fn), types...)
273 }
274
275
276
277
278
279
280 func (v *Validate) RegisterStructValidationCtx(fn StructLevelFuncCtx, types ...interface{}) {
281
282 if v.structLevelFuncs == nil {
283 v.structLevelFuncs = make(map[reflect.Type]StructLevelFuncCtx)
284 }
285
286 for _, t := range types {
287 tv := reflect.ValueOf(t)
288 if tv.Kind() == reflect.Ptr {
289 t = reflect.Indirect(tv).Interface()
290 }
291
292 v.structLevelFuncs[reflect.TypeOf(t)] = fn
293 }
294 }
295
296
297
298
299
300 func (v *Validate) RegisterStructValidationMapRules(rules map[string]string, types ...interface{}) {
301 if v.rules == nil {
302 v.rules = make(map[reflect.Type]map[string]string)
303 }
304
305 deepCopyRules := make(map[string]string)
306 for i, rule := range rules {
307 deepCopyRules[i] = rule
308 }
309
310 for _, t := range types {
311 typ := reflect.TypeOf(t)
312
313 if typ.Kind() == reflect.Ptr {
314 typ = typ.Elem()
315 }
316
317 if typ.Kind() != reflect.Struct {
318 continue
319 }
320 v.rules[typ] = deepCopyRules
321 }
322 }
323
324
325
326
327 func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) {
328
329 if v.customFuncs == nil {
330 v.customFuncs = make(map[reflect.Type]CustomTypeFunc)
331 }
332
333 for _, t := range types {
334 v.customFuncs[reflect.TypeOf(t)] = fn
335 }
336
337 v.hasCustomFuncs = true
338 }
339
340
341 func (v *Validate) RegisterTranslation(tag string, trans ut.Translator, registerFn RegisterTranslationsFunc, translationFn TranslationFunc) (err error) {
342
343 if v.transTagFunc == nil {
344 v.transTagFunc = make(map[ut.Translator]map[string]TranslationFunc)
345 }
346
347 if err = registerFn(trans); err != nil {
348 return
349 }
350
351 m, ok := v.transTagFunc[trans]
352 if !ok {
353 m = make(map[string]TranslationFunc)
354 v.transTagFunc[trans] = m
355 }
356
357 m[tag] = translationFn
358
359 return
360 }
361
362
363
364
365
366 func (v *Validate) Struct(s interface{}) error {
367 return v.StructCtx(context.Background(), s)
368 }
369
370
371
372
373
374
375 func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) {
376
377 val := reflect.ValueOf(s)
378 top := val
379
380 if val.Kind() == reflect.Ptr && !val.IsNil() {
381 val = val.Elem()
382 }
383
384 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
385 return &InvalidValidationError{Type: reflect.TypeOf(s)}
386 }
387
388
389 vd := v.pool.Get().(*validate)
390 vd.top = top
391 vd.isPartial = false
392
393
394 vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
395
396 if len(vd.errs) > 0 {
397 err = vd.errs
398 vd.errs = nil
399 }
400
401 v.pool.Put(vd)
402
403 return
404 }
405
406
407
408
409
410
411 func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) error {
412 return v.StructFilteredCtx(context.Background(), s, fn)
413 }
414
415
416
417
418
419
420
421 func (v *Validate) StructFilteredCtx(ctx context.Context, s interface{}, fn FilterFunc) (err error) {
422 val := reflect.ValueOf(s)
423 top := val
424
425 if val.Kind() == reflect.Ptr && !val.IsNil() {
426 val = val.Elem()
427 }
428
429 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
430 return &InvalidValidationError{Type: reflect.TypeOf(s)}
431 }
432
433
434 vd := v.pool.Get().(*validate)
435 vd.top = top
436 vd.isPartial = true
437 vd.ffn = fn
438
439
440 vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
441
442 if len(vd.errs) > 0 {
443 err = vd.errs
444 vd.errs = nil
445 }
446
447 v.pool.Put(vd)
448
449 return
450 }
451
452
453
454
455
456
457
458 func (v *Validate) StructPartial(s interface{}, fields ...string) error {
459 return v.StructPartialCtx(context.Background(), s, fields...)
460 }
461
462
463
464
465
466
467
468
469 func (v *Validate) StructPartialCtx(ctx context.Context, s interface{}, fields ...string) (err error) {
470 val := reflect.ValueOf(s)
471 top := val
472
473 if val.Kind() == reflect.Ptr && !val.IsNil() {
474 val = val.Elem()
475 }
476
477 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
478 return &InvalidValidationError{Type: reflect.TypeOf(s)}
479 }
480
481
482 vd := v.pool.Get().(*validate)
483 vd.top = top
484 vd.isPartial = true
485 vd.ffn = nil
486 vd.hasExcludes = false
487 vd.includeExclude = make(map[string]struct{})
488
489 typ := val.Type()
490 name := typ.Name()
491
492 for _, k := range fields {
493
494 flds := strings.Split(k, namespaceSeparator)
495 if len(flds) > 0 {
496
497 vd.misc = append(vd.misc[0:0], name...)
498
499 if len(vd.misc) != 0 {
500 vd.misc = append(vd.misc, '.')
501 }
502
503 for _, s := range flds {
504
505 idx := strings.Index(s, leftBracket)
506
507 if idx != -1 {
508 for idx != -1 {
509 vd.misc = append(vd.misc, s[:idx]...)
510 vd.includeExclude[string(vd.misc)] = struct{}{}
511
512 idx2 := strings.Index(s, rightBracket)
513 idx2++
514 vd.misc = append(vd.misc, s[idx:idx2]...)
515 vd.includeExclude[string(vd.misc)] = struct{}{}
516 s = s[idx2:]
517 idx = strings.Index(s, leftBracket)
518 }
519 } else {
520
521 vd.misc = append(vd.misc, s...)
522 vd.includeExclude[string(vd.misc)] = struct{}{}
523 }
524
525 vd.misc = append(vd.misc, '.')
526 }
527 }
528 }
529
530 vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
531
532 if len(vd.errs) > 0 {
533 err = vd.errs
534 vd.errs = nil
535 }
536
537 v.pool.Put(vd)
538
539 return
540 }
541
542
543
544
545
546
547
548 func (v *Validate) StructExcept(s interface{}, fields ...string) error {
549 return v.StructExceptCtx(context.Background(), s, fields...)
550 }
551
552
553
554
555
556
557
558
559 func (v *Validate) StructExceptCtx(ctx context.Context, s interface{}, fields ...string) (err error) {
560 val := reflect.ValueOf(s)
561 top := val
562
563 if val.Kind() == reflect.Ptr && !val.IsNil() {
564 val = val.Elem()
565 }
566
567 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
568 return &InvalidValidationError{Type: reflect.TypeOf(s)}
569 }
570
571
572 vd := v.pool.Get().(*validate)
573 vd.top = top
574 vd.isPartial = true
575 vd.ffn = nil
576 vd.hasExcludes = true
577 vd.includeExclude = make(map[string]struct{})
578
579 typ := val.Type()
580 name := typ.Name()
581
582 for _, key := range fields {
583
584 vd.misc = vd.misc[0:0]
585
586 if len(name) > 0 {
587 vd.misc = append(vd.misc, name...)
588 vd.misc = append(vd.misc, '.')
589 }
590
591 vd.misc = append(vd.misc, key...)
592 vd.includeExclude[string(vd.misc)] = struct{}{}
593 }
594
595 vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
596
597 if len(vd.errs) > 0 {
598 err = vd.errs
599 vd.errs = nil
600 }
601
602 v.pool.Put(vd)
603
604 return
605 }
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620 func (v *Validate) Var(field interface{}, tag string) error {
621 return v.VarCtx(context.Background(), field, tag)
622 }
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638 func (v *Validate) VarCtx(ctx context.Context, field interface{}, tag string) (err error) {
639 if len(tag) == 0 || tag == skipValidationTag {
640 return nil
641 }
642
643 ctag := v.fetchCacheTag(tag)
644
645 val := reflect.ValueOf(field)
646 vd := v.pool.Get().(*validate)
647 vd.top = val
648 vd.isPartial = false
649 vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
650
651 if len(vd.errs) > 0 {
652 err = vd.errs
653 vd.errs = nil
654 }
655 v.pool.Put(vd)
656 return
657 }
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673 func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) error {
674 return v.VarWithValueCtx(context.Background(), field, other, tag)
675 }
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692 func (v *Validate) VarWithValueCtx(ctx context.Context, field interface{}, other interface{}, tag string) (err error) {
693 if len(tag) == 0 || tag == skipValidationTag {
694 return nil
695 }
696 ctag := v.fetchCacheTag(tag)
697 otherVal := reflect.ValueOf(other)
698 vd := v.pool.Get().(*validate)
699 vd.top = otherVal
700 vd.isPartial = false
701 vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
702
703 if len(vd.errs) > 0 {
704 err = vd.errs
705 vd.errs = nil
706 }
707 v.pool.Put(vd)
708 return
709 }
710
View as plain text