1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "strings"
16 "sync"
17 )
18
19
20
21 var errMessageTooLarge = errors.New("message too large")
22
23
24
25 type Reader struct {
26 R *bufio.Reader
27 dot *dotReader
28 buf []byte
29 }
30
31
32
33
34
35
36 func NewReader(r *bufio.Reader) *Reader {
37 return &Reader{R: r}
38 }
39
40
41
42 func (r *Reader) ReadLine() (string, error) {
43 line, err := r.readLineSlice(-1)
44 return string(line), err
45 }
46
47
48 func (r *Reader) ReadLineBytes() ([]byte, error) {
49 line, err := r.readLineSlice(-1)
50 if line != nil {
51 line = bytes.Clone(line)
52 }
53 return line, err
54 }
55
56
57
58
59 func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
60 r.closeDot()
61 var line []byte
62 for {
63 l, more, err := r.R.ReadLine()
64 if err != nil {
65 return nil, err
66 }
67 if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
68 return nil, errMessageTooLarge
69 }
70
71 if line == nil && !more {
72 return l, nil
73 }
74 line = append(line, l...)
75 if !more {
76 break
77 }
78 }
79 return line, nil
80 }
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 func (r *Reader) ReadContinuedLine() (string, error) {
101 line, err := r.readContinuedLineSlice(-1, noValidation)
102 return string(line), err
103 }
104
105
106
107 func trim(s []byte) []byte {
108 i := 0
109 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
110 i++
111 }
112 n := len(s)
113 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
114 n--
115 }
116 return s[i:n]
117 }
118
119
120
121 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
122 line, err := r.readContinuedLineSlice(-1, noValidation)
123 if line != nil {
124 line = bytes.Clone(line)
125 }
126 return line, err
127 }
128
129
130
131
132
133
134 func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
135 if validateFirstLine == nil {
136 return nil, fmt.Errorf("missing validateFirstLine func")
137 }
138
139
140 line, err := r.readLineSlice(lim)
141 if err != nil {
142 return nil, err
143 }
144 if len(line) == 0 {
145 return line, nil
146 }
147
148 if err := validateFirstLine(line); err != nil {
149 return nil, err
150 }
151
152
153
154
155
156 if r.R.Buffered() > 1 {
157 peek, _ := r.R.Peek(2)
158 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
159 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
160 return trim(line), nil
161 }
162 }
163
164
165
166 r.buf = append(r.buf[:0], trim(line)...)
167
168 if lim < 0 {
169 lim = math.MaxInt64
170 }
171 lim -= int64(len(r.buf))
172
173
174 for r.skipSpace() > 0 {
175 r.buf = append(r.buf, ' ')
176 if int64(len(r.buf)) >= lim {
177 return nil, errMessageTooLarge
178 }
179 line, err := r.readLineSlice(lim - int64(len(r.buf)))
180 if err != nil {
181 break
182 }
183 r.buf = append(r.buf, trim(line)...)
184 }
185 return r.buf, nil
186 }
187
188
189 func (r *Reader) skipSpace() int {
190 n := 0
191 for {
192 c, err := r.R.ReadByte()
193 if err != nil {
194
195 break
196 }
197 if c != ' ' && c != '\t' {
198 r.R.UnreadByte()
199 break
200 }
201 n++
202 }
203 return n
204 }
205
206 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
207 line, err := r.ReadLine()
208 if err != nil {
209 return
210 }
211 return parseCodeLine(line, expectCode)
212 }
213
214 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
215 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
216 err = ProtocolError("short response: " + line)
217 return
218 }
219 continued = line[3] == '-'
220 code, err = strconv.Atoi(line[0:3])
221 if err != nil || code < 100 {
222 err = ProtocolError("invalid response code: " + line)
223 return
224 }
225 message = line[4:]
226 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
227 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
228 100 <= expectCode && expectCode < 1000 && code != expectCode {
229 err = &Error{code, message}
230 }
231 return
232 }
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
252 code, continued, message, err := r.readCodeLine(expectCode)
253 if err == nil && continued {
254 err = ProtocolError("unexpected multi-line response: " + message)
255 }
256 return
257 }
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
286 code, continued, message, err := r.readCodeLine(expectCode)
287 multi := continued
288 for continued {
289 line, err := r.ReadLine()
290 if err != nil {
291 return 0, "", err
292 }
293
294 var code2 int
295 var moreMessage string
296 code2, continued, moreMessage, err = parseCodeLine(line, 0)
297 if err != nil || code2 != code {
298 message += "\n" + strings.TrimRight(line, "\r\n")
299 continued = true
300 continue
301 }
302 message += "\n" + moreMessage
303 }
304 if err != nil && multi && message != "" {
305
306 err = &Error{code, message}
307 }
308 return
309 }
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327 func (r *Reader) DotReader() io.Reader {
328 r.closeDot()
329 r.dot = &dotReader{r: r}
330 return r.dot
331 }
332
333 type dotReader struct {
334 r *Reader
335 state int
336 }
337
338
339 func (d *dotReader) Read(b []byte) (n int, err error) {
340
341
342
343 const (
344 stateBeginLine = iota
345 stateDot
346 stateDotCR
347 stateCR
348 stateData
349 stateEOF
350 )
351 br := d.r.R
352 for n < len(b) && d.state != stateEOF {
353 var c byte
354 c, err = br.ReadByte()
355 if err != nil {
356 if err == io.EOF {
357 err = io.ErrUnexpectedEOF
358 }
359 break
360 }
361 switch d.state {
362 case stateBeginLine:
363 if c == '.' {
364 d.state = stateDot
365 continue
366 }
367 if c == '\r' {
368 d.state = stateCR
369 continue
370 }
371 d.state = stateData
372
373 case stateDot:
374 if c == '\r' {
375 d.state = stateDotCR
376 continue
377 }
378 if c == '\n' {
379 d.state = stateEOF
380 continue
381 }
382 d.state = stateData
383
384 case stateDotCR:
385 if c == '\n' {
386 d.state = stateEOF
387 continue
388 }
389
390
391 br.UnreadByte()
392 c = '\r'
393 d.state = stateData
394
395 case stateCR:
396 if c == '\n' {
397 d.state = stateBeginLine
398 break
399 }
400
401 br.UnreadByte()
402 c = '\r'
403 d.state = stateData
404
405 case stateData:
406 if c == '\r' {
407 d.state = stateCR
408 continue
409 }
410 if c == '\n' {
411 d.state = stateBeginLine
412 }
413 }
414 b[n] = c
415 n++
416 }
417 if err == nil && d.state == stateEOF {
418 err = io.EOF
419 }
420 if err != nil && d.r.dot == d {
421 d.r.dot = nil
422 }
423 return
424 }
425
426
427
428 func (r *Reader) closeDot() {
429 if r.dot == nil {
430 return
431 }
432 buf := make([]byte, 128)
433 for r.dot != nil {
434
435
436 r.dot.Read(buf)
437 }
438 }
439
440
441
442
443 func (r *Reader) ReadDotBytes() ([]byte, error) {
444 return io.ReadAll(r.DotReader())
445 }
446
447
448
449
450
451 func (r *Reader) ReadDotLines() ([]string, error) {
452
453
454
455 var v []string
456 var err error
457 for {
458 var line string
459 line, err = r.ReadLine()
460 if err != nil {
461 if err == io.EOF {
462 err = io.ErrUnexpectedEOF
463 }
464 break
465 }
466
467
468 if len(line) > 0 && line[0] == '.' {
469 if len(line) == 1 {
470 break
471 }
472 line = line[1:]
473 }
474 v = append(v, line)
475 }
476 return v, err
477 }
478
479 var colon = []byte(":")
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
501 return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
502 }
503
504
505
506 func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
507
508
509
510 var strs []string
511 hint := r.upcomingHeaderKeys()
512 if hint > 0 {
513 if hint > 1000 {
514 hint = 1000
515 }
516 strs = make([]string, hint)
517 }
518
519 m := make(MIMEHeader, hint)
520
521
522
523
524 maxMemory -= 400
525 const mapEntryOverhead = 200
526
527
528 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
529 const errorLimit = 80
530 line, err := r.readLineSlice(errorLimit)
531 if err != nil {
532 return m, err
533 }
534 return m, ProtocolError("malformed MIME header initial line: " + string(line))
535 }
536
537 for {
538 kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
539 if len(kv) == 0 {
540 return m, err
541 }
542
543
544 k, v, ok := bytes.Cut(kv, colon)
545 if !ok {
546 return m, ProtocolError("malformed MIME header line: " + string(kv))
547 }
548 key, ok := canonicalMIMEHeaderKey(k)
549 if !ok {
550 return m, ProtocolError("malformed MIME header line: " + string(kv))
551 }
552 for _, c := range v {
553 if !validHeaderValueByte(c) {
554 return m, ProtocolError("malformed MIME header line: " + string(kv))
555 }
556 }
557
558
559
560
561 if key == "" {
562 continue
563 }
564
565 maxHeaders--
566 if maxHeaders < 0 {
567 return nil, errMessageTooLarge
568 }
569
570
571 value := string(bytes.TrimLeft(v, " \t"))
572
573 vv := m[key]
574 if vv == nil {
575 maxMemory -= int64(len(key))
576 maxMemory -= mapEntryOverhead
577 }
578 maxMemory -= int64(len(value))
579 if maxMemory < 0 {
580 return m, errMessageTooLarge
581 }
582 if vv == nil && len(strs) > 0 {
583
584
585
586
587 vv, strs = strs[:1:1], strs[1:]
588 vv[0] = value
589 m[key] = vv
590 } else {
591 m[key] = append(vv, value)
592 }
593
594 if err != nil {
595 return m, err
596 }
597 }
598 }
599
600
601
602 func noValidation(_ []byte) error { return nil }
603
604
605
606
607 func mustHaveFieldNameColon(line []byte) error {
608 if bytes.IndexByte(line, ':') < 0 {
609 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
610 }
611 return nil
612 }
613
614 var nl = []byte("\n")
615
616
617
618 func (r *Reader) upcomingHeaderKeys() (n int) {
619
620 r.R.Peek(1)
621 s := r.R.Buffered()
622 if s == 0 {
623 return
624 }
625 peek, _ := r.R.Peek(s)
626 for len(peek) > 0 && n < 1000 {
627 var line []byte
628 line, peek, _ = bytes.Cut(peek, nl)
629 if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
630
631 break
632 }
633 if line[0] == ' ' || line[0] == '\t' {
634
635 continue
636 }
637 n++
638 }
639 return n
640 }
641
642
643
644
645
646
647
648
649
650 func CanonicalMIMEHeaderKey(s string) string {
651
652 upper := true
653 for i := 0; i < len(s); i++ {
654 c := s[i]
655 if !validHeaderFieldByte(c) {
656 return s
657 }
658 if upper && 'a' <= c && c <= 'z' {
659 s, _ = canonicalMIMEHeaderKey([]byte(s))
660 return s
661 }
662 if !upper && 'A' <= c && c <= 'Z' {
663 s, _ = canonicalMIMEHeaderKey([]byte(s))
664 return s
665 }
666 upper = c == '-'
667 }
668 return s
669 }
670
671 const toLower = 'a' - 'A'
672
673
674
675
676
677
678
679
680
681 func validHeaderFieldByte(c byte) bool {
682
683
684
685
686 const mask = 0 |
687 (1<<(10)-1)<<'0' |
688 (1<<(26)-1)<<'a' |
689 (1<<(26)-1)<<'A' |
690 1<<'!' |
691 1<<'#' |
692 1<<'$' |
693 1<<'%' |
694 1<<'&' |
695 1<<'\'' |
696 1<<'*' |
697 1<<'+' |
698 1<<'-' |
699 1<<'.' |
700 1<<'^' |
701 1<<'_' |
702 1<<'`' |
703 1<<'|' |
704 1<<'~'
705 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
706 (uint64(1)<<(c-64))&(mask>>64)) != 0
707 }
708
709
710
711
712
713
714
715
716
717
718
719
720
721 func validHeaderValueByte(c byte) bool {
722
723
724
725
726
727 const mask = 0 |
728 (1<<(0x7f-0x21)-1)<<0x21 |
729 1<<0x20 |
730 1<<0x09
731 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
732 (uint64(1)<<(c-64))&^(mask>>64)) == 0
733 }
734
735
736
737
738
739
740
741
742
743
744
745 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
746
747 noCanon := false
748 for _, c := range a {
749 if validHeaderFieldByte(c) {
750 continue
751 }
752
753 if c == ' ' {
754
755
756
757 noCanon = true
758 continue
759 }
760 return string(a), false
761 }
762 if noCanon {
763 return string(a), true
764 }
765
766 upper := true
767 for i, c := range a {
768
769
770
771
772 if upper && 'a' <= c && c <= 'z' {
773 c -= toLower
774 } else if !upper && 'A' <= c && c <= 'Z' {
775 c += toLower
776 }
777 a[i] = c
778 upper = c == '-'
779 }
780 commonHeaderOnce.Do(initCommonHeader)
781
782
783
784 if v := commonHeader[string(a)]; v != "" {
785 return v, true
786 }
787 return string(a), true
788 }
789
790
791 var commonHeader map[string]string
792
793 var commonHeaderOnce sync.Once
794
795 func initCommonHeader() {
796 commonHeader = make(map[string]string)
797 for _, v := range []string{
798 "Accept",
799 "Accept-Charset",
800 "Accept-Encoding",
801 "Accept-Language",
802 "Accept-Ranges",
803 "Cache-Control",
804 "Cc",
805 "Connection",
806 "Content-Id",
807 "Content-Language",
808 "Content-Length",
809 "Content-Transfer-Encoding",
810 "Content-Type",
811 "Cookie",
812 "Date",
813 "Dkim-Signature",
814 "Etag",
815 "Expires",
816 "From",
817 "Host",
818 "If-Modified-Since",
819 "If-None-Match",
820 "In-Reply-To",
821 "Last-Modified",
822 "Location",
823 "Message-Id",
824 "Mime-Version",
825 "Pragma",
826 "Received",
827 "Return-Path",
828 "Server",
829 "Set-Cookie",
830 "Subject",
831 "To",
832 "User-Agent",
833 "Via",
834 "X-Forwarded-For",
835 "X-Imforwards",
836 "X-Powered-By",
837 } {
838 commonHeader[v] = v
839 }
840 }
841
View as plain text