1
2
3
4 package httpcommon
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "net/http/httptrace"
11 "net/textproto"
12 "net/url"
13 "sort"
14 "strconv"
15 "strings"
16 "sync"
17
18 "golang.org/x/net/http/httpguts"
19 "golang.org/x/net/http2/hpack"
20 )
21
22
23
24
25
26
27
28 func asciiEqualFold(s, t string) bool {
29 if len(s) != len(t) {
30 return false
31 }
32 for i := 0; i < len(s); i++ {
33 if lower(s[i]) != lower(t[i]) {
34 return false
35 }
36 }
37 return true
38 }
39
40
41 func lower(b byte) byte {
42 if 'A' <= b && b <= 'Z' {
43 return b + ('a' - 'A')
44 }
45 return b
46 }
47
48
49
50 func isASCIIPrint(s string) bool {
51 for i := 0; i < len(s); i++ {
52 if s[i] < ' ' || s[i] > '~' {
53 return false
54 }
55 }
56 return true
57 }
58
59
60
61 func asciiToLower(s string) (lower string, ok bool) {
62 if !isASCIIPrint(s) {
63 return "", false
64 }
65 return strings.ToLower(s), true
66 }
67
68 var (
69 commonBuildOnce sync.Once
70 commonLowerHeader map[string]string
71 commonCanonHeader map[string]string
72 )
73
74 func buildCommonHeaderMapsOnce() {
75 commonBuildOnce.Do(buildCommonHeaderMaps)
76 }
77
78 func buildCommonHeaderMaps() {
79 common := []string{
80 "accept",
81 "accept-charset",
82 "accept-encoding",
83 "accept-language",
84 "accept-ranges",
85 "age",
86 "access-control-allow-credentials",
87 "access-control-allow-headers",
88 "access-control-allow-methods",
89 "access-control-allow-origin",
90 "access-control-expose-headers",
91 "access-control-max-age",
92 "access-control-request-headers",
93 "access-control-request-method",
94 "allow",
95 "authorization",
96 "cache-control",
97 "content-disposition",
98 "content-encoding",
99 "content-language",
100 "content-length",
101 "content-location",
102 "content-range",
103 "content-type",
104 "cookie",
105 "date",
106 "etag",
107 "expect",
108 "expires",
109 "from",
110 "host",
111 "if-match",
112 "if-modified-since",
113 "if-none-match",
114 "if-unmodified-since",
115 "last-modified",
116 "link",
117 "location",
118 "max-forwards",
119 "origin",
120 "proxy-authenticate",
121 "proxy-authorization",
122 "range",
123 "referer",
124 "refresh",
125 "retry-after",
126 "server",
127 "set-cookie",
128 "strict-transport-security",
129 "trailer",
130 "transfer-encoding",
131 "user-agent",
132 "vary",
133 "via",
134 "www-authenticate",
135 "x-forwarded-for",
136 "x-forwarded-proto",
137 }
138 commonLowerHeader = make(map[string]string, len(common))
139 commonCanonHeader = make(map[string]string, len(common))
140 for _, v := range common {
141 chk := textproto.CanonicalMIMEHeaderKey(v)
142 commonLowerHeader[chk] = v
143 commonCanonHeader[v] = chk
144 }
145 }
146
147
148
149 func LowerHeader(v string) (lower string, ascii bool) {
150 buildCommonHeaderMapsOnce()
151 if s, ok := commonLowerHeader[v]; ok {
152 return s, true
153 }
154 return asciiToLower(v)
155 }
156
157
158 func CanonicalHeader(v string) string {
159 buildCommonHeaderMapsOnce()
160 if s, ok := commonCanonHeader[v]; ok {
161 return s
162 }
163 return textproto.CanonicalMIMEHeaderKey(v)
164 }
165
166
167 func CachedCanonicalHeader(v string) (string, bool) {
168 buildCommonHeaderMapsOnce()
169 s, ok := commonCanonHeader[v]
170 return s, ok
171 }
172
173 var (
174 ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit")
175 )
176
177
178
179
180 type Request struct {
181 URL *url.URL
182 Method string
183 Host string
184 Header map[string][]string
185 Trailer map[string][]string
186 ActualContentLength int64
187 }
188
189
190 type EncodeHeadersParam struct {
191 Request Request
192
193
194
195 AddGzipHeader bool
196
197
198 PeerMaxHeaderListSize uint64
199
200
201
202 DefaultUserAgent string
203 }
204
205
206 type EncodeHeadersResult struct {
207 HasBody bool
208 HasTrailers bool
209 }
210
211
212
213
214
215 func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) {
216 req := param.Request
217
218
219 if err := checkConnHeaders(req.Header); err != nil {
220 return res, err
221 }
222
223 if req.URL == nil {
224 return res, errors.New("Request.URL is nil")
225 }
226
227 host := req.Host
228 if host == "" {
229 host = req.URL.Host
230 }
231 host, err := httpguts.PunycodeHostPort(host)
232 if err != nil {
233 return res, err
234 }
235 if !httpguts.ValidHostHeader(host) {
236 return res, errors.New("invalid Host header")
237 }
238
239
240 isNormalConnect := false
241 var protocol string
242 if vv := req.Header[":protocol"]; len(vv) > 0 {
243 protocol = vv[0]
244 }
245 if req.Method == "CONNECT" && protocol == "" {
246 isNormalConnect = true
247 } else if protocol != "" && req.Method != "CONNECT" {
248 return res, errors.New("invalid :protocol header in non-CONNECT request")
249 }
250
251
252 var path string
253 if !isNormalConnect {
254 path = req.URL.RequestURI()
255 if !validPseudoPath(path) {
256 orig := path
257 path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
258 if !validPseudoPath(path) {
259 if req.URL.Opaque != "" {
260 return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
261 } else {
262 return res, fmt.Errorf("invalid request :path %q", orig)
263 }
264 }
265 }
266 }
267
268
269
270
271 if err := validateHeaders(req.Header); err != "" {
272 return res, fmt.Errorf("invalid HTTP header %s", err)
273 }
274 if err := validateHeaders(req.Trailer); err != "" {
275 return res, fmt.Errorf("invalid HTTP trailer %s", err)
276 }
277
278 trailers, err := commaSeparatedTrailers(req.Trailer)
279 if err != nil {
280 return res, err
281 }
282
283 enumerateHeaders := func(f func(name, value string)) {
284
285
286
287
288
289 f(":authority", host)
290 m := req.Method
291 if m == "" {
292 m = "GET"
293 }
294 f(":method", m)
295 if !isNormalConnect {
296 f(":path", path)
297 f(":scheme", req.URL.Scheme)
298 }
299 if protocol != "" {
300 f(":protocol", protocol)
301 }
302 if trailers != "" {
303 f("trailer", trailers)
304 }
305
306 var didUA bool
307 for k, vv := range req.Header {
308 if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
309
310
311 continue
312 } else if asciiEqualFold(k, "connection") ||
313 asciiEqualFold(k, "proxy-connection") ||
314 asciiEqualFold(k, "transfer-encoding") ||
315 asciiEqualFold(k, "upgrade") ||
316 asciiEqualFold(k, "keep-alive") {
317
318
319
320
321 continue
322 } else if asciiEqualFold(k, "user-agent") {
323
324
325
326
327 didUA = true
328 if len(vv) < 1 {
329 continue
330 }
331 vv = vv[:1]
332 if vv[0] == "" {
333 continue
334 }
335 } else if asciiEqualFold(k, "cookie") {
336
337
338
339 for _, v := range vv {
340 for {
341 p := strings.IndexByte(v, ';')
342 if p < 0 {
343 break
344 }
345 f("cookie", v[:p])
346 p++
347
348 for p+1 <= len(v) && v[p] == ' ' {
349 p++
350 }
351 v = v[p:]
352 }
353 if len(v) > 0 {
354 f("cookie", v)
355 }
356 }
357 continue
358 } else if k == ":protocol" {
359
360 continue
361 }
362
363 for _, v := range vv {
364 f(k, v)
365 }
366 }
367 if shouldSendReqContentLength(req.Method, req.ActualContentLength) {
368 f("content-length", strconv.FormatInt(req.ActualContentLength, 10))
369 }
370 if param.AddGzipHeader {
371 f("accept-encoding", "gzip")
372 }
373 if !didUA {
374 f("user-agent", param.DefaultUserAgent)
375 }
376 }
377
378
379
380
381
382 if param.PeerMaxHeaderListSize > 0 {
383 hlSize := uint64(0)
384 enumerateHeaders(func(name, value string) {
385 hf := hpack.HeaderField{Name: name, Value: value}
386 hlSize += uint64(hf.Size())
387 })
388
389 if hlSize > param.PeerMaxHeaderListSize {
390 return res, ErrRequestHeaderListSize
391 }
392 }
393
394 trace := httptrace.ContextClientTrace(ctx)
395
396
397 enumerateHeaders(func(name, value string) {
398 name, ascii := LowerHeader(name)
399 if !ascii {
400
401
402 return
403 }
404
405 headerf(name, value)
406
407 if trace != nil && trace.WroteHeaderField != nil {
408 trace.WroteHeaderField(name, []string{value})
409 }
410 })
411
412 res.HasBody = req.ActualContentLength != 0
413 res.HasTrailers = trailers != ""
414 return res, nil
415 }
416
417
418
419 func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool {
420
421 if !disableCompression &&
422 len(header["Accept-Encoding"]) == 0 &&
423 len(header["Range"]) == 0 &&
424 method != "HEAD" {
425
426
427
428
429
430
431
432
433
434
435
436
437 return true
438 }
439 return false
440 }
441
442
443
444
445
446
447
448
449 func checkConnHeaders(h map[string][]string) error {
450 if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") {
451 return fmt.Errorf("invalid Upgrade request header: %q", vv)
452 }
453 if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
454 return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv)
455 }
456 if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
457 return fmt.Errorf("invalid Connection request header: %q", vv)
458 }
459 return nil
460 }
461
462 func commaSeparatedTrailers(trailer map[string][]string) (string, error) {
463 keys := make([]string, 0, len(trailer))
464 for k := range trailer {
465 k = CanonicalHeader(k)
466 switch k {
467 case "Transfer-Encoding", "Trailer", "Content-Length":
468 return "", fmt.Errorf("invalid Trailer key %q", k)
469 }
470 keys = append(keys, k)
471 }
472 if len(keys) > 0 {
473 sort.Strings(keys)
474 return strings.Join(keys, ","), nil
475 }
476 return "", nil
477 }
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492 func validPseudoPath(v string) bool {
493 return (len(v) > 0 && v[0] == '/') || v == "*"
494 }
495
496 func validateHeaders(hdrs map[string][]string) string {
497 for k, vv := range hdrs {
498 if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" {
499 return fmt.Sprintf("name %q", k)
500 }
501 for _, v := range vv {
502 if !httpguts.ValidHeaderFieldValue(v) {
503
504
505 return fmt.Sprintf("value for header %q", k)
506 }
507 }
508 }
509 return ""
510 }
511
512
513
514
515
516
517 func shouldSendReqContentLength(method string, contentLength int64) bool {
518 if contentLength > 0 {
519 return true
520 }
521 if contentLength < 0 {
522 return false
523 }
524
525
526 switch method {
527 case "POST", "PUT", "PATCH":
528 return true
529 default:
530 return false
531 }
532 }
533
View as plain text