1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "time"
25
26 "golang.org/x/net/http/httpguts"
27 )
28
29
30 type ProxyRequest struct {
31
32
33 In *http.Request
34
35
36
37
38
39 Out *http.Request
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54 func (r *ProxyRequest) SetURL(target *url.URL) {
55 rewriteRequestURL(r.Out, target)
56 r.Out.Host = ""
57 }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 func (r *ProxyRequest) SetXForwarded() {
79 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
80 if err == nil {
81 prior := r.Out.Header["X-Forwarded-For"]
82 if len(prior) > 0 {
83 clientIP = strings.Join(prior, ", ") + ", " + clientIP
84 }
85 r.Out.Header.Set("X-Forwarded-For", clientIP)
86 } else {
87 r.Out.Header.Del("X-Forwarded-For")
88 }
89 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
90 if r.In.TLS == nil {
91 r.Out.Header.Set("X-Forwarded-Proto", "http")
92 } else {
93 r.Out.Header.Set("X-Forwarded-Proto", "https")
94 }
95 }
96
97
98
99
100
101
102
103 type ReverseProxy struct {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125 Rewrite func(*ProxyRequest)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155 Director func(*http.Request)
156
157
158
159 Transport http.RoundTripper
160
161
162
163
164
165
166
167
168
169
170
171 FlushInterval time.Duration
172
173
174
175
176 ErrorLog *log.Logger
177
178
179
180
181 BufferPool BufferPool
182
183
184
185
186
187
188
189
190
191
192 ModifyResponse func(*http.Response) error
193
194
195
196
197
198
199 ErrorHandler func(http.ResponseWriter, *http.Request, error)
200 }
201
202
203
204 type BufferPool interface {
205 Get() []byte
206 Put([]byte)
207 }
208
209 func singleJoiningSlash(a, b string) string {
210 aslash := strings.HasSuffix(a, "/")
211 bslash := strings.HasPrefix(b, "/")
212 switch {
213 case aslash && bslash:
214 return a + b[1:]
215 case !aslash && !bslash:
216 return a + "/" + b
217 }
218 return a + b
219 }
220
221 func joinURLPath(a, b *url.URL) (path, rawpath string) {
222 if a.RawPath == "" && b.RawPath == "" {
223 return singleJoiningSlash(a.Path, b.Path), ""
224 }
225
226
227 apath := a.EscapedPath()
228 bpath := b.EscapedPath()
229
230 aslash := strings.HasSuffix(apath, "/")
231 bslash := strings.HasPrefix(bpath, "/")
232
233 switch {
234 case aslash && bslash:
235 return a.Path + b.Path[1:], apath + bpath[1:]
236 case !aslash && !bslash:
237 return a.Path + "/" + b.Path, apath + "/" + bpath
238 }
239 return a.Path + b.Path, apath + bpath
240 }
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
263 director := func(req *http.Request) {
264 rewriteRequestURL(req, target)
265 }
266 return &ReverseProxy{Director: director}
267 }
268
269 func rewriteRequestURL(req *http.Request, target *url.URL) {
270 targetQuery := target.RawQuery
271 req.URL.Scheme = target.Scheme
272 req.URL.Host = target.Host
273 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
274 if targetQuery == "" || req.URL.RawQuery == "" {
275 req.URL.RawQuery = targetQuery + req.URL.RawQuery
276 } else {
277 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
278 }
279 }
280
281 func copyHeader(dst, src http.Header) {
282 for k, vv := range src {
283 for _, v := range vv {
284 dst.Add(k, v)
285 }
286 }
287 }
288
289
290
291
292
293
294 var hopHeaders = []string{
295 "Connection",
296 "Proxy-Connection",
297 "Keep-Alive",
298 "Proxy-Authenticate",
299 "Proxy-Authorization",
300 "Te",
301 "Trailer",
302 "Transfer-Encoding",
303 "Upgrade",
304 }
305
306 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
307 p.logf("http: proxy error: %v", err)
308 rw.WriteHeader(http.StatusBadGateway)
309 }
310
311 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
312 if p.ErrorHandler != nil {
313 return p.ErrorHandler
314 }
315 return p.defaultErrorHandler
316 }
317
318
319
320 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
321 if p.ModifyResponse == nil {
322 return true
323 }
324 if err := p.ModifyResponse(res); err != nil {
325 res.Body.Close()
326 p.getErrorHandler()(rw, req, err)
327 return false
328 }
329 return true
330 }
331
332 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
333 transport := p.Transport
334 if transport == nil {
335 transport = http.DefaultTransport
336 }
337
338 ctx := req.Context()
339 if ctx.Done() != nil {
340
341
342
343
344
345
346
347
348
349
350 } else if cn, ok := rw.(http.CloseNotifier); ok {
351 var cancel context.CancelFunc
352 ctx, cancel = context.WithCancel(ctx)
353 defer cancel()
354 notifyChan := cn.CloseNotify()
355 go func() {
356 select {
357 case <-notifyChan:
358 cancel()
359 case <-ctx.Done():
360 }
361 }()
362 }
363
364 outreq := req.Clone(ctx)
365 if req.ContentLength == 0 {
366 outreq.Body = nil
367 }
368 if outreq.Body != nil {
369
370
371
372
373
374
375 defer outreq.Body.Close()
376 }
377 if outreq.Header == nil {
378 outreq.Header = make(http.Header)
379 }
380
381 if (p.Director != nil) == (p.Rewrite != nil) {
382 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
383 return
384 }
385
386 if p.Director != nil {
387 p.Director(outreq)
388 if outreq.Form != nil {
389 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
390 }
391 }
392 outreq.Close = false
393
394 reqUpType := upgradeType(outreq.Header)
395 if !ascii.IsPrint(reqUpType) {
396 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
397 return
398 }
399 removeHopByHopHeaders(outreq.Header)
400
401
402
403
404
405
406 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
407 outreq.Header.Set("Te", "trailers")
408 }
409
410
411
412 if reqUpType != "" {
413 outreq.Header.Set("Connection", "Upgrade")
414 outreq.Header.Set("Upgrade", reqUpType)
415 }
416
417 if p.Rewrite != nil {
418
419
420
421 outreq.Header.Del("Forwarded")
422 outreq.Header.Del("X-Forwarded-For")
423 outreq.Header.Del("X-Forwarded-Host")
424 outreq.Header.Del("X-Forwarded-Proto")
425
426
427 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
428
429 pr := &ProxyRequest{
430 In: req,
431 Out: outreq,
432 }
433 p.Rewrite(pr)
434 outreq = pr.Out
435 } else {
436 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
437
438
439
440 prior, ok := outreq.Header["X-Forwarded-For"]
441 omit := ok && prior == nil
442 if len(prior) > 0 {
443 clientIP = strings.Join(prior, ", ") + ", " + clientIP
444 }
445 if !omit {
446 outreq.Header.Set("X-Forwarded-For", clientIP)
447 }
448 }
449 }
450
451 if _, ok := outreq.Header["User-Agent"]; !ok {
452
453
454 outreq.Header.Set("User-Agent", "")
455 }
456
457 trace := &httptrace.ClientTrace{
458 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
459 h := rw.Header()
460 copyHeader(h, http.Header(header))
461 rw.WriteHeader(code)
462
463
464 clear(h)
465 return nil
466 },
467 }
468 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
469
470 res, err := transport.RoundTrip(outreq)
471 if err != nil {
472 p.getErrorHandler()(rw, outreq, err)
473 return
474 }
475
476
477 if res.StatusCode == http.StatusSwitchingProtocols {
478 if !p.modifyResponse(rw, res, outreq) {
479 return
480 }
481 p.handleUpgradeResponse(rw, outreq, res)
482 return
483 }
484
485 removeHopByHopHeaders(res.Header)
486
487 if !p.modifyResponse(rw, res, outreq) {
488 return
489 }
490
491 copyHeader(rw.Header(), res.Header)
492
493
494
495 announcedTrailers := len(res.Trailer)
496 if announcedTrailers > 0 {
497 trailerKeys := make([]string, 0, len(res.Trailer))
498 for k := range res.Trailer {
499 trailerKeys = append(trailerKeys, k)
500 }
501 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
502 }
503
504 rw.WriteHeader(res.StatusCode)
505
506 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
507 if err != nil {
508 defer res.Body.Close()
509
510
511
512 if !shouldPanicOnCopyError(req) {
513 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
514 return
515 }
516 panic(http.ErrAbortHandler)
517 }
518 res.Body.Close()
519
520 if len(res.Trailer) > 0 {
521
522
523
524 http.NewResponseController(rw).Flush()
525 }
526
527 if len(res.Trailer) == announcedTrailers {
528 copyHeader(rw.Header(), res.Trailer)
529 return
530 }
531
532 for k, vv := range res.Trailer {
533 k = http.TrailerPrefix + k
534 for _, v := range vv {
535 rw.Header().Add(k, v)
536 }
537 }
538 }
539
540 var inOurTests bool
541
542
543
544
545
546
547 func shouldPanicOnCopyError(req *http.Request) bool {
548 if inOurTests {
549
550 return true
551 }
552 if req.Context().Value(http.ServerContextKey) != nil {
553
554
555 return true
556 }
557
558
559 return false
560 }
561
562
563 func removeHopByHopHeaders(h http.Header) {
564
565 for _, f := range h["Connection"] {
566 for _, sf := range strings.Split(f, ",") {
567 if sf = textproto.TrimString(sf); sf != "" {
568 h.Del(sf)
569 }
570 }
571 }
572
573
574
575 for _, f := range hopHeaders {
576 h.Del(f)
577 }
578 }
579
580
581
582 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
583 resCT := res.Header.Get("Content-Type")
584
585
586
587 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
588 return -1
589 }
590
591
592 if res.ContentLength == -1 {
593 return -1
594 }
595
596 return p.FlushInterval
597 }
598
599 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
600 var w io.Writer = dst
601
602 if flushInterval != 0 {
603 mlw := &maxLatencyWriter{
604 dst: dst,
605 flush: http.NewResponseController(dst).Flush,
606 latency: flushInterval,
607 }
608 defer mlw.stop()
609
610
611 mlw.flushPending = true
612 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
613
614 w = mlw
615 }
616
617 var buf []byte
618 if p.BufferPool != nil {
619 buf = p.BufferPool.Get()
620 defer p.BufferPool.Put(buf)
621 }
622 _, err := p.copyBuffer(w, src, buf)
623 return err
624 }
625
626
627
628 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
629 if len(buf) == 0 {
630 buf = make([]byte, 32*1024)
631 }
632 var written int64
633 for {
634 nr, rerr := src.Read(buf)
635 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
636 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
637 }
638 if nr > 0 {
639 nw, werr := dst.Write(buf[:nr])
640 if nw > 0 {
641 written += int64(nw)
642 }
643 if werr != nil {
644 return written, werr
645 }
646 if nr != nw {
647 return written, io.ErrShortWrite
648 }
649 }
650 if rerr != nil {
651 if rerr == io.EOF {
652 rerr = nil
653 }
654 return written, rerr
655 }
656 }
657 }
658
659 func (p *ReverseProxy) logf(format string, args ...any) {
660 if p.ErrorLog != nil {
661 p.ErrorLog.Printf(format, args...)
662 } else {
663 log.Printf(format, args...)
664 }
665 }
666
667 type maxLatencyWriter struct {
668 dst io.Writer
669 flush func() error
670 latency time.Duration
671
672 mu sync.Mutex
673 t *time.Timer
674 flushPending bool
675 }
676
677 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
678 m.mu.Lock()
679 defer m.mu.Unlock()
680 n, err = m.dst.Write(p)
681 if m.latency < 0 {
682 m.flush()
683 return
684 }
685 if m.flushPending {
686 return
687 }
688 if m.t == nil {
689 m.t = time.AfterFunc(m.latency, m.delayedFlush)
690 } else {
691 m.t.Reset(m.latency)
692 }
693 m.flushPending = true
694 return
695 }
696
697 func (m *maxLatencyWriter) delayedFlush() {
698 m.mu.Lock()
699 defer m.mu.Unlock()
700 if !m.flushPending {
701 return
702 }
703 m.flush()
704 m.flushPending = false
705 }
706
707 func (m *maxLatencyWriter) stop() {
708 m.mu.Lock()
709 defer m.mu.Unlock()
710 m.flushPending = false
711 if m.t != nil {
712 m.t.Stop()
713 }
714 }
715
716 func upgradeType(h http.Header) string {
717 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
718 return ""
719 }
720 return h.Get("Upgrade")
721 }
722
723 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
724 reqUpType := upgradeType(req.Header)
725 resUpType := upgradeType(res.Header)
726 if !ascii.IsPrint(resUpType) {
727 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
728 }
729 if !ascii.EqualFold(reqUpType, resUpType) {
730 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
731 return
732 }
733
734 backConn, ok := res.Body.(io.ReadWriteCloser)
735 if !ok {
736 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
737 return
738 }
739
740 rc := http.NewResponseController(rw)
741 conn, brw, hijackErr := rc.Hijack()
742 if errors.Is(hijackErr, http.ErrNotSupported) {
743 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
744 return
745 }
746
747 backConnCloseCh := make(chan bool)
748 go func() {
749
750
751 select {
752 case <-req.Context().Done():
753 case <-backConnCloseCh:
754 }
755 backConn.Close()
756 }()
757 defer close(backConnCloseCh)
758
759 if hijackErr != nil {
760 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
761 return
762 }
763 defer conn.Close()
764
765 copyHeader(rw.Header(), res.Header)
766
767 res.Header = rw.Header()
768 res.Body = nil
769 if err := res.Write(brw); err != nil {
770 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
771 return
772 }
773 if err := brw.Flush(); err != nil {
774 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
775 return
776 }
777 errc := make(chan error, 1)
778 spc := switchProtocolCopier{user: conn, backend: backConn}
779 go spc.copyToBackend(errc)
780 go spc.copyFromBackend(errc)
781 <-errc
782 }
783
784
785
786 type switchProtocolCopier struct {
787 user, backend io.ReadWriter
788 }
789
790 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
791 _, err := io.Copy(c.user, c.backend)
792 errc <- err
793 }
794
795 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
796 _, err := io.Copy(c.backend, c.user)
797 errc <- err
798 }
799
800 func cleanQueryParams(s string) string {
801 reencode := func(s string) string {
802 v, _ := url.ParseQuery(s)
803 return v.Encode()
804 }
805 for i := 0; i < len(s); {
806 switch s[i] {
807 case ';':
808 return reencode(s)
809 case '%':
810 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
811 return reencode(s)
812 }
813 i += 3
814 default:
815 i++
816 }
817 }
818 return s
819 }
820
821 func ishex(c byte) bool {
822 switch {
823 case '0' <= c && c <= '9':
824 return true
825 case 'a' <= c && c <= 'f':
826 return true
827 case 'A' <= c && c <= 'F':
828 return true
829 }
830 return false
831 }
832
View as plain text