1 package validator
2
3 import (
4 "fmt"
5 "reflect"
6 "strings"
7 "sync"
8 "sync/atomic"
9 )
10
11 type tagType uint8
12
13 const (
14 typeDefault tagType = iota
15 typeOmitEmpty
16 typeIsDefault
17 typeNoStructLevel
18 typeStructOnly
19 typeDive
20 typeOr
21 typeKeys
22 typeEndKeys
23 typeOmitNil
24 )
25
26 const (
27 invalidValidation = "Invalid validation tag on field '%s'"
28 undefinedValidation = "Undefined validation function '%s' on field '%s'"
29 keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
30 )
31
32 type structCache struct {
33 lock sync.Mutex
34 m atomic.Value
35 }
36
37 func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
38 c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
39 return
40 }
41
42 func (sc *structCache) Set(key reflect.Type, value *cStruct) {
43 m := sc.m.Load().(map[reflect.Type]*cStruct)
44 nm := make(map[reflect.Type]*cStruct, len(m)+1)
45 for k, v := range m {
46 nm[k] = v
47 }
48 nm[key] = value
49 sc.m.Store(nm)
50 }
51
52 type tagCache struct {
53 lock sync.Mutex
54 m atomic.Value
55 }
56
57 func (tc *tagCache) Get(key string) (c *cTag, found bool) {
58 c, found = tc.m.Load().(map[string]*cTag)[key]
59 return
60 }
61
62 func (tc *tagCache) Set(key string, value *cTag) {
63 m := tc.m.Load().(map[string]*cTag)
64 nm := make(map[string]*cTag, len(m)+1)
65 for k, v := range m {
66 nm[k] = v
67 }
68 nm[key] = value
69 tc.m.Store(nm)
70 }
71
72 type cStruct struct {
73 name string
74 fields []*cField
75 fn StructLevelFuncCtx
76 }
77
78 type cField struct {
79 idx int
80 name string
81 altName string
82 namesEqual bool
83 cTags *cTag
84 }
85
86 type cTag struct {
87 tag string
88 aliasTag string
89 actualAliasTag string
90 param string
91 keys *cTag
92 next *cTag
93 fn FuncCtx
94 typeof tagType
95 hasTag bool
96 hasAlias bool
97 hasParam bool
98 isBlockEnd bool
99 runValidationWhenNil bool
100 }
101
102 func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
103 v.structCache.lock.Lock()
104 defer v.structCache.lock.Unlock()
105
106 typ := current.Type()
107
108
109
110 cs, ok := v.structCache.Get(typ)
111 if ok {
112 return cs
113 }
114
115 cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
116
117 numFields := current.NumField()
118 rules := v.rules[typ]
119
120 var ctag *cTag
121 var fld reflect.StructField
122 var tag string
123 var customName string
124
125 for i := 0; i < numFields; i++ {
126
127 fld = typ.Field(i)
128
129 if !fld.Anonymous && len(fld.PkgPath) > 0 {
130 continue
131 }
132
133 if rtag, ok := rules[fld.Name]; ok {
134 tag = rtag
135 } else {
136 tag = fld.Tag.Get(v.tagName)
137 }
138
139 if tag == skipValidationTag {
140 continue
141 }
142
143 customName = fld.Name
144
145 if v.hasTagNameFunc {
146 name := v.tagNameFunc(fld)
147 if len(name) > 0 {
148 customName = name
149 }
150 }
151
152
153
154
155 if len(tag) > 0 {
156 ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
157 } else {
158
159
160 ctag = new(cTag)
161 }
162
163 cs.fields = append(cs.fields, &cField{
164 idx: i,
165 name: fld.Name,
166 altName: customName,
167 cTags: ctag,
168 namesEqual: fld.Name == customName,
169 })
170 }
171 v.structCache.Set(typ, cs)
172 return cs
173 }
174
175 func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
176 var t string
177 noAlias := len(alias) == 0
178 tags := strings.Split(tag, tagSeparator)
179
180 for i := 0; i < len(tags); i++ {
181 t = tags[i]
182 if noAlias {
183 alias = t
184 }
185
186
187 if tagsVal, found := v.aliases[t]; found {
188 if i == 0 {
189 firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
190 } else {
191 next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
192 current.next, current = next, curr
193
194 }
195 continue
196 }
197
198 var prevTag tagType
199
200 if i == 0 {
201 current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
202 firstCtag = current
203 } else {
204 prevTag = current.typeof
205 current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
206 current = current.next
207 }
208
209 switch t {
210 case diveTag:
211 current.typeof = typeDive
212 continue
213
214 case keysTag:
215 current.typeof = typeKeys
216
217 if i == 0 || prevTag != typeDive {
218 panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
219 }
220
221 current.typeof = typeKeys
222
223
224
225 b := make([]byte, 0, 64)
226
227 i++
228
229 for ; i < len(tags); i++ {
230
231 b = append(b, tags[i]...)
232 b = append(b, ',')
233
234 if tags[i] == endKeysTag {
235 break
236 }
237 }
238
239 current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
240 continue
241
242 case endKeysTag:
243 current.typeof = typeEndKeys
244
245
246
247 if i != len(tags)-1 {
248 panic(keysTagNotDefined)
249 }
250 return
251
252 case omitempty:
253 current.typeof = typeOmitEmpty
254 continue
255
256 case omitnil:
257 current.typeof = typeOmitNil
258 continue
259
260 case structOnlyTag:
261 current.typeof = typeStructOnly
262 continue
263
264 case noStructLevelTag:
265 current.typeof = typeNoStructLevel
266 continue
267
268 default:
269 if t == isdefault {
270 current.typeof = typeIsDefault
271 }
272
273 orVals := strings.Split(t, orSeparator)
274
275 for j := 0; j < len(orVals); j++ {
276 vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
277 if noAlias {
278 alias = vals[0]
279 current.aliasTag = alias
280 } else {
281 current.actualAliasTag = t
282 }
283
284 if j > 0 {
285 current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
286 current = current.next
287 }
288 current.hasParam = len(vals) > 1
289
290 current.tag = vals[0]
291 if len(current.tag) == 0 {
292 panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
293 }
294
295 if wrapper, ok := v.validations[current.tag]; ok {
296 current.fn = wrapper.fn
297 current.runValidationWhenNil = wrapper.runValidatinOnNil
298 } else {
299 panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
300 }
301
302 if len(orVals) > 1 {
303 current.typeof = typeOr
304 }
305
306 if len(vals) > 1 {
307 current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
308 }
309 }
310 current.isBlockEnd = true
311 }
312 }
313 return
314 }
315
316 func (v *Validate) fetchCacheTag(tag string) *cTag {
317
318 ctag, found := v.tagCache.Get(tag)
319 if !found {
320 v.tagCache.lock.Lock()
321 defer v.tagCache.lock.Unlock()
322
323
324
325 ctag, found = v.tagCache.Get(tag)
326 if !found {
327 ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
328 v.tagCache.Set(tag, ctag)
329 }
330 }
331 return ctag
332 }
333
View as plain text