1
2
3
4
5
6 package cookiejar
7
8 import (
9 "errors"
10 "fmt"
11 "net"
12 "net/http"
13 "net/http/internal/ascii"
14 "net/url"
15 "sort"
16 "strings"
17 "sync"
18 "time"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 type PublicSuffixList interface {
36
37
38
39
40
41 PublicSuffix(domain string) string
42
43
44
45
46 String() string
47 }
48
49
50 type Options struct {
51
52
53
54
55
56
57 PublicSuffixList PublicSuffixList
58 }
59
60
61 type Jar struct {
62 psList PublicSuffixList
63
64
65 mu sync.Mutex
66
67
68
69 entries map[string]map[string]entry
70
71
72
73 nextSeqNum uint64
74 }
75
76
77
78 func New(o *Options) (*Jar, error) {
79 jar := &Jar{
80 entries: make(map[string]map[string]entry),
81 }
82 if o != nil {
83 jar.psList = o.PublicSuffixList
84 }
85 return jar, nil
86 }
87
88
89
90
91
92 type entry struct {
93 Name string
94 Value string
95 Domain string
96 Path string
97 SameSite string
98 Secure bool
99 HttpOnly bool
100 Persistent bool
101 HostOnly bool
102 Expires time.Time
103 Creation time.Time
104 LastAccess time.Time
105
106
107
108
109 seqNum uint64
110 }
111
112
113 func (e *entry) id() string {
114 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
115 }
116
117
118
119
120 func (e *entry) shouldSend(https bool, host, path string) bool {
121 return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
122 }
123
124
125
126
127 func (e *entry) domainMatch(host string) bool {
128 if e.Domain == host {
129 return true
130 }
131 return !e.HostOnly && hasDotSuffix(host, e.Domain)
132 }
133
134
135 func (e *entry) pathMatch(requestPath string) bool {
136 if requestPath == e.Path {
137 return true
138 }
139 if strings.HasPrefix(requestPath, e.Path) {
140 if e.Path[len(e.Path)-1] == '/' {
141 return true
142 } else if requestPath[len(e.Path)] == '/' {
143 return true
144 }
145 }
146 return false
147 }
148
149
150 func hasDotSuffix(s, suffix string) bool {
151 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
152 }
153
154
155
156
157 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
158 return j.cookies(u, time.Now())
159 }
160
161
162 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
163 if u.Scheme != "http" && u.Scheme != "https" {
164 return cookies
165 }
166 host, err := canonicalHost(u.Host)
167 if err != nil {
168 return cookies
169 }
170 key := jarKey(host, j.psList)
171
172 j.mu.Lock()
173 defer j.mu.Unlock()
174
175 submap := j.entries[key]
176 if submap == nil {
177 return cookies
178 }
179
180 https := u.Scheme == "https"
181 path := u.Path
182 if path == "" {
183 path = "/"
184 }
185
186 modified := false
187 var selected []entry
188 for id, e := range submap {
189 if e.Persistent && !e.Expires.After(now) {
190 delete(submap, id)
191 modified = true
192 continue
193 }
194 if !e.shouldSend(https, host, path) {
195 continue
196 }
197 e.LastAccess = now
198 submap[id] = e
199 selected = append(selected, e)
200 modified = true
201 }
202 if modified {
203 if len(submap) == 0 {
204 delete(j.entries, key)
205 } else {
206 j.entries[key] = submap
207 }
208 }
209
210
211
212 sort.Slice(selected, func(i, j int) bool {
213 s := selected
214 if len(s[i].Path) != len(s[j].Path) {
215 return len(s[i].Path) > len(s[j].Path)
216 }
217 if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 {
218 return ret < 0
219 }
220 return s[i].seqNum < s[j].seqNum
221 })
222 for _, e := range selected {
223 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
224 }
225
226 return cookies
227 }
228
229
230
231
232 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
233 j.setCookies(u, cookies, time.Now())
234 }
235
236
237 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
238 if len(cookies) == 0 {
239 return
240 }
241 if u.Scheme != "http" && u.Scheme != "https" {
242 return
243 }
244 host, err := canonicalHost(u.Host)
245 if err != nil {
246 return
247 }
248 key := jarKey(host, j.psList)
249 defPath := defaultPath(u.Path)
250
251 j.mu.Lock()
252 defer j.mu.Unlock()
253
254 submap := j.entries[key]
255
256 modified := false
257 for _, cookie := range cookies {
258 e, remove, err := j.newEntry(cookie, now, defPath, host)
259 if err != nil {
260 continue
261 }
262 id := e.id()
263 if remove {
264 if submap != nil {
265 if _, ok := submap[id]; ok {
266 delete(submap, id)
267 modified = true
268 }
269 }
270 continue
271 }
272 if submap == nil {
273 submap = make(map[string]entry)
274 }
275
276 if old, ok := submap[id]; ok {
277 e.Creation = old.Creation
278 e.seqNum = old.seqNum
279 } else {
280 e.Creation = now
281 e.seqNum = j.nextSeqNum
282 j.nextSeqNum++
283 }
284 e.LastAccess = now
285 submap[id] = e
286 modified = true
287 }
288
289 if modified {
290 if len(submap) == 0 {
291 delete(j.entries, key)
292 } else {
293 j.entries[key] = submap
294 }
295 }
296 }
297
298
299
300 func canonicalHost(host string) (string, error) {
301 var err error
302 if hasPort(host) {
303 host, _, err = net.SplitHostPort(host)
304 if err != nil {
305 return "", err
306 }
307 }
308
309 host = strings.TrimSuffix(host, ".")
310 encoded, err := toASCII(host)
311 if err != nil {
312 return "", err
313 }
314
315 lower, _ := ascii.ToLower(encoded)
316 return lower, nil
317 }
318
319
320
321 func hasPort(host string) bool {
322 colons := strings.Count(host, ":")
323 if colons == 0 {
324 return false
325 }
326 if colons == 1 {
327 return true
328 }
329 return host[0] == '[' && strings.Contains(host, "]:")
330 }
331
332
333 func jarKey(host string, psl PublicSuffixList) string {
334 if isIP(host) {
335 return host
336 }
337
338 var i int
339 if psl == nil {
340 i = strings.LastIndex(host, ".")
341 if i <= 0 {
342 return host
343 }
344 } else {
345 suffix := psl.PublicSuffix(host)
346 if suffix == host {
347 return host
348 }
349 i = len(host) - len(suffix)
350 if i <= 0 || host[i-1] != '.' {
351
352
353 return host
354 }
355
356
357
358 }
359 prevDot := strings.LastIndex(host[:i-1], ".")
360 return host[prevDot+1:]
361 }
362
363
364 func isIP(host string) bool {
365 if strings.ContainsAny(host, ":%") {
366
367
368
369
370 return true
371 }
372 return net.ParseIP(host) != nil
373 }
374
375
376
377 func defaultPath(path string) string {
378 if len(path) == 0 || path[0] != '/' {
379 return "/"
380 }
381
382 i := strings.LastIndex(path, "/")
383 if i == 0 {
384 return "/"
385 }
386 return path[:i]
387 }
388
389
390
391
392
393
394
395
396
397
398 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
399 e.Name = c.Name
400
401 if c.Path == "" || c.Path[0] != '/' {
402 e.Path = defPath
403 } else {
404 e.Path = c.Path
405 }
406
407 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
408 if err != nil {
409 return e, false, err
410 }
411
412
413 if c.MaxAge < 0 {
414 return e, true, nil
415 } else if c.MaxAge > 0 {
416 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
417 e.Persistent = true
418 } else {
419 if c.Expires.IsZero() {
420 e.Expires = endOfTime
421 e.Persistent = false
422 } else {
423 if !c.Expires.After(now) {
424 return e, true, nil
425 }
426 e.Expires = c.Expires
427 e.Persistent = true
428 }
429 }
430
431 e.Value = c.Value
432 e.Secure = c.Secure
433 e.HttpOnly = c.HttpOnly
434
435 switch c.SameSite {
436 case http.SameSiteDefaultMode:
437 e.SameSite = "SameSite"
438 case http.SameSiteStrictMode:
439 e.SameSite = "SameSite=Strict"
440 case http.SameSiteLaxMode:
441 e.SameSite = "SameSite=Lax"
442 }
443
444 return e, false, nil
445 }
446
447 var (
448 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
449 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
450 )
451
452
453
454
455 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
456
457
458 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
459 if domain == "" {
460
461
462 return host, true, nil
463 }
464
465 if isIP(host) {
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482 if host != domain {
483 return "", false, errIllegalDomain
484 }
485
486
487
488
489
490
491
492
493
494
495 return host, true, nil
496 }
497
498
499
500
501 if domain[0] == '.' {
502 domain = domain[1:]
503 }
504
505 if len(domain) == 0 || domain[0] == '.' {
506
507
508 return "", false, errMalformedDomain
509 }
510
511 domain, isASCII := ascii.ToLower(domain)
512 if !isASCII {
513
514 return "", false, errMalformedDomain
515 }
516
517 if domain[len(domain)-1] == '.' {
518
519
520
521
522
523
524 return "", false, errMalformedDomain
525 }
526
527
528 if j.psList != nil {
529 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
530 if host == domain {
531
532
533 return host, true, nil
534 }
535 return "", false, errIllegalDomain
536 }
537 }
538
539
540
541 if host != domain && !hasDotSuffix(host, domain) {
542 return "", false, errIllegalDomain
543 }
544
545 return domain, false, nil
546 }
547
View as plain text