1
2
3
4
5
6 package cookiejar
7
8 import (
9 "errors"
10 "fmt"
11 "net"
12 "net/http"
13 "net/http/internal/ascii"
14 "net/url"
15 "sort"
16 "strings"
17 "sync"
18 "time"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 type PublicSuffixList interface {
36
37
38
39
40
41 PublicSuffix(domain string) string
42
43
44
45
46 String() string
47 }
48
49
50 type Options struct {
51
52
53
54
55
56
57 PublicSuffixList PublicSuffixList
58 }
59
60
61 type Jar struct {
62 psList PublicSuffixList
63
64
65 mu sync.Mutex
66
67
68
69 entries map[string]map[string]entry
70
71
72
73 nextSeqNum uint64
74 }
75
76
77
78 func New(o *Options) (*Jar, error) {
79 jar := &Jar{
80 entries: make(map[string]map[string]entry),
81 }
82 if o != nil {
83 jar.psList = o.PublicSuffixList
84 }
85 return jar, nil
86 }
87
88
89
90
91
92 type entry struct {
93 Name string
94 Value string
95 Quoted bool
96 Domain string
97 Path string
98 SameSite string
99 Secure bool
100 HttpOnly bool
101 Persistent bool
102 HostOnly bool
103 Expires time.Time
104 Creation time.Time
105 LastAccess time.Time
106
107
108
109
110 seqNum uint64
111 }
112
113
114 func (e *entry) id() string {
115 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
116 }
117
118
119
120
121 func (e *entry) shouldSend(https bool, host, path string) bool {
122 return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
123 }
124
125
126
127
128 func (e *entry) domainMatch(host string) bool {
129 if e.Domain == host {
130 return true
131 }
132 return !e.HostOnly && hasDotSuffix(host, e.Domain)
133 }
134
135
136 func (e *entry) pathMatch(requestPath string) bool {
137 if requestPath == e.Path {
138 return true
139 }
140 if strings.HasPrefix(requestPath, e.Path) {
141 if e.Path[len(e.Path)-1] == '/' {
142 return true
143 } else if requestPath[len(e.Path)] == '/' {
144 return true
145 }
146 }
147 return false
148 }
149
150
151 func hasDotSuffix(s, suffix string) bool {
152 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
153 }
154
155
156
157
158 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
159 return j.cookies(u, time.Now())
160 }
161
162
163 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
164 if u.Scheme != "http" && u.Scheme != "https" {
165 return cookies
166 }
167 host, err := canonicalHost(u.Host)
168 if err != nil {
169 return cookies
170 }
171 key := jarKey(host, j.psList)
172
173 j.mu.Lock()
174 defer j.mu.Unlock()
175
176 submap := j.entries[key]
177 if submap == nil {
178 return cookies
179 }
180
181 https := u.Scheme == "https"
182 path := u.Path
183 if path == "" {
184 path = "/"
185 }
186
187 modified := false
188 var selected []entry
189 for id, e := range submap {
190 if e.Persistent && !e.Expires.After(now) {
191 delete(submap, id)
192 modified = true
193 continue
194 }
195 if !e.shouldSend(https, host, path) {
196 continue
197 }
198 e.LastAccess = now
199 submap[id] = e
200 selected = append(selected, e)
201 modified = true
202 }
203 if modified {
204 if len(submap) == 0 {
205 delete(j.entries, key)
206 } else {
207 j.entries[key] = submap
208 }
209 }
210
211
212
213 sort.Slice(selected, func(i, j int) bool {
214 s := selected
215 if len(s[i].Path) != len(s[j].Path) {
216 return len(s[i].Path) > len(s[j].Path)
217 }
218 if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 {
219 return ret < 0
220 }
221 return s[i].seqNum < s[j].seqNum
222 })
223 for _, e := range selected {
224 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value, Quoted: e.Quoted})
225 }
226
227 return cookies
228 }
229
230
231
232
233 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
234 j.setCookies(u, cookies, time.Now())
235 }
236
237
238 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
239 if len(cookies) == 0 {
240 return
241 }
242 if u.Scheme != "http" && u.Scheme != "https" {
243 return
244 }
245 host, err := canonicalHost(u.Host)
246 if err != nil {
247 return
248 }
249 key := jarKey(host, j.psList)
250 defPath := defaultPath(u.Path)
251
252 j.mu.Lock()
253 defer j.mu.Unlock()
254
255 submap := j.entries[key]
256
257 modified := false
258 for _, cookie := range cookies {
259 e, remove, err := j.newEntry(cookie, now, defPath, host)
260 if err != nil {
261 continue
262 }
263 id := e.id()
264 if remove {
265 if submap != nil {
266 if _, ok := submap[id]; ok {
267 delete(submap, id)
268 modified = true
269 }
270 }
271 continue
272 }
273 if submap == nil {
274 submap = make(map[string]entry)
275 }
276
277 if old, ok := submap[id]; ok {
278 e.Creation = old.Creation
279 e.seqNum = old.seqNum
280 } else {
281 e.Creation = now
282 e.seqNum = j.nextSeqNum
283 j.nextSeqNum++
284 }
285 e.LastAccess = now
286 submap[id] = e
287 modified = true
288 }
289
290 if modified {
291 if len(submap) == 0 {
292 delete(j.entries, key)
293 } else {
294 j.entries[key] = submap
295 }
296 }
297 }
298
299
300
301 func canonicalHost(host string) (string, error) {
302 var err error
303 if hasPort(host) {
304 host, _, err = net.SplitHostPort(host)
305 if err != nil {
306 return "", err
307 }
308 }
309
310 host = strings.TrimSuffix(host, ".")
311 encoded, err := toASCII(host)
312 if err != nil {
313 return "", err
314 }
315
316 lower, _ := ascii.ToLower(encoded)
317 return lower, nil
318 }
319
320
321
322 func hasPort(host string) bool {
323 colons := strings.Count(host, ":")
324 if colons == 0 {
325 return false
326 }
327 if colons == 1 {
328 return true
329 }
330 return host[0] == '[' && strings.Contains(host, "]:")
331 }
332
333
334 func jarKey(host string, psl PublicSuffixList) string {
335 if isIP(host) {
336 return host
337 }
338
339 var i int
340 if psl == nil {
341 i = strings.LastIndex(host, ".")
342 if i <= 0 {
343 return host
344 }
345 } else {
346 suffix := psl.PublicSuffix(host)
347 if suffix == host {
348 return host
349 }
350 i = len(host) - len(suffix)
351 if i <= 0 || host[i-1] != '.' {
352
353
354 return host
355 }
356
357
358
359 }
360 prevDot := strings.LastIndex(host[:i-1], ".")
361 return host[prevDot+1:]
362 }
363
364
365 func isIP(host string) bool {
366 if strings.ContainsAny(host, ":%") {
367
368
369
370
371 return true
372 }
373 return net.ParseIP(host) != nil
374 }
375
376
377
378 func defaultPath(path string) string {
379 if len(path) == 0 || path[0] != '/' {
380 return "/"
381 }
382
383 i := strings.LastIndex(path, "/")
384 if i == 0 {
385 return "/"
386 }
387 return path[:i]
388 }
389
390
391
392
393
394
395
396
397
398
399 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
400 e.Name = c.Name
401
402 if c.Path == "" || c.Path[0] != '/' {
403 e.Path = defPath
404 } else {
405 e.Path = c.Path
406 }
407
408 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
409 if err != nil {
410 return e, false, err
411 }
412
413
414 if c.MaxAge < 0 {
415 return e, true, nil
416 } else if c.MaxAge > 0 {
417 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
418 e.Persistent = true
419 } else {
420 if c.Expires.IsZero() {
421 e.Expires = endOfTime
422 e.Persistent = false
423 } else {
424 if !c.Expires.After(now) {
425 return e, true, nil
426 }
427 e.Expires = c.Expires
428 e.Persistent = true
429 }
430 }
431
432 e.Value = c.Value
433 e.Quoted = c.Quoted
434 e.Secure = c.Secure
435 e.HttpOnly = c.HttpOnly
436
437 switch c.SameSite {
438 case http.SameSiteDefaultMode:
439 e.SameSite = "SameSite"
440 case http.SameSiteStrictMode:
441 e.SameSite = "SameSite=Strict"
442 case http.SameSiteLaxMode:
443 e.SameSite = "SameSite=Lax"
444 }
445
446 return e, false, nil
447 }
448
449 var (
450 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
451 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
452 )
453
454
455
456
457 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
458
459
460 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
461 if domain == "" {
462
463
464 return host, true, nil
465 }
466
467 if isIP(host) {
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484 if host != domain {
485 return "", false, errIllegalDomain
486 }
487
488
489
490
491
492
493
494
495
496
497 return host, true, nil
498 }
499
500
501
502
503 if domain[0] == '.' {
504 domain = domain[1:]
505 }
506
507 if len(domain) == 0 || domain[0] == '.' {
508
509
510 return "", false, errMalformedDomain
511 }
512
513 domain, isASCII := ascii.ToLower(domain)
514 if !isASCII {
515
516 return "", false, errMalformedDomain
517 }
518
519 if domain[len(domain)-1] == '.' {
520
521
522
523
524
525
526 return "", false, errMalformedDomain
527 }
528
529
530 if j.psList != nil {
531 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
532 if host == domain {
533
534
535 return host, true, nil
536 }
537 return "", false, errIllegalDomain
538 }
539 }
540
541
542
543 if host != domain && !hasDotSuffix(host, domain) {
544 return "", false, errIllegalDomain
545 }
546
547 return domain, false, nil
548 }
549
View as plain text