1  
     2  
     3  
     4  
     5  package textproto
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"math"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  )
    18  
    19  
    20  
    21  var errMessageTooLarge = errors.New("message too large")
    22  
    23  
    24  
    25  type Reader struct {
    26  	R   *bufio.Reader
    27  	dot *dotReader
    28  	buf []byte 
    29  }
    30  
    31  
    32  
    33  
    34  
    35  
    36  func NewReader(r *bufio.Reader) *Reader {
    37  	return &Reader{R: r}
    38  }
    39  
    40  
    41  
    42  func (r *Reader) ReadLine() (string, error) {
    43  	line, err := r.readLineSlice(-1)
    44  	return string(line), err
    45  }
    46  
    47  
    48  func (r *Reader) ReadLineBytes() ([]byte, error) {
    49  	line, err := r.readLineSlice(-1)
    50  	if line != nil {
    51  		line = bytes.Clone(line)
    52  	}
    53  	return line, err
    54  }
    55  
    56  
    57  
    58  
    59  func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
    60  	r.closeDot()
    61  	var line []byte
    62  	for {
    63  		l, more, err := r.R.ReadLine()
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  		if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
    68  			return nil, errMessageTooLarge
    69  		}
    70  		
    71  		if line == nil && !more {
    72  			return l, nil
    73  		}
    74  		line = append(line, l...)
    75  		if !more {
    76  			break
    77  		}
    78  	}
    79  	return line, nil
    80  }
    81  
    82  
    83  
    84  
    85  
    86  
    87  
    88  
    89  
    90  
    91  
    92  
    93  
    94  
    95  
    96  
    97  
    98  
    99  
   100  func (r *Reader) ReadContinuedLine() (string, error) {
   101  	line, err := r.readContinuedLineSlice(-1, noValidation)
   102  	return string(line), err
   103  }
   104  
   105  
   106  
   107  func trim(s []byte) []byte {
   108  	i := 0
   109  	for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
   110  		i++
   111  	}
   112  	n := len(s)
   113  	for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
   114  		n--
   115  	}
   116  	return s[i:n]
   117  }
   118  
   119  
   120  
   121  func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
   122  	line, err := r.readContinuedLineSlice(-1, noValidation)
   123  	if line != nil {
   124  		line = bytes.Clone(line)
   125  	}
   126  	return line, err
   127  }
   128  
   129  
   130  
   131  
   132  
   133  
   134  func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
   135  	if validateFirstLine == nil {
   136  		return nil, fmt.Errorf("missing validateFirstLine func")
   137  	}
   138  
   139  	
   140  	line, err := r.readLineSlice(lim)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	if len(line) == 0 { 
   145  		return line, nil
   146  	}
   147  
   148  	if err := validateFirstLine(line); err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	
   153  	
   154  	
   155  	
   156  	if r.R.Buffered() > 1 {
   157  		peek, _ := r.R.Peek(2)
   158  		if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
   159  			len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
   160  			return trim(line), nil
   161  		}
   162  	}
   163  
   164  	
   165  	
   166  	r.buf = append(r.buf[:0], trim(line)...)
   167  
   168  	if lim < 0 {
   169  		lim = math.MaxInt64
   170  	}
   171  	lim -= int64(len(r.buf))
   172  
   173  	
   174  	for r.skipSpace() > 0 {
   175  		r.buf = append(r.buf, ' ')
   176  		if int64(len(r.buf)) >= lim {
   177  			return nil, errMessageTooLarge
   178  		}
   179  		line, err := r.readLineSlice(lim - int64(len(r.buf)))
   180  		if err != nil {
   181  			break
   182  		}
   183  		r.buf = append(r.buf, trim(line)...)
   184  	}
   185  	return r.buf, nil
   186  }
   187  
   188  
   189  func (r *Reader) skipSpace() int {
   190  	n := 0
   191  	for {
   192  		c, err := r.R.ReadByte()
   193  		if err != nil {
   194  			
   195  			break
   196  		}
   197  		if c != ' ' && c != '\t' {
   198  			r.R.UnreadByte()
   199  			break
   200  		}
   201  		n++
   202  	}
   203  	return n
   204  }
   205  
   206  func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
   207  	line, err := r.ReadLine()
   208  	if err != nil {
   209  		return
   210  	}
   211  	return parseCodeLine(line, expectCode)
   212  }
   213  
   214  func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
   215  	if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
   216  		err = ProtocolError("short response: " + line)
   217  		return
   218  	}
   219  	continued = line[3] == '-'
   220  	code, err = strconv.Atoi(line[0:3])
   221  	if err != nil || code < 100 {
   222  		err = ProtocolError("invalid response code: " + line)
   223  		return
   224  	}
   225  	message = line[4:]
   226  	if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
   227  		10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
   228  		100 <= expectCode && expectCode < 1000 && code != expectCode {
   229  		err = &Error{code, message}
   230  	}
   231  	return
   232  }
   233  
   234  
   235  
   236  
   237  
   238  
   239  
   240  
   241  
   242  
   243  
   244  
   245  
   246  
   247  
   248  
   249  
   250  
   251  func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
   252  	code, continued, message, err := r.readCodeLine(expectCode)
   253  	if err == nil && continued {
   254  		err = ProtocolError("unexpected multi-line response: " + message)
   255  	}
   256  	return
   257  }
   258  
   259  
   260  
   261  
   262  
   263  
   264  
   265  
   266  
   267  
   268  
   269  
   270  
   271  
   272  
   273  
   274  
   275  
   276  
   277  
   278  
   279  
   280  
   281  
   282  
   283  
   284  
   285  func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
   286  	code, continued, message, err := r.readCodeLine(expectCode)
   287  	multi := continued
   288  	for continued {
   289  		line, err := r.ReadLine()
   290  		if err != nil {
   291  			return 0, "", err
   292  		}
   293  
   294  		var code2 int
   295  		var moreMessage string
   296  		code2, continued, moreMessage, err = parseCodeLine(line, 0)
   297  		if err != nil || code2 != code {
   298  			message += "\n" + strings.TrimRight(line, "\r\n")
   299  			continued = true
   300  			continue
   301  		}
   302  		message += "\n" + moreMessage
   303  	}
   304  	if err != nil && multi && message != "" {
   305  		
   306  		err = &Error{code, message}
   307  	}
   308  	return
   309  }
   310  
   311  
   312  
   313  
   314  
   315  
   316  
   317  
   318  
   319  
   320  
   321  
   322  
   323  
   324  
   325  
   326  
   327  func (r *Reader) DotReader() io.Reader {
   328  	r.closeDot()
   329  	r.dot = &dotReader{r: r}
   330  	return r.dot
   331  }
   332  
   333  type dotReader struct {
   334  	r     *Reader
   335  	state int
   336  }
   337  
   338  
   339  func (d *dotReader) Read(b []byte) (n int, err error) {
   340  	
   341  	
   342  	
   343  	const (
   344  		stateBeginLine = iota 
   345  		stateDot              
   346  		stateDotCR            
   347  		stateCR               
   348  		stateData             
   349  		stateEOF              
   350  	)
   351  	br := d.r.R
   352  	for n < len(b) && d.state != stateEOF {
   353  		var c byte
   354  		c, err = br.ReadByte()
   355  		if err != nil {
   356  			if err == io.EOF {
   357  				err = io.ErrUnexpectedEOF
   358  			}
   359  			break
   360  		}
   361  		switch d.state {
   362  		case stateBeginLine:
   363  			if c == '.' {
   364  				d.state = stateDot
   365  				continue
   366  			}
   367  			if c == '\r' {
   368  				d.state = stateCR
   369  				continue
   370  			}
   371  			d.state = stateData
   372  
   373  		case stateDot:
   374  			if c == '\r' {
   375  				d.state = stateDotCR
   376  				continue
   377  			}
   378  			if c == '\n' {
   379  				d.state = stateEOF
   380  				continue
   381  			}
   382  			d.state = stateData
   383  
   384  		case stateDotCR:
   385  			if c == '\n' {
   386  				d.state = stateEOF
   387  				continue
   388  			}
   389  			
   390  			
   391  			br.UnreadByte()
   392  			c = '\r'
   393  			d.state = stateData
   394  
   395  		case stateCR:
   396  			if c == '\n' {
   397  				d.state = stateBeginLine
   398  				break
   399  			}
   400  			
   401  			br.UnreadByte()
   402  			c = '\r'
   403  			d.state = stateData
   404  
   405  		case stateData:
   406  			if c == '\r' {
   407  				d.state = stateCR
   408  				continue
   409  			}
   410  			if c == '\n' {
   411  				d.state = stateBeginLine
   412  			}
   413  		}
   414  		b[n] = c
   415  		n++
   416  	}
   417  	if err == nil && d.state == stateEOF {
   418  		err = io.EOF
   419  	}
   420  	if err != nil && d.r.dot == d {
   421  		d.r.dot = nil
   422  	}
   423  	return
   424  }
   425  
   426  
   427  
   428  func (r *Reader) closeDot() {
   429  	if r.dot == nil {
   430  		return
   431  	}
   432  	buf := make([]byte, 128)
   433  	for r.dot != nil {
   434  		
   435  		
   436  		r.dot.Read(buf)
   437  	}
   438  }
   439  
   440  
   441  
   442  
   443  func (r *Reader) ReadDotBytes() ([]byte, error) {
   444  	return io.ReadAll(r.DotReader())
   445  }
   446  
   447  
   448  
   449  
   450  
   451  func (r *Reader) ReadDotLines() ([]string, error) {
   452  	
   453  	
   454  	
   455  	var v []string
   456  	var err error
   457  	for {
   458  		var line string
   459  		line, err = r.ReadLine()
   460  		if err != nil {
   461  			if err == io.EOF {
   462  				err = io.ErrUnexpectedEOF
   463  			}
   464  			break
   465  		}
   466  
   467  		
   468  		if len(line) > 0 && line[0] == '.' {
   469  			if len(line) == 1 {
   470  				break
   471  			}
   472  			line = line[1:]
   473  		}
   474  		v = append(v, line)
   475  	}
   476  	return v, err
   477  }
   478  
   479  var colon = []byte(":")
   480  
   481  
   482  
   483  
   484  
   485  
   486  
   487  
   488  
   489  
   490  
   491  
   492  
   493  
   494  
   495  
   496  
   497  
   498  
   499  
   500  func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
   501  	return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
   502  }
   503  
   504  
   505  
   506  func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
   507  	
   508  	
   509  	
   510  	var strs []string
   511  	hint := r.upcomingHeaderKeys()
   512  	if hint > 0 {
   513  		if hint > 1000 {
   514  			hint = 1000 
   515  		}
   516  		strs = make([]string, hint)
   517  	}
   518  
   519  	m := make(MIMEHeader, hint)
   520  
   521  	
   522  	
   523  	
   524  	maxMemory -= 400
   525  	const mapEntryOverhead = 200
   526  
   527  	
   528  	if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
   529  		const errorLimit = 80 
   530  		line, err := r.readLineSlice(errorLimit)
   531  		if err != nil {
   532  			return m, err
   533  		}
   534  		return m, ProtocolError("malformed MIME header initial line: " + string(line))
   535  	}
   536  
   537  	for {
   538  		kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
   539  		if len(kv) == 0 {
   540  			return m, err
   541  		}
   542  
   543  		
   544  		k, v, ok := bytes.Cut(kv, colon)
   545  		if !ok {
   546  			return m, ProtocolError("malformed MIME header line: " + string(kv))
   547  		}
   548  		key, ok := canonicalMIMEHeaderKey(k)
   549  		if !ok {
   550  			return m, ProtocolError("malformed MIME header line: " + string(kv))
   551  		}
   552  		for _, c := range v {
   553  			if !validHeaderValueByte(c) {
   554  				return m, ProtocolError("malformed MIME header line: " + string(kv))
   555  			}
   556  		}
   557  
   558  		
   559  		
   560  		
   561  		if key == "" {
   562  			continue
   563  		}
   564  
   565  		maxHeaders--
   566  		if maxHeaders < 0 {
   567  			return nil, errMessageTooLarge
   568  		}
   569  
   570  		
   571  		value := string(bytes.TrimLeft(v, " \t"))
   572  
   573  		vv := m[key]
   574  		if vv == nil {
   575  			maxMemory -= int64(len(key))
   576  			maxMemory -= mapEntryOverhead
   577  		}
   578  		maxMemory -= int64(len(value))
   579  		if maxMemory < 0 {
   580  			return m, errMessageTooLarge
   581  		}
   582  		if vv == nil && len(strs) > 0 {
   583  			
   584  			
   585  			
   586  			
   587  			vv, strs = strs[:1:1], strs[1:]
   588  			vv[0] = value
   589  			m[key] = vv
   590  		} else {
   591  			m[key] = append(vv, value)
   592  		}
   593  
   594  		if err != nil {
   595  			return m, err
   596  		}
   597  	}
   598  }
   599  
   600  
   601  
   602  func noValidation(_ []byte) error { return nil }
   603  
   604  
   605  
   606  
   607  func mustHaveFieldNameColon(line []byte) error {
   608  	if bytes.IndexByte(line, ':') < 0 {
   609  		return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
   610  	}
   611  	return nil
   612  }
   613  
   614  var nl = []byte("\n")
   615  
   616  
   617  
   618  func (r *Reader) upcomingHeaderKeys() (n int) {
   619  	
   620  	r.R.Peek(1) 
   621  	s := r.R.Buffered()
   622  	if s == 0 {
   623  		return
   624  	}
   625  	peek, _ := r.R.Peek(s)
   626  	for len(peek) > 0 && n < 1000 {
   627  		var line []byte
   628  		line, peek, _ = bytes.Cut(peek, nl)
   629  		if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
   630  			
   631  			break
   632  		}
   633  		if line[0] == ' ' || line[0] == '\t' {
   634  			
   635  			continue
   636  		}
   637  		n++
   638  	}
   639  	return n
   640  }
   641  
   642  
   643  
   644  
   645  
   646  
   647  
   648  
   649  
   650  func CanonicalMIMEHeaderKey(s string) string {
   651  	
   652  	upper := true
   653  	for i := 0; i < len(s); i++ {
   654  		c := s[i]
   655  		if !validHeaderFieldByte(c) {
   656  			return s
   657  		}
   658  		if upper && 'a' <= c && c <= 'z' {
   659  			s, _ = canonicalMIMEHeaderKey([]byte(s))
   660  			return s
   661  		}
   662  		if !upper && 'A' <= c && c <= 'Z' {
   663  			s, _ = canonicalMIMEHeaderKey([]byte(s))
   664  			return s
   665  		}
   666  		upper = c == '-'
   667  	}
   668  	return s
   669  }
   670  
   671  const toLower = 'a' - 'A'
   672  
   673  
   674  
   675  
   676  
   677  
   678  
   679  
   680  
   681  func validHeaderFieldByte(c byte) bool {
   682  	
   683  	
   684  	
   685  	
   686  	const mask = 0 |
   687  		(1<<(10)-1)<<'0' |
   688  		(1<<(26)-1)<<'a' |
   689  		(1<<(26)-1)<<'A' |
   690  		1<<'!' |
   691  		1<<'#' |
   692  		1<<'$' |
   693  		1<<'%' |
   694  		1<<'&' |
   695  		1<<'\'' |
   696  		1<<'*' |
   697  		1<<'+' |
   698  		1<<'-' |
   699  		1<<'.' |
   700  		1<<'^' |
   701  		1<<'_' |
   702  		1<<'`' |
   703  		1<<'|' |
   704  		1<<'~'
   705  	return ((uint64(1)<<c)&(mask&(1<<64-1)) |
   706  		(uint64(1)<<(c-64))&(mask>>64)) != 0
   707  }
   708  
   709  
   710  
   711  
   712  
   713  
   714  
   715  
   716  
   717  
   718  
   719  
   720  
   721  func validHeaderValueByte(c byte) bool {
   722  	
   723  	
   724  	
   725  	
   726  	
   727  	const mask = 0 |
   728  		(1<<(0x7f-0x21)-1)<<0x21 | 
   729  		1<<0x20 | 
   730  		1<<0x09 
   731  	return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
   732  		(uint64(1)<<(c-64))&^(mask>>64)) == 0
   733  }
   734  
   735  
   736  
   737  
   738  
   739  
   740  
   741  
   742  
   743  
   744  
   745  func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
   746  	
   747  	noCanon := false
   748  	for _, c := range a {
   749  		if validHeaderFieldByte(c) {
   750  			continue
   751  		}
   752  		
   753  		if c == ' ' {
   754  			
   755  			
   756  			
   757  			noCanon = true
   758  			continue
   759  		}
   760  		return string(a), false
   761  	}
   762  	if noCanon {
   763  		return string(a), true
   764  	}
   765  
   766  	upper := true
   767  	for i, c := range a {
   768  		
   769  		
   770  		
   771  		
   772  		if upper && 'a' <= c && c <= 'z' {
   773  			c -= toLower
   774  		} else if !upper && 'A' <= c && c <= 'Z' {
   775  			c += toLower
   776  		}
   777  		a[i] = c
   778  		upper = c == '-' 
   779  	}
   780  	commonHeaderOnce.Do(initCommonHeader)
   781  	
   782  	
   783  	
   784  	if v := commonHeader[string(a)]; v != "" {
   785  		return v, true
   786  	}
   787  	return string(a), true
   788  }
   789  
   790  
   791  var commonHeader map[string]string
   792  
   793  var commonHeaderOnce sync.Once
   794  
   795  func initCommonHeader() {
   796  	commonHeader = make(map[string]string)
   797  	for _, v := range []string{
   798  		"Accept",
   799  		"Accept-Charset",
   800  		"Accept-Encoding",
   801  		"Accept-Language",
   802  		"Accept-Ranges",
   803  		"Cache-Control",
   804  		"Cc",
   805  		"Connection",
   806  		"Content-Id",
   807  		"Content-Language",
   808  		"Content-Length",
   809  		"Content-Transfer-Encoding",
   810  		"Content-Type",
   811  		"Cookie",
   812  		"Date",
   813  		"Dkim-Signature",
   814  		"Etag",
   815  		"Expires",
   816  		"From",
   817  		"Host",
   818  		"If-Modified-Since",
   819  		"If-None-Match",
   820  		"In-Reply-To",
   821  		"Last-Modified",
   822  		"Location",
   823  		"Message-Id",
   824  		"Mime-Version",
   825  		"Pragma",
   826  		"Received",
   827  		"Return-Path",
   828  		"Server",
   829  		"Set-Cookie",
   830  		"Subject",
   831  		"To",
   832  		"User-Agent",
   833  		"Via",
   834  		"X-Forwarded-For",
   835  		"X-Imforwards",
   836  		"X-Powered-By",
   837  	} {
   838  		commonHeader[v] = v
   839  	}
   840  }
   841  
View as plain text