1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "strings"
16 "sync"
17 )
18
19
20
21 var errMessageTooLarge = errors.New("message too large")
22
23
24
25 type Reader struct {
26 R *bufio.Reader
27 dot *dotReader
28 buf []byte
29 }
30
31
32
33
34
35
36 func NewReader(r *bufio.Reader) *Reader {
37 return &Reader{R: r}
38 }
39
40
41
42 func (r *Reader) ReadLine() (string, error) {
43 line, err := r.readLineSlice(-1)
44 return string(line), err
45 }
46
47
48 func (r *Reader) ReadLineBytes() ([]byte, error) {
49 line, err := r.readLineSlice(-1)
50 if line != nil {
51 line = bytes.Clone(line)
52 }
53 return line, err
54 }
55
56
57
58
59 func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
60 r.closeDot()
61 var line []byte
62 for {
63 l, more, err := r.R.ReadLine()
64 if err != nil {
65 return nil, err
66 }
67 if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
68 return nil, errMessageTooLarge
69 }
70
71 if line == nil && !more {
72 return l, nil
73 }
74 line = append(line, l...)
75 if !more {
76 break
77 }
78 }
79 return line, nil
80 }
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 func (r *Reader) ReadContinuedLine() (string, error) {
101 line, err := r.readContinuedLineSlice(-1, noValidation)
102 return string(line), err
103 }
104
105
106
107 func trim(s []byte) []byte {
108 i := 0
109 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
110 i++
111 }
112 n := len(s)
113 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
114 n--
115 }
116 return s[i:n]
117 }
118
119
120
121 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
122 line, err := r.readContinuedLineSlice(-1, noValidation)
123 if line != nil {
124 line = bytes.Clone(line)
125 }
126 return line, err
127 }
128
129
130
131
132
133
134 func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
135 if validateFirstLine == nil {
136 return nil, fmt.Errorf("missing validateFirstLine func")
137 }
138
139
140 line, err := r.readLineSlice(lim)
141 if err != nil {
142 return nil, err
143 }
144 if len(line) == 0 {
145 return line, nil
146 }
147
148 if err := validateFirstLine(line); err != nil {
149 return nil, err
150 }
151
152
153
154
155
156 if r.R.Buffered() > 1 {
157 peek, _ := r.R.Peek(2)
158 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
159 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
160 return trim(line), nil
161 }
162 }
163
164
165
166 r.buf = append(r.buf[:0], trim(line)...)
167
168 if lim < 0 {
169 lim = math.MaxInt64
170 }
171 lim -= int64(len(r.buf))
172
173
174 for r.skipSpace() > 0 {
175 r.buf = append(r.buf, ' ')
176 if int64(len(r.buf)) >= lim {
177 return nil, errMessageTooLarge
178 }
179 line, err := r.readLineSlice(lim - int64(len(r.buf)))
180 if err != nil {
181 break
182 }
183 r.buf = append(r.buf, trim(line)...)
184 }
185 return r.buf, nil
186 }
187
188
189 func (r *Reader) skipSpace() int {
190 n := 0
191 for {
192 c, err := r.R.ReadByte()
193 if err != nil {
194
195 break
196 }
197 if c != ' ' && c != '\t' {
198 r.R.UnreadByte()
199 break
200 }
201 n++
202 }
203 return n
204 }
205
206 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
207 line, err := r.ReadLine()
208 if err != nil {
209 return
210 }
211 return parseCodeLine(line, expectCode)
212 }
213
214 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
215 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
216 err = ProtocolError("short response: " + line)
217 return
218 }
219 continued = line[3] == '-'
220 code, err = strconv.Atoi(line[0:3])
221 if err != nil || code < 100 {
222 err = ProtocolError("invalid response code: " + line)
223 return
224 }
225 message = line[4:]
226 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
227 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
228 100 <= expectCode && expectCode < 1000 && code != expectCode {
229 err = &Error{code, message}
230 }
231 return
232 }
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
252 code, continued, message, err := r.readCodeLine(expectCode)
253 if err == nil && continued {
254 err = ProtocolError("unexpected multi-line response: " + message)
255 }
256 return
257 }
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
286 code, continued, message, err := r.readCodeLine(expectCode)
287 multi := continued
288 for continued {
289 line, err := r.ReadLine()
290 if err != nil {
291 return 0, "", err
292 }
293
294 var code2 int
295 var moreMessage string
296 code2, continued, moreMessage, err = parseCodeLine(line, 0)
297 if err != nil || code2 != code {
298 message += "\n" + strings.TrimRight(line, "\r\n")
299 continued = true
300 continue
301 }
302 message += "\n" + moreMessage
303 }
304 if err != nil && multi && message != "" {
305
306 err = &Error{code, message}
307 }
308 return
309 }
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327 func (r *Reader) DotReader() io.Reader {
328 r.closeDot()
329 r.dot = &dotReader{r: r}
330 return r.dot
331 }
332
333 type dotReader struct {
334 r *Reader
335 state int
336 }
337
338
339 func (d *dotReader) Read(b []byte) (n int, err error) {
340
341
342
343 const (
344 stateBeginLine = iota
345 stateDot
346 stateDotCR
347 stateCR
348 stateData
349 stateEOF
350 )
351 br := d.r.R
352 for n < len(b) && d.state != stateEOF {
353 var c byte
354 c, err = br.ReadByte()
355 if err != nil {
356 if err == io.EOF {
357 err = io.ErrUnexpectedEOF
358 }
359 break
360 }
361 switch d.state {
362 case stateBeginLine:
363 if c == '.' {
364 d.state = stateDot
365 continue
366 }
367 if c == '\r' {
368 d.state = stateCR
369 continue
370 }
371 d.state = stateData
372
373 case stateDot:
374 if c == '\r' {
375 d.state = stateDotCR
376 continue
377 }
378 if c == '\n' {
379 d.state = stateEOF
380 continue
381 }
382 d.state = stateData
383
384 case stateDotCR:
385 if c == '\n' {
386 d.state = stateEOF
387 continue
388 }
389
390
391 br.UnreadByte()
392 c = '\r'
393 d.state = stateData
394
395 case stateCR:
396 if c == '\n' {
397 d.state = stateBeginLine
398 break
399 }
400
401 br.UnreadByte()
402 c = '\r'
403 d.state = stateData
404
405 case stateData:
406 if c == '\r' {
407 d.state = stateCR
408 continue
409 }
410 if c == '\n' {
411 d.state = stateBeginLine
412 }
413 }
414 b[n] = c
415 n++
416 }
417 if err == nil && d.state == stateEOF {
418 err = io.EOF
419 }
420 if err != nil && d.r.dot == d {
421 d.r.dot = nil
422 }
423 return
424 }
425
426
427
428 func (r *Reader) closeDot() {
429 if r.dot == nil {
430 return
431 }
432 buf := make([]byte, 128)
433 for r.dot != nil {
434
435
436 r.dot.Read(buf)
437 }
438 }
439
440
441
442
443 func (r *Reader) ReadDotBytes() ([]byte, error) {
444 return io.ReadAll(r.DotReader())
445 }
446
447
448
449
450
451 func (r *Reader) ReadDotLines() ([]string, error) {
452
453
454
455 var v []string
456 var err error
457 for {
458 var line string
459 line, err = r.ReadLine()
460 if err != nil {
461 if err == io.EOF {
462 err = io.ErrUnexpectedEOF
463 }
464 break
465 }
466
467
468 if len(line) > 0 && line[0] == '.' {
469 if len(line) == 1 {
470 break
471 }
472 line = line[1:]
473 }
474 v = append(v, line)
475 }
476 return v, err
477 }
478
479 var colon = []byte(":")
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
501 return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
502 }
503
504
505
506 func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
507
508
509
510 var strs []string
511 hint := r.upcomingHeaderKeys()
512 if hint > 0 {
513 if hint > 1000 {
514 hint = 1000
515 }
516 strs = make([]string, hint)
517 }
518
519 m := make(MIMEHeader, hint)
520
521
522
523
524 maxMemory -= 400
525 const mapEntryOverhead = 200
526
527
528 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
529 const errorLimit = 80
530 line, err := r.readLineSlice(errorLimit)
531 if err != nil {
532 return m, err
533 }
534 return m, ProtocolError("malformed MIME header initial line: " + string(line))
535 }
536
537 for {
538 kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
539 if len(kv) == 0 {
540 return m, err
541 }
542
543
544 k, v, ok := bytes.Cut(kv, colon)
545 if !ok {
546 return m, ProtocolError("malformed MIME header line: " + string(kv))
547 }
548 key, ok := canonicalMIMEHeaderKey(k)
549 if !ok {
550 return m, ProtocolError("malformed MIME header line: " + string(kv))
551 }
552 for _, c := range v {
553 if !validHeaderValueByte(c) {
554 return m, ProtocolError("malformed MIME header line: " + string(kv))
555 }
556 }
557
558 maxHeaders--
559 if maxHeaders < 0 {
560 return nil, errMessageTooLarge
561 }
562
563
564 value := string(bytes.TrimLeft(v, " \t"))
565
566 vv := m[key]
567 if vv == nil {
568 maxMemory -= int64(len(key))
569 maxMemory -= mapEntryOverhead
570 }
571 maxMemory -= int64(len(value))
572 if maxMemory < 0 {
573 return m, errMessageTooLarge
574 }
575 if vv == nil && len(strs) > 0 {
576
577
578
579
580 vv, strs = strs[:1:1], strs[1:]
581 vv[0] = value
582 m[key] = vv
583 } else {
584 m[key] = append(vv, value)
585 }
586
587 if err != nil {
588 return m, err
589 }
590 }
591 }
592
593
594
595 func noValidation(_ []byte) error { return nil }
596
597
598
599
600 func mustHaveFieldNameColon(line []byte) error {
601 if bytes.IndexByte(line, ':') < 0 {
602 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
603 }
604 return nil
605 }
606
607 var nl = []byte("\n")
608
609
610
611 func (r *Reader) upcomingHeaderKeys() (n int) {
612
613 r.R.Peek(1)
614 s := r.R.Buffered()
615 if s == 0 {
616 return
617 }
618 peek, _ := r.R.Peek(s)
619 for len(peek) > 0 && n < 1000 {
620 var line []byte
621 line, peek, _ = bytes.Cut(peek, nl)
622 if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
623
624 break
625 }
626 if line[0] == ' ' || line[0] == '\t' {
627
628 continue
629 }
630 n++
631 }
632 return n
633 }
634
635
636
637
638
639
640
641
642
643 func CanonicalMIMEHeaderKey(s string) string {
644
645 upper := true
646 for i := 0; i < len(s); i++ {
647 c := s[i]
648 if !validHeaderFieldByte(c) {
649 return s
650 }
651 if upper && 'a' <= c && c <= 'z' {
652 s, _ = canonicalMIMEHeaderKey([]byte(s))
653 return s
654 }
655 if !upper && 'A' <= c && c <= 'Z' {
656 s, _ = canonicalMIMEHeaderKey([]byte(s))
657 return s
658 }
659 upper = c == '-'
660 }
661 return s
662 }
663
664 const toLower = 'a' - 'A'
665
666
667
668
669
670
671
672
673
674 func validHeaderFieldByte(c byte) bool {
675
676
677
678
679 const mask = 0 |
680 (1<<(10)-1)<<'0' |
681 (1<<(26)-1)<<'a' |
682 (1<<(26)-1)<<'A' |
683 1<<'!' |
684 1<<'#' |
685 1<<'$' |
686 1<<'%' |
687 1<<'&' |
688 1<<'\'' |
689 1<<'*' |
690 1<<'+' |
691 1<<'-' |
692 1<<'.' |
693 1<<'^' |
694 1<<'_' |
695 1<<'`' |
696 1<<'|' |
697 1<<'~'
698 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
699 (uint64(1)<<(c-64))&(mask>>64)) != 0
700 }
701
702
703
704
705
706
707
708
709
710
711
712
713
714 func validHeaderValueByte(c byte) bool {
715
716
717
718
719
720 const mask = 0 |
721 (1<<(0x7f-0x21)-1)<<0x21 |
722 1<<0x20 |
723 1<<0x09
724 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
725 (uint64(1)<<(c-64))&^(mask>>64)) == 0
726 }
727
728
729
730
731
732
733
734
735
736
737
738 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
739 if len(a) == 0 {
740 return "", false
741 }
742
743
744 noCanon := false
745 for _, c := range a {
746 if validHeaderFieldByte(c) {
747 continue
748 }
749
750 if c == ' ' {
751
752
753
754 noCanon = true
755 continue
756 }
757 return string(a), false
758 }
759 if noCanon {
760 return string(a), true
761 }
762
763 upper := true
764 for i, c := range a {
765
766
767
768
769 if upper && 'a' <= c && c <= 'z' {
770 c -= toLower
771 } else if !upper && 'A' <= c && c <= 'Z' {
772 c += toLower
773 }
774 a[i] = c
775 upper = c == '-'
776 }
777 commonHeaderOnce.Do(initCommonHeader)
778
779
780
781 if v := commonHeader[string(a)]; v != "" {
782 return v, true
783 }
784 return string(a), true
785 }
786
787
788 var commonHeader map[string]string
789
790 var commonHeaderOnce sync.Once
791
792 func initCommonHeader() {
793 commonHeader = make(map[string]string)
794 for _, v := range []string{
795 "Accept",
796 "Accept-Charset",
797 "Accept-Encoding",
798 "Accept-Language",
799 "Accept-Ranges",
800 "Cache-Control",
801 "Cc",
802 "Connection",
803 "Content-Id",
804 "Content-Language",
805 "Content-Length",
806 "Content-Transfer-Encoding",
807 "Content-Type",
808 "Cookie",
809 "Date",
810 "Dkim-Signature",
811 "Etag",
812 "Expires",
813 "From",
814 "Host",
815 "If-Modified-Since",
816 "If-None-Match",
817 "In-Reply-To",
818 "Last-Modified",
819 "Location",
820 "Message-Id",
821 "Mime-Version",
822 "Pragma",
823 "Received",
824 "Return-Path",
825 "Server",
826 "Set-Cookie",
827 "Subject",
828 "To",
829 "User-Agent",
830 "Via",
831 "X-Forwarded-For",
832 "X-Imforwards",
833 "X-Powered-By",
834 } {
835 commonHeader[v] = v
836 }
837 }
838
View as plain text