Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "testing"
44 "testing/iotest"
45 "time"
46
47 "golang.org/x/net/http/httpguts"
48 )
49
50
51
52
53
54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55 if r.FormValue("close") == "true" {
56 w.Header().Set("Connection", "close")
57 }
58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59 w.Write([]byte(r.RemoteAddr))
60
61
62
63 if c, ok := ResponseWriterConnForTesting(w); ok {
64 fmt.Fprintf(w, ", %T %p", c, c)
65 }
66 })
67
68
69 type testCloseConn struct {
70 net.Conn
71 set *testConnSet
72 }
73
74 func (c *testCloseConn) Close() error {
75 c.set.remove(c)
76 return c.Conn.Close()
77 }
78
79
80
81 type testConnSet struct {
82 t *testing.T
83 mu sync.Mutex
84 closed map[net.Conn]bool
85 list []net.Conn
86 }
87
88 func (tcs *testConnSet) insert(c net.Conn) {
89 tcs.mu.Lock()
90 defer tcs.mu.Unlock()
91 tcs.closed[c] = false
92 tcs.list = append(tcs.list, c)
93 }
94
95 func (tcs *testConnSet) remove(c net.Conn) {
96 tcs.mu.Lock()
97 defer tcs.mu.Unlock()
98 tcs.closed[c] = true
99 }
100
101
102 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
103 connSet := &testConnSet{
104 t: t,
105 closed: make(map[net.Conn]bool),
106 }
107 dial := func(n, addr string) (net.Conn, error) {
108 c, err := net.Dial(n, addr)
109 if err != nil {
110 return nil, err
111 }
112 tc := &testCloseConn{c, connSet}
113 connSet.insert(tc)
114 return tc, nil
115 }
116 return connSet, dial
117 }
118
119 func (tcs *testConnSet) check(t *testing.T) {
120 tcs.mu.Lock()
121 defer tcs.mu.Unlock()
122 for i := 4; i >= 0; i-- {
123 for i, c := range tcs.list {
124 if tcs.closed[c] {
125 continue
126 }
127 if i != 0 {
128
129
130 tcs.mu.Unlock()
131 time.Sleep(50 * time.Millisecond)
132 tcs.mu.Lock()
133 continue
134 }
135 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
136 }
137 }
138 }
139
140 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
141 func testReuseRequest(t *testing.T, mode testMode) {
142 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
143 w.Write([]byte("{}"))
144 })).ts
145
146 c := ts.Client()
147 req, _ := NewRequest("GET", ts.URL, nil)
148 res, err := c.Do(req)
149 if err != nil {
150 t.Fatal(err)
151 }
152 err = res.Body.Close()
153 if err != nil {
154 t.Fatal(err)
155 }
156
157 res, err = c.Do(req)
158 if err != nil {
159 t.Fatal(err)
160 }
161 err = res.Body.Close()
162 if err != nil {
163 t.Fatal(err)
164 }
165 }
166
167
168
169 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
170 func testTransportKeepAlives(t *testing.T, mode testMode) {
171 ts := newClientServerTest(t, mode, hostPortHandler).ts
172
173 c := ts.Client()
174 for _, disableKeepAlive := range []bool{false, true} {
175 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
176 fetch := func(n int) string {
177 res, err := c.Get(ts.URL)
178 if err != nil {
179 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
180 }
181 body, err := io.ReadAll(res.Body)
182 if err != nil {
183 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
184 }
185 return string(body)
186 }
187
188 body1 := fetch(1)
189 body2 := fetch(2)
190
191 bodiesDiffer := body1 != body2
192 if bodiesDiffer != disableKeepAlive {
193 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
194 disableKeepAlive, bodiesDiffer, body1, body2)
195 }
196 }
197 }
198
199 func TestTransportConnectionCloseOnResponse(t *testing.T) {
200 run(t, testTransportConnectionCloseOnResponse)
201 }
202 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
203 ts := newClientServerTest(t, mode, hostPortHandler).ts
204
205 connSet, testDial := makeTestDial(t)
206
207 c := ts.Client()
208 tr := c.Transport.(*Transport)
209 tr.Dial = testDial
210
211 for _, connectionClose := range []bool{false, true} {
212 fetch := func(n int) string {
213 req := new(Request)
214 var err error
215 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
216 if err != nil {
217 t.Fatalf("URL parse error: %v", err)
218 }
219 req.Method = "GET"
220 req.Proto = "HTTP/1.1"
221 req.ProtoMajor = 1
222 req.ProtoMinor = 1
223
224 res, err := c.Do(req)
225 if err != nil {
226 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
227 }
228 defer res.Body.Close()
229 body, err := io.ReadAll(res.Body)
230 if err != nil {
231 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
232 }
233 return string(body)
234 }
235
236 body1 := fetch(1)
237 body2 := fetch(2)
238 bodiesDiffer := body1 != body2
239 if bodiesDiffer != connectionClose {
240 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
241 connectionClose, bodiesDiffer, body1, body2)
242 }
243
244 tr.CloseIdleConnections()
245 }
246
247 connSet.check(t)
248 }
249
250
251
252
253
254
255
256 func TestTransportConnectionCloseOnRequest(t *testing.T) {
257 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
258 }
259 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
260 ts := newClientServerTest(t, mode, hostPortHandler).ts
261
262 connSet, testDial := makeTestDial(t)
263
264 c := ts.Client()
265 tr := c.Transport.(*Transport)
266 tr.Dial = testDial
267 for _, reqClose := range []bool{false, true} {
268 fetch := func(n int) string {
269 req := new(Request)
270 var err error
271 req.URL, err = url.Parse(ts.URL)
272 if err != nil {
273 t.Fatalf("URL parse error: %v", err)
274 }
275 req.Method = "GET"
276 req.Proto = "HTTP/1.1"
277 req.ProtoMajor = 1
278 req.ProtoMinor = 1
279 req.Close = reqClose
280
281 res, err := c.Do(req)
282 if err != nil {
283 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
284 }
285 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
286 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
287 reqClose, got, !reqClose)
288 }
289 body, err := io.ReadAll(res.Body)
290 if err != nil {
291 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
292 }
293 return string(body)
294 }
295
296 body1 := fetch(1)
297 body2 := fetch(2)
298
299 got := 1
300 if body1 != body2 {
301 got++
302 }
303 want := 1
304 if reqClose {
305 want = 2
306 }
307 if got != want {
308 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
309 reqClose, got, want, body1, body2)
310 }
311
312 tr.CloseIdleConnections()
313 }
314
315 connSet.check(t)
316 }
317
318
319
320
321 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
322 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
323 }
324 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
325 ts := newClientServerTest(t, mode, hostPortHandler).ts
326
327 c := ts.Client()
328 c.Transport.(*Transport).DisableKeepAlives = true
329
330 res, err := c.Get(ts.URL)
331 if err != nil {
332 t.Fatal(err)
333 }
334 res.Body.Close()
335 if res.Header.Get("X-Saw-Close") != "true" {
336 t.Errorf("handler didn't see Connection: close ")
337 }
338 }
339
340
341
342 func TestTransportRespectRequestWantsClose(t *testing.T) {
343 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
344 }
345 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
346 tests := []struct {
347 disableKeepAlives bool
348 close bool
349 }{
350 {disableKeepAlives: false, close: false},
351 {disableKeepAlives: false, close: true},
352 {disableKeepAlives: true, close: false},
353 {disableKeepAlives: true, close: true},
354 }
355
356 for _, tc := range tests {
357 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
358 func(t *testing.T) {
359 ts := newClientServerTest(t, mode, hostPortHandler).ts
360
361 c := ts.Client()
362 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
363 req, err := NewRequest("GET", ts.URL, nil)
364 if err != nil {
365 t.Fatal(err)
366 }
367 count := 0
368 trace := &httptrace.ClientTrace{
369 WroteHeaderField: func(key string, field []string) {
370 if key != "Connection" {
371 return
372 }
373 if httpguts.HeaderValuesContainsToken(field, "close") {
374 count += 1
375 }
376 },
377 }
378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
379 req.Close = tc.close
380 res, err := c.Do(req)
381 if err != nil {
382 t.Fatal(err)
383 }
384 defer res.Body.Close()
385 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
386 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
387 }
388 })
389 }
390
391 }
392
393 func TestTransportIdleCacheKeys(t *testing.T) {
394 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
395 }
396 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
397 ts := newClientServerTest(t, mode, hostPortHandler).ts
398 c := ts.Client()
399 tr := c.Transport.(*Transport)
400
401 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
402 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
403 }
404
405 resp, err := c.Get(ts.URL)
406 if err != nil {
407 t.Error(err)
408 }
409 io.ReadAll(resp.Body)
410
411 keys := tr.IdleConnKeysForTesting()
412 if e, g := 1, len(keys); e != g {
413 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
414 }
415
416 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
417 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
418 }
419
420 tr.CloseIdleConnections()
421 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
422 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
423 }
424 }
425
426
427
428 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
429 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
430 const msg = "foobar"
431
432 var addrSeen map[string]int
433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
434 addrSeen[r.RemoteAddr]++
435 if r.URL.Path == "/chunked/" {
436 w.WriteHeader(200)
437 w.(Flusher).Flush()
438 } else {
439 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
440 w.WriteHeader(200)
441 }
442 w.Write([]byte(msg))
443 })).ts
444
445 for pi, path := range []string{"/content-length/", "/chunked/"} {
446 wantLen := []int{len(msg), -1}[pi]
447 addrSeen = make(map[string]int)
448 for i := 0; i < 3; i++ {
449 res, err := ts.Client().Get(ts.URL + path)
450 if err != nil {
451 t.Errorf("Get %s: %v", path, err)
452 continue
453 }
454
455
456
457
458
459 defer res.Body.Close()
460
461 if res.ContentLength != int64(wantLen) {
462 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
463 }
464 got, err := io.ReadAll(res.Body)
465 if string(got) != msg || err != nil {
466 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
467 }
468 }
469 if len(addrSeen) != 1 {
470 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
471 }
472 }
473 }
474
475 func TestTransportMaxPerHostIdleConns(t *testing.T) {
476 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
477 }
478 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
479 stop := make(chan struct{})
480 defer close(stop)
481
482 resch := make(chan string)
483 gotReq := make(chan bool)
484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
485 gotReq <- true
486 var msg string
487 select {
488 case <-stop:
489 return
490 case msg = <-resch:
491 }
492 _, err := w.Write([]byte(msg))
493 if err != nil {
494 t.Errorf("Write: %v", err)
495 return
496 }
497 })).ts
498
499 c := ts.Client()
500 tr := c.Transport.(*Transport)
501 maxIdleConnsPerHost := 2
502 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
503
504
505
506 donech := make(chan bool)
507 doReq := func() {
508 defer func() {
509 select {
510 case <-stop:
511 return
512 case donech <- t.Failed():
513 }
514 }()
515 resp, err := c.Get(ts.URL)
516 if err != nil {
517 t.Error(err)
518 return
519 }
520 if _, err := io.ReadAll(resp.Body); err != nil {
521 t.Errorf("ReadAll: %v", err)
522 return
523 }
524 }
525 go doReq()
526 <-gotReq
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531
532 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
533 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
534 }
535
536 resch <- "res1"
537 <-donech
538 keys := tr.IdleConnKeysForTesting()
539 if e, g := 1, len(keys); e != g {
540 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
541 }
542 addr := ts.Listener.Addr().String()
543 cacheKey := "|http|" + addr
544 if keys[0] != cacheKey {
545 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
546 }
547 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
548 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
549 }
550
551 resch <- "res2"
552 <-donech
553 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
554 t.Errorf("after second response, idle conns = %d; want %d", g, w)
555 }
556
557 resch <- "res3"
558 <-donech
559 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
560 t.Errorf("after third response, idle conns = %d; want %d", g, w)
561 }
562 }
563
564 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
565 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
566 }
567 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
569 _, err := w.Write([]byte("foo"))
570 if err != nil {
571 t.Fatalf("Write: %v", err)
572 }
573 })).ts
574 c := ts.Client()
575 tr := c.Transport.(*Transport)
576 dialStarted := make(chan struct{})
577 stallDial := make(chan struct{})
578 tr.Dial = func(network, addr string) (net.Conn, error) {
579 dialStarted <- struct{}{}
580 <-stallDial
581 return net.Dial(network, addr)
582 }
583
584 tr.DisableKeepAlives = true
585 tr.MaxConnsPerHost = 1
586
587 preDial := make(chan struct{})
588 reqComplete := make(chan struct{})
589 doReq := func(reqId string) {
590 req, _ := NewRequest("GET", ts.URL, nil)
591 trace := &httptrace.ClientTrace{
592 GetConn: func(hostPort string) {
593 preDial <- struct{}{}
594 },
595 }
596 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
597 resp, err := tr.RoundTrip(req)
598 if err != nil {
599 t.Errorf("unexpected error for request %s: %v", reqId, err)
600 }
601 _, err = io.ReadAll(resp.Body)
602 if err != nil {
603 t.Errorf("unexpected error for request %s: %v", reqId, err)
604 }
605 reqComplete <- struct{}{}
606 }
607
608 go doReq("req1")
609 <-preDial
610 <-dialStarted
611
612
613 go doReq("req2")
614 <-preDial
615 select {
616 case <-dialStarted:
617 t.Error("req2 dial started while req1 dial in progress")
618 return
619 default:
620 }
621
622
623 stallDial <- struct{}{}
624 <-reqComplete
625
626
627 <-dialStarted
628 stallDial <- struct{}{}
629 <-reqComplete
630 }
631
632 func TestTransportMaxConnsPerHost(t *testing.T) {
633 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
634 }
635 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
636 CondSkipHTTP2(t)
637
638 h := HandlerFunc(func(w ResponseWriter, r *Request) {
639 _, err := w.Write([]byte("foo"))
640 if err != nil {
641 t.Fatalf("Write: %v", err)
642 }
643 })
644
645 ts := newClientServerTest(t, mode, h).ts
646 c := ts.Client()
647 tr := c.Transport.(*Transport)
648 tr.MaxConnsPerHost = 1
649
650 mu := sync.Mutex{}
651 var conns []net.Conn
652 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
653 tr.Dial = func(network, addr string) (net.Conn, error) {
654 atomic.AddInt32(&dialCnt, 1)
655 c, err := net.Dial(network, addr)
656 mu.Lock()
657 defer mu.Unlock()
658 conns = append(conns, c)
659 return c, err
660 }
661
662 doReq := func() {
663 trace := &httptrace.ClientTrace{
664 GotConn: func(connInfo httptrace.GotConnInfo) {
665 if !connInfo.Reused {
666 atomic.AddInt32(&gotConnCnt, 1)
667 }
668 },
669 TLSHandshakeStart: func() {
670 atomic.AddInt32(&tlsHandshakeCnt, 1)
671 },
672 }
673 req, _ := NewRequest("GET", ts.URL, nil)
674 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
675
676 resp, err := c.Do(req)
677 if err != nil {
678 t.Fatalf("request failed: %v", err)
679 }
680 defer resp.Body.Close()
681 _, err = io.ReadAll(resp.Body)
682 if err != nil {
683 t.Fatalf("read body failed: %v", err)
684 }
685 }
686
687 wg := sync.WaitGroup{}
688 for i := 0; i < 10; i++ {
689 wg.Add(1)
690 go func() {
691 defer wg.Done()
692 doReq()
693 }()
694 }
695 wg.Wait()
696
697 expected := int32(tr.MaxConnsPerHost)
698 if dialCnt != expected {
699 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
700 }
701 if gotConnCnt != expected {
702 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
703 }
704 if ts.TLS != nil && tlsHandshakeCnt != expected {
705 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
706 }
707
708 if t.Failed() {
709 t.FailNow()
710 }
711
712 mu.Lock()
713 for _, c := range conns {
714 c.Close()
715 }
716 conns = nil
717 mu.Unlock()
718 tr.CloseIdleConnections()
719
720 doReq()
721 expected++
722 if dialCnt != expected {
723 t.Errorf("round 2: too many dials: %d", dialCnt)
724 }
725 if gotConnCnt != expected {
726 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
727 }
728 if ts.TLS != nil && tlsHandshakeCnt != expected {
729 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
730 }
731 }
732
733 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
734 run(t, testTransportMaxConnsPerHostDialCancellation,
735 testNotParallel,
736 []testMode{http1Mode, https1Mode, http2Mode},
737 )
738 }
739
740 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
741 CondSkipHTTP2(t)
742
743 h := HandlerFunc(func(w ResponseWriter, r *Request) {
744 _, err := w.Write([]byte("foo"))
745 if err != nil {
746 t.Fatalf("Write: %v", err)
747 }
748 })
749
750 cst := newClientServerTest(t, mode, h)
751 defer cst.close()
752 ts := cst.ts
753 c := ts.Client()
754 tr := c.Transport.(*Transport)
755 tr.MaxConnsPerHost = 1
756
757
758 ctx, cancel := context.WithCancel(context.Background())
759 defer cancel()
760 SetPendingDialHooks(cancel, nil)
761 defer SetPendingDialHooks(nil, nil)
762
763 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
764 _, err := c.Do(req)
765 if !errors.Is(err, context.Canceled) {
766 t.Errorf("expected error %v, got %v", context.Canceled, err)
767 }
768
769
770 SetPendingDialHooks(nil, nil)
771 req, _ = NewRequest("GET", ts.URL, nil)
772 resp, err := c.Do(req)
773 if err != nil {
774 t.Fatalf("request failed: %v", err)
775 }
776 defer resp.Body.Close()
777 _, err = io.ReadAll(resp.Body)
778 if err != nil {
779 t.Fatalf("read body failed: %v", err)
780 }
781 }
782
783 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
784 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
785 }
786 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
788 io.WriteString(w, r.RemoteAddr)
789 })).ts
790
791 c := ts.Client()
792 tr := c.Transport.(*Transport)
793
794 doReq := func(name string) {
795
796
797 res, err := c.Post(ts.URL, "", nil)
798 if err != nil {
799 t.Fatalf("%s: %v", name, err)
800 }
801 if res.StatusCode != 200 {
802 t.Fatalf("%s: %v", name, res.Status)
803 }
804 defer res.Body.Close()
805 slurp, err := io.ReadAll(res.Body)
806 if err != nil {
807 t.Fatalf("%s: %v", name, err)
808 }
809 t.Logf("%s: ok (%q)", name, slurp)
810 }
811
812 doReq("first")
813 keys1 := tr.IdleConnKeysForTesting()
814
815 ts.CloseClientConnections()
816
817 var keys2 []string
818 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
819 keys2 = tr.IdleConnKeysForTesting()
820 if len(keys2) != 0 {
821 if d > 0 {
822 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
823 }
824 return false
825 }
826 return true
827 })
828
829 doReq("second")
830 }
831
832
833
834 func TestTransportServerClosingUnexpectedly(t *testing.T) {
835 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
836 }
837 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
838 ts := newClientServerTest(t, mode, hostPortHandler).ts
839 c := ts.Client()
840
841 fetch := func(n, retries int) string {
842 condFatalf := func(format string, arg ...any) {
843 if retries <= 0 {
844 t.Fatalf(format, arg...)
845 }
846 t.Logf("retrying shortly after expected error: "+format, arg...)
847 time.Sleep(time.Second / time.Duration(retries))
848 }
849 for retries >= 0 {
850 retries--
851 res, err := c.Get(ts.URL)
852 if err != nil {
853 condFatalf("error in req #%d, GET: %v", n, err)
854 continue
855 }
856 body, err := io.ReadAll(res.Body)
857 if err != nil {
858 condFatalf("error in req #%d, ReadAll: %v", n, err)
859 continue
860 }
861 res.Body.Close()
862 return string(body)
863 }
864 panic("unreachable")
865 }
866
867 body1 := fetch(1, 0)
868 body2 := fetch(2, 0)
869
870
871
872
873
874
875
876
877 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
878
879 body3 := fetch(3, 5)
880
881 if body1 != body2 {
882 t.Errorf("expected body1 and body2 to be equal")
883 }
884 if body2 == body3 {
885 t.Errorf("expected body2 and body3 to be different")
886 }
887 }
888
889
890
891 func TestStressSurpriseServerCloses(t *testing.T) {
892 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
893 }
894 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
895 if testing.Short() {
896 t.Skip("skipping test in short mode")
897 }
898 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
899 w.Header().Set("Content-Length", "5")
900 w.Header().Set("Content-Type", "text/plain")
901 w.Write([]byte("Hello"))
902 w.(Flusher).Flush()
903 conn, buf, _ := w.(Hijacker).Hijack()
904 buf.Flush()
905 conn.Close()
906 })).ts
907 c := ts.Client()
908
909
910
911
912
913
914
915 const (
916 numClients = 20
917 reqsPerClient = 25
918 )
919 var wg sync.WaitGroup
920 wg.Add(numClients * reqsPerClient)
921 for i := 0; i < numClients; i++ {
922 go func() {
923 for i := 0; i < reqsPerClient; i++ {
924 res, err := c.Get(ts.URL)
925 if err == nil {
926
927
928
929
930
931
932 res.Body.Close()
933 }
934 wg.Done()
935 }
936 }()
937 }
938
939
940 wg.Wait()
941 }
942
943
944
945 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
946 func testTransportHeadResponses(t *testing.T, mode testMode) {
947 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
948 if r.Method != "HEAD" {
949 panic("expected HEAD; got " + r.Method)
950 }
951 w.Header().Set("Content-Length", "123")
952 w.WriteHeader(200)
953 })).ts
954 c := ts.Client()
955
956 for i := 0; i < 2; i++ {
957 res, err := c.Head(ts.URL)
958 if err != nil {
959 t.Errorf("error on loop %d: %v", i, err)
960 continue
961 }
962 if e, g := "123", res.Header.Get("Content-Length"); e != g {
963 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
964 }
965 if e, g := int64(123), res.ContentLength; e != g {
966 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
967 }
968 if all, err := io.ReadAll(res.Body); err != nil {
969 t.Errorf("loop %d: Body ReadAll: %v", i, err)
970 } else if len(all) != 0 {
971 t.Errorf("Bogus body %q", all)
972 }
973 }
974 }
975
976
977
978 func TestTransportHeadChunkedResponse(t *testing.T) {
979 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
980 }
981 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
982 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
983 if r.Method != "HEAD" {
984 panic("expected HEAD; got " + r.Method)
985 }
986 w.Header().Set("Transfer-Encoding", "chunked")
987 w.Header().Set("x-client-ipport", r.RemoteAddr)
988 w.WriteHeader(200)
989 })).ts
990 c := ts.Client()
991
992
993
994 didRead := make(chan bool)
995 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
996 defer SetReadLoopBeforeNextReadHook(nil)
997
998 res1, err := c.Head(ts.URL)
999 <-didRead
1000
1001 if err != nil {
1002 t.Fatalf("request 1 error: %v", err)
1003 }
1004
1005 res2, err := c.Head(ts.URL)
1006 <-didRead
1007
1008 if err != nil {
1009 t.Fatalf("request 2 error: %v", err)
1010 }
1011 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1012 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1013 }
1014 }
1015
1016 var roundTripTests = []struct {
1017 accept string
1018 expectAccept string
1019 compressed bool
1020 }{
1021
1022 {"", "gzip", false},
1023
1024 {"foo", "foo", false},
1025
1026 {"gzip", "gzip", true},
1027 }
1028
1029
1030 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1031 func testRoundTripGzip(t *testing.T, mode testMode) {
1032 const responseBody = "test response body"
1033 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1034 accept := req.Header.Get("Accept-Encoding")
1035 if expect := req.FormValue("expect_accept"); accept != expect {
1036 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1037 req.FormValue("testnum"), accept, expect)
1038 }
1039 if accept == "gzip" {
1040 rw.Header().Set("Content-Encoding", "gzip")
1041 gz := gzip.NewWriter(rw)
1042 gz.Write([]byte(responseBody))
1043 gz.Close()
1044 } else {
1045 rw.Header().Set("Content-Encoding", accept)
1046 rw.Write([]byte(responseBody))
1047 }
1048 })).ts
1049 tr := ts.Client().Transport.(*Transport)
1050
1051 for i, test := range roundTripTests {
1052
1053 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1054 if test.accept != "" {
1055 req.Header.Set("Accept-Encoding", test.accept)
1056 }
1057 res, err := tr.RoundTrip(req)
1058 if err != nil {
1059 t.Errorf("%d. RoundTrip: %v", i, err)
1060 continue
1061 }
1062 var body []byte
1063 if test.compressed {
1064 var r *gzip.Reader
1065 r, err = gzip.NewReader(res.Body)
1066 if err != nil {
1067 t.Errorf("%d. gzip NewReader: %v", i, err)
1068 continue
1069 }
1070 body, err = io.ReadAll(r)
1071 res.Body.Close()
1072 } else {
1073 body, err = io.ReadAll(res.Body)
1074 }
1075 if err != nil {
1076 t.Errorf("%d. Error: %q", i, err)
1077 continue
1078 }
1079 if g, e := string(body), responseBody; g != e {
1080 t.Errorf("%d. body = %q; want %q", i, g, e)
1081 }
1082 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1083 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1084 }
1085 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1086 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1087 }
1088 }
1089
1090 }
1091
1092 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1093 func testTransportGzip(t *testing.T, mode testMode) {
1094 if mode == http2Mode {
1095 t.Skip("https://go.dev/issue/56020")
1096 }
1097 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1098 const nRandBytes = 1024 * 1024
1099 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1100 if req.Method == "HEAD" {
1101 if g := req.Header.Get("Accept-Encoding"); g != "" {
1102 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1103 }
1104 return
1105 }
1106 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1107 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1108 }
1109 rw.Header().Set("Content-Encoding", "gzip")
1110
1111 var w io.Writer = rw
1112 var buf bytes.Buffer
1113 if req.FormValue("chunked") == "0" {
1114 w = &buf
1115 defer io.Copy(rw, &buf)
1116 defer func() {
1117 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1118 }()
1119 }
1120 gz := gzip.NewWriter(w)
1121 gz.Write([]byte(testString))
1122 if req.FormValue("body") == "large" {
1123 io.CopyN(gz, rand.Reader, nRandBytes)
1124 }
1125 gz.Close()
1126 })).ts
1127 c := ts.Client()
1128
1129 for _, chunked := range []string{"1", "0"} {
1130
1131 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1132 if err != nil {
1133 t.Fatalf("large get: %v", err)
1134 }
1135 buf := make([]byte, len(testString))
1136 n, err := io.ReadFull(res.Body, buf)
1137 if err != nil {
1138 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1139 }
1140 if e, g := testString, string(buf); e != g {
1141 t.Errorf("partial read got %q, expected %q", g, e)
1142 }
1143 res.Body.Close()
1144
1145 n, err = res.Body.Read(buf)
1146 if n != 0 || err == nil {
1147 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1148 }
1149
1150
1151 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1152 if err != nil {
1153 t.Fatal(err)
1154 }
1155 body, err := io.ReadAll(res.Body)
1156 if err != nil {
1157 t.Fatal(err)
1158 }
1159 if g, e := string(body), testString; g != e {
1160 t.Fatalf("body = %q; want %q", g, e)
1161 }
1162 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1163 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1164 }
1165
1166
1167 n, err = res.Body.Read(buf)
1168 if n != 0 || err == nil {
1169 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1170 }
1171 res.Body.Close()
1172 n, err = res.Body.Read(buf)
1173 if n != 0 || err == nil {
1174 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1175 }
1176 }
1177
1178
1179 res, err := c.Head(ts.URL)
1180 if err != nil {
1181 t.Fatalf("Head: %v", err)
1182 }
1183 if res.StatusCode != 200 {
1184 t.Errorf("Head status=%d; want=200", res.StatusCode)
1185 }
1186 }
1187
1188
1189
1190 func TestTransportExpect100Continue(t *testing.T) {
1191 run(t, testTransportExpect100Continue, []testMode{http1Mode})
1192 }
1193 func testTransportExpect100Continue(t *testing.T, mode testMode) {
1194 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1195 switch req.URL.Path {
1196 case "/100":
1197
1198 if _, err := io.Copy(io.Discard, req.Body); err != nil {
1199 t.Error("Failed to read Body", err)
1200 }
1201 rw.WriteHeader(StatusOK)
1202 case "/200":
1203
1204
1205 rw.WriteHeader(StatusOK)
1206 case "/500":
1207 rw.WriteHeader(StatusInternalServerError)
1208 case "/keepalive":
1209
1210 _, bufrw, err := rw.(Hijacker).Hijack()
1211 if err != nil {
1212 log.Fatal(err)
1213 }
1214 bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
1215 bufrw.WriteString("Content-Length: 0\r\n\r\n")
1216 bufrw.Flush()
1217 case "/timeout":
1218
1219
1220 conn, bufrw, err := rw.(Hijacker).Hijack()
1221 if err != nil {
1222 log.Fatal(err)
1223 }
1224 if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
1225 t.Error("Failed to read Body", err)
1226 }
1227 bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
1228 bufrw.Flush()
1229 conn.Close()
1230 }
1231
1232 })).ts
1233
1234 tests := []struct {
1235 path string
1236 body []byte
1237 sent int
1238 status int
1239 }{
1240 {path: "/100", body: []byte("hello"), sent: 5, status: 200},
1241 {path: "/200", body: []byte("hello"), sent: 0, status: 200},
1242 {path: "/500", body: []byte("hello"), sent: 0, status: 500},
1243 {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500},
1244 {path: "/timeout", body: []byte("hello"), sent: 5, status: 200},
1245 }
1246
1247 c := ts.Client()
1248 for i, v := range tests {
1249 tr := &Transport{
1250 ExpectContinueTimeout: 2 * time.Second,
1251 }
1252 defer tr.CloseIdleConnections()
1253 c.Transport = tr
1254 body := bytes.NewReader(v.body)
1255 req, err := NewRequest("PUT", ts.URL+v.path, body)
1256 if err != nil {
1257 t.Fatal(err)
1258 }
1259 req.Header.Set("Expect", "100-continue")
1260 req.ContentLength = int64(len(v.body))
1261
1262 resp, err := c.Do(req)
1263 if err != nil {
1264 t.Fatal(err)
1265 }
1266 resp.Body.Close()
1267
1268 sent := len(v.body) - body.Len()
1269 if v.status != resp.StatusCode {
1270 t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
1271 }
1272 if v.sent != sent {
1273 t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
1274 }
1275 }
1276 }
1277
1278 func TestSOCKS5Proxy(t *testing.T) {
1279 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1280 }
1281 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1282 ch := make(chan string, 1)
1283 l := newLocalListener(t)
1284 defer l.Close()
1285 defer close(ch)
1286 proxy := func(t *testing.T) {
1287 s, err := l.Accept()
1288 if err != nil {
1289 t.Errorf("socks5 proxy Accept(): %v", err)
1290 return
1291 }
1292 defer s.Close()
1293 var buf [22]byte
1294 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1295 t.Errorf("socks5 proxy initial read: %v", err)
1296 return
1297 }
1298 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1299 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1300 return
1301 }
1302 if _, err := s.Write([]byte{5, 0}); err != nil {
1303 t.Errorf("socks5 proxy initial write: %v", err)
1304 return
1305 }
1306 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1307 t.Errorf("socks5 proxy second read: %v", err)
1308 return
1309 }
1310 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1311 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1312 return
1313 }
1314 var ipLen int
1315 switch buf[3] {
1316 case 1:
1317 ipLen = net.IPv4len
1318 case 4:
1319 ipLen = net.IPv6len
1320 default:
1321 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1322 return
1323 }
1324 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1325 t.Errorf("socks5 proxy address read: %v", err)
1326 return
1327 }
1328 ip := net.IP(buf[4 : ipLen+4])
1329 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1330 copy(buf[:3], []byte{5, 0, 0})
1331 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1332 t.Errorf("socks5 proxy connect write: %v", err)
1333 return
1334 }
1335 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1336
1337
1338 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1339 targetConn, err := net.Dial("tcp", targetHost)
1340 if err != nil {
1341 t.Errorf("net.Dial failed")
1342 return
1343 }
1344 go io.Copy(targetConn, s)
1345 io.Copy(s, targetConn)
1346 targetConn.Close()
1347 }
1348
1349 pu, err := url.Parse("socks5://" + l.Addr().String())
1350 if err != nil {
1351 t.Fatal(err)
1352 }
1353
1354 sentinelHeader := "X-Sentinel"
1355 sentinelValue := "12345"
1356 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1357 w.Header().Set(sentinelHeader, sentinelValue)
1358 })
1359 for _, useTLS := range []bool{false, true} {
1360 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1361 ts := newClientServerTest(t, mode, h).ts
1362 go proxy(t)
1363 c := ts.Client()
1364 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1365 r, err := c.Head(ts.URL)
1366 if err != nil {
1367 t.Fatal(err)
1368 }
1369 if r.Header.Get(sentinelHeader) != sentinelValue {
1370 t.Errorf("Failed to retrieve sentinel value")
1371 }
1372 got := <-ch
1373 ts.Close()
1374 tsu, err := url.Parse(ts.URL)
1375 if err != nil {
1376 t.Fatal(err)
1377 }
1378 want := "proxy for " + tsu.Host
1379 if got != want {
1380 t.Errorf("got %q, want %q", got, want)
1381 }
1382 })
1383 }
1384 }
1385
1386 func TestTransportProxy(t *testing.T) {
1387 defer afterTest(t)
1388 testCases := []struct{ siteMode, proxyMode testMode }{
1389 {http1Mode, http1Mode},
1390 {http1Mode, https1Mode},
1391 {https1Mode, http1Mode},
1392 {https1Mode, https1Mode},
1393 }
1394 for _, testCase := range testCases {
1395 siteMode := testCase.siteMode
1396 proxyMode := testCase.proxyMode
1397 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1398 siteCh := make(chan *Request, 1)
1399 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1400 siteCh <- r
1401 })
1402 proxyCh := make(chan *Request, 1)
1403 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1404 proxyCh <- r
1405
1406 if r.Method == "CONNECT" {
1407 hijacker, ok := w.(Hijacker)
1408 if !ok {
1409 t.Errorf("hijack not allowed")
1410 return
1411 }
1412 clientConn, _, err := hijacker.Hijack()
1413 if err != nil {
1414 t.Errorf("hijacking failed")
1415 return
1416 }
1417 res := &Response{
1418 StatusCode: StatusOK,
1419 Proto: "HTTP/1.1",
1420 ProtoMajor: 1,
1421 ProtoMinor: 1,
1422 Header: make(Header),
1423 }
1424
1425 targetConn, err := net.Dial("tcp", r.URL.Host)
1426 if err != nil {
1427 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1428 return
1429 }
1430
1431 if err := res.Write(clientConn); err != nil {
1432 t.Errorf("Writing 200 OK failed: %v", err)
1433 return
1434 }
1435
1436 go io.Copy(targetConn, clientConn)
1437 go func() {
1438 io.Copy(clientConn, targetConn)
1439 targetConn.Close()
1440 }()
1441 }
1442 })
1443 ts := newClientServerTest(t, siteMode, h1).ts
1444 proxy := newClientServerTest(t, proxyMode, h2).ts
1445
1446 pu, err := url.Parse(proxy.URL)
1447 if err != nil {
1448 t.Fatal(err)
1449 }
1450
1451
1452
1453
1454 c := proxy.Client()
1455 if siteMode == https1Mode {
1456 c = ts.Client()
1457 }
1458
1459 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1460 if _, err := c.Head(ts.URL); err != nil {
1461 t.Error(err)
1462 }
1463 got := <-proxyCh
1464 c.Transport.(*Transport).CloseIdleConnections()
1465 ts.Close()
1466 proxy.Close()
1467 if siteMode == https1Mode {
1468
1469 if got.Method != "CONNECT" {
1470 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1471 }
1472 gotHost := got.URL.Host
1473 pu, err := url.Parse(ts.URL)
1474 if err != nil {
1475 t.Fatal("Invalid site URL")
1476 }
1477 if wantHost := pu.Host; gotHost != wantHost {
1478 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1479 }
1480
1481
1482 next := <-siteCh
1483 if next.Method != "HEAD" {
1484 t.Errorf("Wrong method at destination: %s", next.Method)
1485 }
1486 if nextURL := next.URL.String(); nextURL != "/" {
1487 t.Errorf("Wrong URL at destination: %s", nextURL)
1488 }
1489 } else {
1490 if got.Method != "HEAD" {
1491 t.Errorf("Wrong method for destination: %q", got.Method)
1492 }
1493 gotURL := got.URL.String()
1494 wantURL := ts.URL + "/"
1495 if gotURL != wantURL {
1496 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1497 }
1498 }
1499 })
1500 }
1501 }
1502
1503 func TestOnProxyConnectResponse(t *testing.T) {
1504
1505 var tcases = []struct {
1506 proxyStatusCode int
1507 err error
1508 }{
1509 {
1510 StatusOK,
1511 nil,
1512 },
1513 {
1514 StatusForbidden,
1515 errors.New("403"),
1516 },
1517 }
1518 for _, tcase := range tcases {
1519 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1520
1521 })
1522
1523 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1524
1525 if r.Method == "CONNECT" {
1526 if tcase.proxyStatusCode != StatusOK {
1527 w.WriteHeader(tcase.proxyStatusCode)
1528 return
1529 }
1530 hijacker, ok := w.(Hijacker)
1531 if !ok {
1532 t.Errorf("hijack not allowed")
1533 return
1534 }
1535 clientConn, _, err := hijacker.Hijack()
1536 if err != nil {
1537 t.Errorf("hijacking failed")
1538 return
1539 }
1540 res := &Response{
1541 StatusCode: StatusOK,
1542 Proto: "HTTP/1.1",
1543 ProtoMajor: 1,
1544 ProtoMinor: 1,
1545 Header: make(Header),
1546 }
1547
1548 targetConn, err := net.Dial("tcp", r.URL.Host)
1549 if err != nil {
1550 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1551 return
1552 }
1553
1554 if err := res.Write(clientConn); err != nil {
1555 t.Errorf("Writing 200 OK failed: %v", err)
1556 return
1557 }
1558
1559 go io.Copy(targetConn, clientConn)
1560 go func() {
1561 io.Copy(clientConn, targetConn)
1562 targetConn.Close()
1563 }()
1564 }
1565 })
1566 ts := newClientServerTest(t, https1Mode, h1).ts
1567 proxy := newClientServerTest(t, https1Mode, h2).ts
1568
1569 pu, err := url.Parse(proxy.URL)
1570 if err != nil {
1571 t.Fatal(err)
1572 }
1573
1574 c := proxy.Client()
1575
1576 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1577 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1578 if proxyURL.String() != pu.String() {
1579 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1580 }
1581
1582 if "https://"+connectReq.URL.String() != ts.URL {
1583 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1584 }
1585 return tcase.err
1586 }
1587 if _, err := c.Head(ts.URL); err != nil {
1588 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1589 t.Errorf("got %v, want %v", err, tcase.err)
1590 }
1591 }
1592 }
1593 }
1594
1595
1596
1597 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1598 setParallel(t)
1599 defer afterTest(t)
1600
1601 ctx, cancel := context.WithCancel(context.Background())
1602 defer cancel()
1603
1604 ln := newLocalListener(t)
1605 defer ln.Close()
1606 listenerDone := make(chan struct{})
1607 go func() {
1608 defer close(listenerDone)
1609 c, err := ln.Accept()
1610 if err != nil {
1611 t.Errorf("Accept: %v", err)
1612 return
1613 }
1614 defer c.Close()
1615
1616 br := bufio.NewReader(c)
1617 cr, err := ReadRequest(br)
1618 if err != nil {
1619 t.Errorf("proxy server failed to read CONNECT request")
1620 return
1621 }
1622 if cr.Method != "CONNECT" {
1623 t.Errorf("unexpected method %q", cr.Method)
1624 return
1625 }
1626
1627
1628
1629
1630 cancel()
1631 var buf [1]byte
1632 _, err = br.Read(buf[:])
1633 if err != io.EOF {
1634 t.Errorf("proxy server Read err = %v; want EOF", err)
1635 }
1636 return
1637 }()
1638
1639 c := &Client{
1640 Transport: &Transport{
1641 Proxy: func(*Request) (*url.URL, error) {
1642 return url.Parse("http://" + ln.Addr().String())
1643 },
1644 },
1645 }
1646 req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
1647 if err != nil {
1648 t.Fatal(err)
1649 }
1650 _, err = c.Do(req)
1651 if err == nil {
1652 t.Errorf("unexpected Get success")
1653 }
1654
1655
1656
1657
1658 <-listenerDone
1659 }
1660
1661
1662 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1663 defer afterTest(t)
1664
1665 var errDial = errors.New("some dial error")
1666
1667 tr := &Transport{
1668 Proxy: func(*Request) (*url.URL, error) {
1669 return url.Parse("http://proxy.fake.tld/")
1670 },
1671 Dial: func(string, string) (net.Conn, error) {
1672 return nil, errDial
1673 },
1674 }
1675 defer tr.CloseIdleConnections()
1676
1677 c := &Client{Transport: tr}
1678 req, _ := NewRequest("GET", "http://fake.tld", nil)
1679 res, err := c.Do(req)
1680 if err == nil {
1681 res.Body.Close()
1682 t.Fatal("wanted a non-nil error")
1683 }
1684
1685 uerr, ok := err.(*url.Error)
1686 if !ok {
1687 t.Fatalf("got %T, want *url.Error", err)
1688 }
1689 oe, ok := uerr.Err.(*net.OpError)
1690 if !ok {
1691 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1692 }
1693 want := &net.OpError{
1694 Op: "proxyconnect",
1695 Net: "tcp",
1696 Err: errDial,
1697 }
1698 if !reflect.DeepEqual(oe, want) {
1699 t.Errorf("Got error %#v; want %#v", oe, want)
1700 }
1701 }
1702
1703
1704
1705
1706
1707 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1708 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1709 }
1710 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1711 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1712 defer proxy.Close()
1713 c := proxy.Client()
1714
1715 tr := c.Transport.(*Transport)
1716 tr.Proxy = func(*Request) (*url.URL, error) {
1717 u, _ := url.Parse(proxy.URL)
1718 u.User = url.UserPassword("aladdin", "opensesame")
1719 return u, nil
1720 }
1721 h := tr.ProxyConnectHeader
1722 if h == nil {
1723 h = make(Header)
1724 }
1725 tr.ProxyConnectHeader = h.Clone()
1726
1727 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1728 if err != nil {
1729 t.Fatal(err)
1730 }
1731 _, err = c.Do(req)
1732 if err == nil {
1733 t.Errorf("unexpected Get success")
1734 }
1735
1736 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1737 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1738 }
1739 }
1740
1741
1742
1743
1744
1745 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1746 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1747 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1748 w.Header().Set("Content-Encoding", "gzip")
1749 w.Write(rgz)
1750 })).ts
1751
1752 c := ts.Client()
1753 res, err := c.Get(ts.URL)
1754 if err != nil {
1755 t.Fatal(err)
1756 }
1757 body, err := io.ReadAll(res.Body)
1758 if err != nil {
1759 t.Fatal(err)
1760 }
1761 if !bytes.Equal(body, rgz) {
1762 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1763 body, rgz)
1764 }
1765 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1766 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1767 }
1768 }
1769
1770
1771
1772 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1773 func testTransportGzipShort(t *testing.T, mode testMode) {
1774 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1775 w.Header().Set("Content-Encoding", "gzip")
1776 w.Write([]byte{0x1f, 0x8b})
1777 })).ts
1778
1779 c := ts.Client()
1780 res, err := c.Get(ts.URL)
1781 if err != nil {
1782 t.Fatal(err)
1783 }
1784 defer res.Body.Close()
1785 _, err = io.ReadAll(res.Body)
1786 if err == nil {
1787 t.Fatal("Expect an error from reading a body.")
1788 }
1789 if err != io.ErrUnexpectedEOF {
1790 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1791 }
1792 }
1793
1794
1795 func waitNumGoroutine(nmax int) int {
1796 nfinal := runtime.NumGoroutine()
1797 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1798 time.Sleep(50 * time.Millisecond)
1799 runtime.GC()
1800 nfinal = runtime.NumGoroutine()
1801 }
1802 return nfinal
1803 }
1804
1805
1806 func TestTransportPersistConnLeak(t *testing.T) {
1807 run(t, testTransportPersistConnLeak, testNotParallel)
1808 }
1809 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1810 if mode == http2Mode {
1811 t.Skip("flaky in HTTP/2")
1812 }
1813
1814
1815 const numReq = 25
1816 gotReqCh := make(chan bool, numReq)
1817 unblockCh := make(chan bool, numReq)
1818 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1819 gotReqCh <- true
1820 <-unblockCh
1821 w.Header().Set("Content-Length", "0")
1822 w.WriteHeader(204)
1823 })).ts
1824 c := ts.Client()
1825 tr := c.Transport.(*Transport)
1826
1827 n0 := runtime.NumGoroutine()
1828
1829 didReqCh := make(chan bool, numReq)
1830 failed := make(chan bool, numReq)
1831 for i := 0; i < numReq; i++ {
1832 go func() {
1833 res, err := c.Get(ts.URL)
1834 didReqCh <- true
1835 if err != nil {
1836 t.Logf("client fetch error: %v", err)
1837 failed <- true
1838 return
1839 }
1840 res.Body.Close()
1841 }()
1842 }
1843
1844
1845 for i := 0; i < numReq; i++ {
1846 select {
1847 case <-gotReqCh:
1848
1849 case <-failed:
1850
1851
1852 }
1853 }
1854
1855 nhigh := runtime.NumGoroutine()
1856
1857
1858 close(unblockCh)
1859
1860
1861 for i := 0; i < numReq; i++ {
1862 <-didReqCh
1863 }
1864
1865 tr.CloseIdleConnections()
1866 nfinal := waitNumGoroutine(n0 + 5)
1867
1868 growth := nfinal - n0
1869
1870
1871
1872 if int(growth) > 5 {
1873 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1874 t.Error("too many new goroutines")
1875 }
1876 }
1877
1878
1879
1880 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1881 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1882 }
1883 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1884 if mode == http2Mode {
1885 t.Skip("flaky in HTTP/2")
1886 }
1887
1888
1889 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1890 })).ts
1891 c := ts.Client()
1892 tr := c.Transport.(*Transport)
1893
1894 n0 := runtime.NumGoroutine()
1895 body := []byte("Hello")
1896 for i := 0; i < 20; i++ {
1897 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1898 if err != nil {
1899 t.Fatal(err)
1900 }
1901 req.ContentLength = int64(len(body) - 2)
1902 _, err = c.Do(req)
1903 if err == nil {
1904 t.Fatal("Expect an error from writing too long of a body.")
1905 }
1906 }
1907 nhigh := runtime.NumGoroutine()
1908 tr.CloseIdleConnections()
1909 nfinal := waitNumGoroutine(n0 + 5)
1910
1911 growth := nfinal - n0
1912
1913
1914
1915 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1916 if int(growth) > 5 {
1917 t.Error("too many new goroutines")
1918 }
1919 }
1920
1921
1922 type countedConn struct {
1923 net.Conn
1924 }
1925
1926
1927 type countingDialer struct {
1928 dialer net.Dialer
1929 mu sync.Mutex
1930 total, live int64
1931 }
1932
1933 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
1934 conn, err := d.dialer.DialContext(ctx, network, address)
1935 if err != nil {
1936 return nil, err
1937 }
1938
1939 counted := new(countedConn)
1940 counted.Conn = conn
1941
1942 d.mu.Lock()
1943 defer d.mu.Unlock()
1944 d.total++
1945 d.live++
1946
1947 runtime.SetFinalizer(counted, d.decrement)
1948 return counted, nil
1949 }
1950
1951 func (d *countingDialer) decrement(*countedConn) {
1952 d.mu.Lock()
1953 defer d.mu.Unlock()
1954 d.live--
1955 }
1956
1957 func (d *countingDialer) Read() (total, live int64) {
1958 d.mu.Lock()
1959 defer d.mu.Unlock()
1960 return d.total, d.live
1961 }
1962
1963 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
1964 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
1965 }
1966 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
1967 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1968
1969 conn, _, err := w.(Hijacker).Hijack()
1970 if err != nil {
1971 t.Errorf("Hijack failed unexpectedly: %v", err)
1972 return
1973 }
1974 conn.Close()
1975 })).ts
1976
1977 var d countingDialer
1978 c := ts.Client()
1979 c.Transport.(*Transport).DialContext = d.DialContext
1980
1981 body := []byte("Hello")
1982 for i := 0; ; i++ {
1983 total, live := d.Read()
1984 if live < total {
1985 break
1986 }
1987 if i >= 1<<12 {
1988 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
1989 }
1990
1991 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1992 if err != nil {
1993 t.Fatal(err)
1994 }
1995 _, err = c.Do(req)
1996 if err == nil {
1997 t.Fatal("expected broken connection")
1998 }
1999
2000 runtime.GC()
2001 }
2002 }
2003
2004 type countedContext struct {
2005 context.Context
2006 }
2007
2008 type contextCounter struct {
2009 mu sync.Mutex
2010 live int64
2011 }
2012
2013 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2014 counted := new(countedContext)
2015 counted.Context = ctx
2016 cc.mu.Lock()
2017 defer cc.mu.Unlock()
2018 cc.live++
2019 runtime.SetFinalizer(counted, cc.decrement)
2020 return counted
2021 }
2022
2023 func (cc *contextCounter) decrement(*countedContext) {
2024 cc.mu.Lock()
2025 defer cc.mu.Unlock()
2026 cc.live--
2027 }
2028
2029 func (cc *contextCounter) Read() (live int64) {
2030 cc.mu.Lock()
2031 defer cc.mu.Unlock()
2032 return cc.live
2033 }
2034
2035 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2036 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2037 }
2038 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2039 if mode == http2Mode {
2040 t.Skip("https://go.dev/issue/56021")
2041 }
2042
2043 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2044 runtime.Gosched()
2045 w.WriteHeader(StatusOK)
2046 })).ts
2047
2048 c := ts.Client()
2049 c.Transport.(*Transport).MaxConnsPerHost = 1
2050
2051 ctx := context.Background()
2052 body := []byte("Hello")
2053 doPosts := func(cc *contextCounter) {
2054 var wg sync.WaitGroup
2055 for n := 64; n > 0; n-- {
2056 wg.Add(1)
2057 go func() {
2058 defer wg.Done()
2059
2060 ctx := cc.Track(ctx)
2061 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2062 if err != nil {
2063 t.Error(err)
2064 }
2065
2066 _, err = c.Do(req.WithContext(ctx))
2067 if err != nil {
2068 t.Errorf("Do failed with error: %v", err)
2069 }
2070 }()
2071 }
2072 wg.Wait()
2073 }
2074
2075 var initialCC contextCounter
2076 doPosts(&initialCC)
2077
2078
2079
2080
2081 var flushCC contextCounter
2082 for i := 0; ; i++ {
2083 live := initialCC.Read()
2084 if live == 0 {
2085 break
2086 }
2087 if i >= 100 {
2088 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2089 }
2090 doPosts(&flushCC)
2091 runtime.GC()
2092 }
2093 }
2094
2095
2096 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2097 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2098 var tr *Transport
2099
2100 unblockCh := make(chan bool, 1)
2101 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2102 <-unblockCh
2103 tr.CloseIdleConnections()
2104 })).ts
2105 c := ts.Client()
2106 tr = c.Transport.(*Transport)
2107
2108 didreq := make(chan bool)
2109 go func() {
2110 res, err := c.Get(ts.URL)
2111 if err != nil {
2112 t.Error(err)
2113 } else {
2114 res.Body.Close()
2115 }
2116 didreq <- true
2117 }()
2118 unblockCh <- true
2119 <-didreq
2120 }
2121
2122
2123
2124
2125
2126 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2127 func testIssue3644(t *testing.T, mode testMode) {
2128 const numFoos = 5000
2129 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2130 w.Header().Set("Connection", "close")
2131 for i := 0; i < numFoos; i++ {
2132 w.Write([]byte("foo "))
2133 }
2134 })).ts
2135 c := ts.Client()
2136 res, err := c.Get(ts.URL)
2137 if err != nil {
2138 t.Fatal(err)
2139 }
2140 defer res.Body.Close()
2141 bs, err := io.ReadAll(res.Body)
2142 if err != nil {
2143 t.Fatal(err)
2144 }
2145 if len(bs) != numFoos*len("foo ") {
2146 t.Errorf("unexpected response length")
2147 }
2148 }
2149
2150
2151
2152 func TestIssue3595(t *testing.T) {
2153
2154 run(t, testIssue3595, testNotParallel)
2155 }
2156 func testIssue3595(t *testing.T, mode testMode) {
2157 runTimeSensitiveTest(t, []time.Duration{
2158 1 * time.Millisecond,
2159 5 * time.Millisecond,
2160 10 * time.Millisecond,
2161 50 * time.Millisecond,
2162 100 * time.Millisecond,
2163 500 * time.Millisecond,
2164 time.Second,
2165 5 * time.Second,
2166 }, func(t *testing.T, timeout time.Duration) error {
2167 SetRSTAvoidanceDelay(t, timeout)
2168 t.Logf("set RST avoidance delay to %v", timeout)
2169
2170 const deniedMsg = "sorry, denied."
2171 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2172 Error(w, deniedMsg, StatusUnauthorized)
2173 }))
2174
2175
2176 defer cst.close()
2177 ts := cst.ts
2178 c := ts.Client()
2179
2180 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2181 if err != nil {
2182 return fmt.Errorf("Post: %v", err)
2183 }
2184 got, err := io.ReadAll(res.Body)
2185 if err != nil {
2186 return fmt.Errorf("Body ReadAll: %v", err)
2187 }
2188 t.Logf("server response:\n%s", got)
2189 if !strings.Contains(string(got), deniedMsg) {
2190
2191
2192 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2193 }
2194 return nil
2195 })
2196 }
2197
2198
2199
2200 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2201 func testChunkedNoContent(t *testing.T, mode testMode) {
2202 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2203 w.WriteHeader(StatusNoContent)
2204 })).ts
2205
2206 c := ts.Client()
2207 for _, closeBody := range []bool{true, false} {
2208 const n = 4
2209 for i := 1; i <= n; i++ {
2210 res, err := c.Get(ts.URL)
2211 if err != nil {
2212 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2213 } else {
2214 if closeBody {
2215 res.Body.Close()
2216 }
2217 }
2218 }
2219 }
2220 }
2221
2222 func TestTransportConcurrency(t *testing.T) {
2223 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2224 }
2225 func testTransportConcurrency(t *testing.T, mode testMode) {
2226
2227 maxProcs, numReqs := 16, 500
2228 if testing.Short() {
2229 maxProcs, numReqs = 4, 50
2230 }
2231 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2232 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2233 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2234 })).ts
2235
2236 var wg sync.WaitGroup
2237 wg.Add(numReqs)
2238
2239
2240
2241
2242
2243
2244
2245 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2246 defer SetPendingDialHooks(nil, nil)
2247
2248 c := ts.Client()
2249 reqs := make(chan string)
2250 defer close(reqs)
2251
2252 for i := 0; i < maxProcs*2; i++ {
2253 go func() {
2254 for req := range reqs {
2255 res, err := c.Get(ts.URL + "/?echo=" + req)
2256 if err != nil {
2257 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2258
2259
2260 t.Logf("error on req %s: %v", req, err)
2261 t.Logf("(see https://go.dev/issue/52168)")
2262 } else {
2263 t.Errorf("error on req %s: %v", req, err)
2264 }
2265 wg.Done()
2266 continue
2267 }
2268 all, err := io.ReadAll(res.Body)
2269 if err != nil {
2270 t.Errorf("read error on req %s: %v", req, err)
2271 } else if string(all) != req {
2272 t.Errorf("body of req %s = %q; want %q", req, all, req)
2273 }
2274 res.Body.Close()
2275 wg.Done()
2276 }
2277 }()
2278 }
2279 for i := 0; i < numReqs; i++ {
2280 reqs <- fmt.Sprintf("request-%d", i)
2281 }
2282 wg.Wait()
2283 }
2284
2285 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2286 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2287 mux := NewServeMux()
2288 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2289 io.Copy(w, neverEnding('a'))
2290 })
2291 ts := newClientServerTest(t, mode, mux).ts
2292
2293 connc := make(chan net.Conn, 1)
2294 c := ts.Client()
2295 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2296 conn, err := net.Dial(n, addr)
2297 if err != nil {
2298 return nil, err
2299 }
2300 select {
2301 case connc <- conn:
2302 default:
2303 }
2304 return conn, nil
2305 }
2306
2307 res, err := c.Get(ts.URL + "/get")
2308 if err != nil {
2309 t.Fatalf("Error issuing GET: %v", err)
2310 }
2311 defer res.Body.Close()
2312
2313 conn := <-connc
2314 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2315 _, err = io.Copy(io.Discard, res.Body)
2316 if err == nil {
2317 t.Errorf("Unexpected successful copy")
2318 }
2319 }
2320
2321 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2322 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2323 }
2324 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2325 const debug = false
2326 mux := NewServeMux()
2327 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2328 io.Copy(w, neverEnding('a'))
2329 })
2330 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2331 defer r.Body.Close()
2332 io.Copy(io.Discard, r.Body)
2333 })
2334 ts := newClientServerTest(t, mode, mux).ts
2335 timeout := 100 * time.Millisecond
2336
2337 c := ts.Client()
2338 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2339 conn, err := net.Dial(n, addr)
2340 if err != nil {
2341 return nil, err
2342 }
2343 conn.SetDeadline(time.Now().Add(timeout))
2344 if debug {
2345 conn = NewLoggingConn("client", conn)
2346 }
2347 return conn, nil
2348 }
2349
2350 getFailed := false
2351 nRuns := 5
2352 if testing.Short() {
2353 nRuns = 1
2354 }
2355 for i := 0; i < nRuns; i++ {
2356 if debug {
2357 println("run", i+1, "of", nRuns)
2358 }
2359 sres, err := c.Get(ts.URL + "/get")
2360 if err != nil {
2361 if !getFailed {
2362
2363 getFailed = true
2364 t.Logf("increasing timeout")
2365 i--
2366 timeout *= 10
2367 continue
2368 }
2369 t.Errorf("Error issuing GET: %v", err)
2370 break
2371 }
2372 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2373 _, err = c.Do(req)
2374 if err == nil {
2375 sres.Body.Close()
2376 t.Errorf("Unexpected successful PUT")
2377 break
2378 }
2379 sres.Body.Close()
2380 }
2381 if debug {
2382 println("tests complete; waiting for handlers to finish")
2383 }
2384 ts.Close()
2385 }
2386
2387 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2388 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2389 if testing.Short() {
2390 t.Skip("skipping timeout test in -short mode")
2391 }
2392
2393 timeout := 2 * time.Millisecond
2394 retry := true
2395 for retry && !t.Failed() {
2396 var srvWG sync.WaitGroup
2397 inHandler := make(chan bool, 1)
2398 mux := NewServeMux()
2399 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2400 inHandler <- true
2401 srvWG.Done()
2402 })
2403 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2404 inHandler <- true
2405 <-r.Context().Done()
2406 srvWG.Done()
2407 })
2408 ts := newClientServerTest(t, mode, mux).ts
2409
2410 c := ts.Client()
2411 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2412
2413 retry = false
2414 srvWG.Add(3)
2415 tests := []struct {
2416 path string
2417 wantTimeout bool
2418 }{
2419 {path: "/fast"},
2420 {path: "/slow", wantTimeout: true},
2421 {path: "/fast"},
2422 }
2423 for i, tt := range tests {
2424 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2425 req = req.WithT(t)
2426 res, err := c.Do(req)
2427 <-inHandler
2428 if err != nil {
2429 uerr, ok := err.(*url.Error)
2430 if !ok {
2431 t.Errorf("error is not a url.Error; got: %#v", err)
2432 continue
2433 }
2434 nerr, ok := uerr.Err.(net.Error)
2435 if !ok {
2436 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2437 continue
2438 }
2439 if !nerr.Timeout() {
2440 t.Errorf("want timeout error; got: %q", nerr)
2441 continue
2442 }
2443 if !tt.wantTimeout {
2444 if !retry {
2445
2446 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2447 timeout *= 2
2448 retry = true
2449 }
2450 }
2451 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2452 t.Errorf("%d. unexpected error: %v", i, err)
2453 }
2454 continue
2455 }
2456 if tt.wantTimeout {
2457 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2458 continue
2459 }
2460 if res.StatusCode != 200 {
2461 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2462 }
2463 }
2464
2465 srvWG.Wait()
2466 ts.Close()
2467 }
2468 }
2469
2470 func TestTransportCancelRequest(t *testing.T) {
2471 run(t, testTransportCancelRequest, []testMode{http1Mode})
2472 }
2473 func testTransportCancelRequest(t *testing.T, mode testMode) {
2474 if testing.Short() {
2475 t.Skip("skipping test in -short mode")
2476 }
2477
2478 const msg = "Hello"
2479 unblockc := make(chan bool)
2480 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2481 io.WriteString(w, msg)
2482 w.(Flusher).Flush()
2483 <-unblockc
2484 })).ts
2485 defer close(unblockc)
2486
2487 c := ts.Client()
2488 tr := c.Transport.(*Transport)
2489
2490 req, _ := NewRequest("GET", ts.URL, nil)
2491 res, err := c.Do(req)
2492 if err != nil {
2493 t.Fatal(err)
2494 }
2495 body := make([]byte, len(msg))
2496 n, _ := io.ReadFull(res.Body, body)
2497 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2498 t.Errorf("Body = %q; want %q", body[:n], msg)
2499 }
2500 tr.CancelRequest(req)
2501
2502 tail, err := io.ReadAll(res.Body)
2503 res.Body.Close()
2504 if err != ExportErrRequestCanceled {
2505 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2506 } else if len(tail) > 0 {
2507 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2508 }
2509
2510
2511
2512 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2513 n := tr.NumPendingRequestsForTesting()
2514 if n > 0 {
2515 if d > 0 {
2516 t.Logf("pending requests = %d after %v (want 0)", n, d)
2517 }
2518 return false
2519 }
2520 return true
2521 })
2522 }
2523
2524 func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
2525 if testing.Short() {
2526 t.Skip("skipping test in -short mode")
2527 }
2528 unblockc := make(chan bool)
2529 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2530 <-unblockc
2531 })).ts
2532 defer close(unblockc)
2533
2534 c := ts.Client()
2535 tr := c.Transport.(*Transport)
2536
2537 donec := make(chan bool)
2538 req, _ := NewRequest("GET", ts.URL, body)
2539 go func() {
2540 defer close(donec)
2541 c.Do(req)
2542 }()
2543
2544 unblockc <- true
2545 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2546 tr.CancelRequest(req)
2547 select {
2548 case <-donec:
2549 return true
2550 default:
2551 if d > 0 {
2552 t.Logf("Do of canceled request has not returned after %v", d)
2553 }
2554 return false
2555 }
2556 })
2557 }
2558
2559 func TestTransportCancelRequestInDo(t *testing.T) {
2560 run(t, func(t *testing.T, mode testMode) {
2561 testTransportCancelRequestInDo(t, mode, nil)
2562 }, []testMode{http1Mode})
2563 }
2564
2565 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2566 run(t, func(t *testing.T, mode testMode) {
2567 testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
2568 }, []testMode{http1Mode})
2569 }
2570
2571 func TestTransportCancelRequestInDial(t *testing.T) {
2572 defer afterTest(t)
2573 if testing.Short() {
2574 t.Skip("skipping test in -short mode")
2575 }
2576 var logbuf strings.Builder
2577 eventLog := log.New(&logbuf, "", 0)
2578
2579 unblockDial := make(chan bool)
2580 defer close(unblockDial)
2581
2582 inDial := make(chan bool)
2583 tr := &Transport{
2584 Dial: func(network, addr string) (net.Conn, error) {
2585 eventLog.Println("dial: blocking")
2586 if !<-inDial {
2587 return nil, errors.New("main Test goroutine exited")
2588 }
2589 <-unblockDial
2590 return nil, errors.New("nope")
2591 },
2592 }
2593 cl := &Client{Transport: tr}
2594 gotres := make(chan bool)
2595 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2596 go func() {
2597 _, err := cl.Do(req)
2598 eventLog.Printf("Get = %v", err)
2599 gotres <- true
2600 }()
2601
2602 inDial <- true
2603
2604 eventLog.Printf("canceling")
2605 tr.CancelRequest(req)
2606 tr.CancelRequest(req)
2607
2608 if d, ok := t.Deadline(); ok {
2609
2610
2611 timeout := time.Until(d) * 19 / 20
2612 timer := time.AfterFunc(timeout, func() {
2613 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2614 })
2615 defer timer.Stop()
2616 }
2617 <-gotres
2618
2619 got := logbuf.String()
2620 want := `dial: blocking
2621 canceling
2622 Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
2623 `
2624 if got != want {
2625 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2626 }
2627 }
2628
2629 func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
2630 func testCancelRequestWithChannel(t *testing.T, mode testMode) {
2631 if testing.Short() {
2632 t.Skip("skipping test in -short mode")
2633 }
2634
2635 const msg = "Hello"
2636 unblockc := make(chan struct{})
2637 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2638 io.WriteString(w, msg)
2639 w.(Flusher).Flush()
2640 <-unblockc
2641 })).ts
2642 defer close(unblockc)
2643
2644 c := ts.Client()
2645 tr := c.Transport.(*Transport)
2646
2647 req, _ := NewRequest("GET", ts.URL, nil)
2648 cancel := make(chan struct{})
2649 req.Cancel = cancel
2650
2651 res, err := c.Do(req)
2652 if err != nil {
2653 t.Fatal(err)
2654 }
2655 body := make([]byte, len(msg))
2656 n, _ := io.ReadFull(res.Body, body)
2657 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2658 t.Errorf("Body = %q; want %q", body[:n], msg)
2659 }
2660 close(cancel)
2661
2662 tail, err := io.ReadAll(res.Body)
2663 res.Body.Close()
2664 if err != ExportErrRequestCanceled {
2665 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2666 } else if len(tail) > 0 {
2667 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2668 }
2669
2670
2671
2672 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2673 n := tr.NumPendingRequestsForTesting()
2674 if n > 0 {
2675 if d > 0 {
2676 t.Logf("pending requests = %d after %v (want 0)", n, d)
2677 }
2678 return false
2679 }
2680 return true
2681 })
2682 }
2683
2684
2685 func TestCancelRequestWithBodyWithChannel(t *testing.T) {
2686 run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
2687 }
2688 func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
2689 if testing.Short() {
2690 t.Skip("skipping test in -short mode")
2691 }
2692
2693 const msg = "Hello"
2694 unblockc := make(chan struct{})
2695 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2696 io.WriteString(w, msg)
2697 w.(Flusher).Flush()
2698 <-unblockc
2699 })).ts
2700 defer close(unblockc)
2701
2702 c := ts.Client()
2703 tr := c.Transport.(*Transport)
2704
2705 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2706 cancel := make(chan struct{})
2707 req.Cancel = cancel
2708
2709 res, err := c.Do(req)
2710 if err != nil {
2711 t.Fatal(err)
2712 }
2713 body := make([]byte, len(msg))
2714 n, _ := io.ReadFull(res.Body, body)
2715 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2716 t.Errorf("Body = %q; want %q", body[:n], msg)
2717 }
2718 close(cancel)
2719
2720 tail, err := io.ReadAll(res.Body)
2721 res.Body.Close()
2722 if err != ExportErrRequestCanceled {
2723 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2724 } else if len(tail) > 0 {
2725 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2726 }
2727
2728
2729
2730 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2731 n := tr.NumPendingRequestsForTesting()
2732 if n > 0 {
2733 if d > 0 {
2734 t.Logf("pending requests = %d after %v (want 0)", n, d)
2735 }
2736 return false
2737 }
2738 return true
2739 })
2740 }
2741
2742 func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
2743 run(t, func(t *testing.T, mode testMode) {
2744 testCancelRequestWithChannelBeforeDo(t, mode, false)
2745 })
2746 }
2747 func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
2748 run(t, func(t *testing.T, mode testMode) {
2749 testCancelRequestWithChannelBeforeDo(t, mode, true)
2750 })
2751 }
2752 func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
2753 unblockc := make(chan bool)
2754 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2755 <-unblockc
2756 })).ts
2757 defer close(unblockc)
2758
2759 c := ts.Client()
2760
2761 req, _ := NewRequest("GET", ts.URL, nil)
2762 if withCtx {
2763 ctx, cancel := context.WithCancel(context.Background())
2764 cancel()
2765 req = req.WithContext(ctx)
2766 } else {
2767 ch := make(chan struct{})
2768 req.Cancel = ch
2769 close(ch)
2770 }
2771
2772 _, err := c.Do(req)
2773 if ue, ok := err.(*url.Error); ok {
2774 err = ue.Err
2775 }
2776 if withCtx {
2777 if err != context.Canceled {
2778 t.Errorf("Do error = %v; want %v", err, context.Canceled)
2779 }
2780 } else {
2781 if err == nil || !strings.Contains(err.Error(), "canceled") {
2782 t.Errorf("Do error = %v; want cancellation", err)
2783 }
2784 }
2785 }
2786
2787
2788 func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
2789 defer afterTest(t)
2790
2791 serverConnCh := make(chan net.Conn, 1)
2792 tr := &Transport{
2793 Dial: func(network, addr string) (net.Conn, error) {
2794 cc, sc := net.Pipe()
2795 serverConnCh <- sc
2796 return cc, nil
2797 },
2798 }
2799 defer tr.CloseIdleConnections()
2800 errc := make(chan error, 1)
2801 req, _ := NewRequest("GET", "http://example.com/", nil)
2802 go func() {
2803 _, err := tr.RoundTrip(req)
2804 errc <- err
2805 }()
2806
2807 sc := <-serverConnCh
2808 verb := make([]byte, 3)
2809 if _, err := io.ReadFull(sc, verb); err != nil {
2810 t.Errorf("Error reading HTTP verb from server: %v", err)
2811 }
2812 if string(verb) != "GET" {
2813 t.Errorf("server received %q; want GET", verb)
2814 }
2815 defer sc.Close()
2816
2817 tr.CancelRequest(req)
2818
2819 err := <-errc
2820 if err == nil {
2821 t.Fatalf("unexpected success from RoundTrip")
2822 }
2823 if err != ExportErrRequestCanceled {
2824 t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
2825 }
2826 }
2827
2828
2829
2830
2831 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2832 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2833 writeErr := make(chan error, 1)
2834 msg := []byte("young\n")
2835 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2836 for {
2837 _, err := w.Write(msg)
2838 if err != nil {
2839 writeErr <- err
2840 return
2841 }
2842 w.(Flusher).Flush()
2843 }
2844 })).ts
2845
2846 c := ts.Client()
2847 tr := c.Transport.(*Transport)
2848
2849 req, _ := NewRequest("GET", ts.URL, nil)
2850 defer tr.CancelRequest(req)
2851
2852 res, err := c.Do(req)
2853 if err != nil {
2854 t.Fatal(err)
2855 }
2856
2857 const repeats = 3
2858 buf := make([]byte, len(msg)*repeats)
2859 want := bytes.Repeat(msg, repeats)
2860
2861 _, err = io.ReadFull(res.Body, buf)
2862 if err != nil {
2863 t.Fatal(err)
2864 }
2865 if !bytes.Equal(buf, want) {
2866 t.Fatalf("read %q; want %q", buf, want)
2867 }
2868
2869 if err := res.Body.Close(); err != nil {
2870 t.Errorf("Close = %v", err)
2871 }
2872
2873 if err := <-writeErr; err == nil {
2874 t.Errorf("expected non-nil write error")
2875 }
2876 }
2877
2878 type fooProto struct{}
2879
2880 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2881 res := &Response{
2882 Status: "200 OK",
2883 StatusCode: 200,
2884 Header: make(Header),
2885 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2886 }
2887 return res, nil
2888 }
2889
2890 func TestTransportAltProto(t *testing.T) {
2891 defer afterTest(t)
2892 tr := &Transport{}
2893 c := &Client{Transport: tr}
2894 tr.RegisterProtocol("foo", fooProto{})
2895 res, err := c.Get("foo://bar.com/path")
2896 if err != nil {
2897 t.Fatal(err)
2898 }
2899 bodyb, err := io.ReadAll(res.Body)
2900 if err != nil {
2901 t.Fatal(err)
2902 }
2903 body := string(bodyb)
2904 if e := "You wanted foo://bar.com/path"; body != e {
2905 t.Errorf("got response %q, want %q", body, e)
2906 }
2907 }
2908
2909 func TestTransportNoHost(t *testing.T) {
2910 defer afterTest(t)
2911 tr := &Transport{}
2912 _, err := tr.RoundTrip(&Request{
2913 Header: make(Header),
2914 URL: &url.URL{
2915 Scheme: "http",
2916 },
2917 })
2918 want := "http: no Host in request URL"
2919 if got := fmt.Sprint(err); got != want {
2920 t.Errorf("error = %v; want %q", err, want)
2921 }
2922 }
2923
2924
2925 func TestTransportEmptyMethod(t *testing.T) {
2926 req, _ := NewRequest("GET", "http://foo.com/", nil)
2927 req.Method = ""
2928 got, err := httputil.DumpRequestOut(req, false)
2929 if err != nil {
2930 t.Fatal(err)
2931 }
2932 if !strings.Contains(string(got), "GET ") {
2933 t.Fatalf("expected substring 'GET '; got: %s", got)
2934 }
2935 }
2936
2937 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
2938 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
2939 mux := NewServeMux()
2940 fooGate := make(chan bool, 1)
2941 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
2942 w.Header().Set("foo-ipport", r.RemoteAddr)
2943 w.(Flusher).Flush()
2944 <-fooGate
2945 })
2946 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
2947 w.Header().Set("bar-ipport", r.RemoteAddr)
2948 })
2949 ts := newClientServerTest(t, mode, mux).ts
2950
2951 dialGate := make(chan bool, 1)
2952 dialing := make(chan bool)
2953 c := ts.Client()
2954 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2955 for {
2956 select {
2957 case ok := <-dialGate:
2958 if !ok {
2959 return nil, errors.New("manually closed")
2960 }
2961 return net.Dial(n, addr)
2962 case dialing <- true:
2963 }
2964 }
2965 }
2966 defer close(dialGate)
2967
2968 dialGate <- true
2969 fooRes, err := c.Get(ts.URL + "/foo")
2970 if err != nil {
2971 t.Fatal(err)
2972 }
2973 fooAddr := fooRes.Header.Get("foo-ipport")
2974 if fooAddr == "" {
2975 t.Fatal("No addr on /foo request")
2976 }
2977
2978 fooDone := make(chan struct{})
2979 go func() {
2980
2981
2982
2983
2984 if mode == http2Mode {
2985
2986
2987
2988
2989 select {
2990 case <-dialing:
2991 t.Errorf("unexpected second Dial in HTTP/2 mode")
2992 case <-time.After(10 * time.Millisecond):
2993 }
2994 } else {
2995 <-dialing
2996 }
2997 fooGate <- true
2998 io.Copy(io.Discard, fooRes.Body)
2999 fooRes.Body.Close()
3000 close(fooDone)
3001 }()
3002 defer func() {
3003 <-fooDone
3004 }()
3005
3006 barRes, err := c.Get(ts.URL + "/bar")
3007 if err != nil {
3008 t.Fatal(err)
3009 }
3010 barAddr := barRes.Header.Get("bar-ipport")
3011 if barAddr != fooAddr {
3012 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3013 }
3014 barRes.Body.Close()
3015 }
3016
3017
3018 func TestTransportReading100Continue(t *testing.T) {
3019 defer afterTest(t)
3020
3021 const numReqs = 5
3022 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3023 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3024
3025 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3026 defer w.Close()
3027 defer r.Close()
3028 br := bufio.NewReader(r)
3029 n := 0
3030 for {
3031 n++
3032 req, err := ReadRequest(br)
3033 if err == io.EOF {
3034 return
3035 }
3036 if err != nil {
3037 t.Error(err)
3038 return
3039 }
3040 slurp, err := io.ReadAll(req.Body)
3041 if err != nil {
3042 t.Errorf("Server request body slurp: %v", err)
3043 return
3044 }
3045 id := req.Header.Get("Request-Id")
3046 resCode := req.Header.Get("X-Want-Response-Code")
3047 if resCode == "" {
3048 resCode = "100 Continue"
3049 if string(slurp) != reqBody(n) {
3050 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3051 }
3052 }
3053 body := fmt.Sprintf("Response number %d", n)
3054 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3055 Date: Thu, 28 Feb 2013 17:55:41 GMT
3056
3057 HTTP/1.1 200 OK
3058 Content-Type: text/html
3059 Echo-Request-Id: %s
3060 Content-Length: %d
3061
3062 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3063 w.Write(v)
3064 if id == reqID(numReqs) {
3065 return
3066 }
3067 }
3068
3069 }
3070
3071 tr := &Transport{
3072 Dial: func(n, addr string) (net.Conn, error) {
3073 sr, sw := io.Pipe()
3074 cr, cw := io.Pipe()
3075 conn := &rwTestConn{
3076 Reader: cr,
3077 Writer: sw,
3078 closeFunc: func() error {
3079 sw.Close()
3080 cw.Close()
3081 return nil
3082 },
3083 }
3084 go send100Response(cw, sr)
3085 return conn, nil
3086 },
3087 DisableKeepAlives: false,
3088 }
3089 defer tr.CloseIdleConnections()
3090 c := &Client{Transport: tr}
3091
3092 testResponse := func(req *Request, name string, wantCode int) {
3093 t.Helper()
3094 res, err := c.Do(req)
3095 if err != nil {
3096 t.Fatalf("%s: Do: %v", name, err)
3097 }
3098 if res.StatusCode != wantCode {
3099 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3100 }
3101 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3102 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3103 }
3104 _, err = io.ReadAll(res.Body)
3105 if err != nil {
3106 t.Fatalf("%s: Slurp error: %v", name, err)
3107 }
3108 }
3109
3110
3111 for i := 1; i <= numReqs; i++ {
3112 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3113 req.Header.Set("Request-Id", reqID(i))
3114 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3115 }
3116 }
3117
3118
3119
3120 func TestTransportIgnore1xxResponses(t *testing.T) {
3121 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3122 }
3123 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3124 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3125 conn, buf, _ := w.(Hijacker).Hijack()
3126 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3127 buf.Flush()
3128 conn.Close()
3129 }))
3130 cst.tr.DisableKeepAlives = true
3131
3132 var got strings.Builder
3133
3134 req, _ := NewRequest("GET", cst.ts.URL, nil)
3135 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3136 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3137 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3138 return nil
3139 },
3140 }))
3141 res, err := cst.c.Do(req)
3142 if err != nil {
3143 t.Fatal(err)
3144 }
3145 defer res.Body.Close()
3146
3147 res.Write(&got)
3148 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3149 if got.String() != want {
3150 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3151 }
3152 }
3153
3154 func TestTransportLimits1xxResponses(t *testing.T) {
3155 run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
3156 }
3157 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3158 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3159 conn, buf, _ := w.(Hijacker).Hijack()
3160 for i := 0; i < 10; i++ {
3161 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
3162 }
3163 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3164 buf.Flush()
3165 conn.Close()
3166 }))
3167 cst.tr.DisableKeepAlives = true
3168
3169 res, err := cst.c.Get(cst.ts.URL)
3170 if res != nil {
3171 defer res.Body.Close()
3172 }
3173 got := fmt.Sprint(err)
3174 wantSub := "too many 1xx informational responses"
3175 if !strings.Contains(got, wantSub) {
3176 t.Errorf("Get error = %v; want substring %q", err, wantSub)
3177 }
3178 }
3179
3180
3181
3182 func TestTransportTreat101Terminal(t *testing.T) {
3183 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3184 }
3185 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3186 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3187 conn, buf, _ := w.(Hijacker).Hijack()
3188 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3189 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3190 buf.Flush()
3191 conn.Close()
3192 }))
3193 res, err := cst.c.Get(cst.ts.URL)
3194 if err != nil {
3195 t.Fatal(err)
3196 }
3197 defer res.Body.Close()
3198 if res.StatusCode != StatusSwitchingProtocols {
3199 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3200 }
3201 }
3202
3203 type proxyFromEnvTest struct {
3204 req string
3205
3206 env string
3207 httpsenv string
3208 noenv string
3209 reqmeth string
3210
3211 want string
3212 wanterr error
3213 }
3214
3215 func (t proxyFromEnvTest) String() string {
3216 var buf strings.Builder
3217 space := func() {
3218 if buf.Len() > 0 {
3219 buf.WriteByte(' ')
3220 }
3221 }
3222 if t.env != "" {
3223 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3224 }
3225 if t.httpsenv != "" {
3226 space()
3227 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3228 }
3229 if t.noenv != "" {
3230 space()
3231 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3232 }
3233 if t.reqmeth != "" {
3234 space()
3235 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3236 }
3237 req := "http://example.com"
3238 if t.req != "" {
3239 req = t.req
3240 }
3241 space()
3242 fmt.Fprintf(&buf, "req=%q", req)
3243 return strings.TrimSpace(buf.String())
3244 }
3245
3246 var proxyFromEnvTests = []proxyFromEnvTest{
3247 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3248 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3249 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3250 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3251 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3252 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3253 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3254
3255
3256 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3257
3258 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3259 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3260
3261
3262
3263 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3264 want: "<nil>",
3265 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3266
3267 {want: "<nil>"},
3268
3269 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3270 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3271 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3272 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3273 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3274 }
3275
3276 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3277 t.Helper()
3278 reqURL := tt.req
3279 if reqURL == "" {
3280 reqURL = "http://example.com"
3281 }
3282 req, _ := NewRequest("GET", reqURL, nil)
3283 url, err := proxyForRequest(req)
3284 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3285 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3286 return
3287 }
3288 if got := fmt.Sprintf("%s", url); got != tt.want {
3289 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3290 }
3291 }
3292
3293 func TestProxyFromEnvironment(t *testing.T) {
3294 ResetProxyEnv()
3295 defer ResetProxyEnv()
3296 for _, tt := range proxyFromEnvTests {
3297 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3298 os.Setenv("HTTP_PROXY", tt.env)
3299 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3300 os.Setenv("NO_PROXY", tt.noenv)
3301 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3302 ResetCachedEnvironment()
3303 return ProxyFromEnvironment(req)
3304 })
3305 }
3306 }
3307
3308 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3309 ResetProxyEnv()
3310 defer ResetProxyEnv()
3311 for _, tt := range proxyFromEnvTests {
3312 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3313 os.Setenv("http_proxy", tt.env)
3314 os.Setenv("https_proxy", tt.httpsenv)
3315 os.Setenv("no_proxy", tt.noenv)
3316 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3317 ResetCachedEnvironment()
3318 return ProxyFromEnvironment(req)
3319 })
3320 }
3321 }
3322
3323 func TestIdleConnChannelLeak(t *testing.T) {
3324 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3325 }
3326 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3327
3328 var mu sync.Mutex
3329 var n int
3330
3331 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3332 mu.Lock()
3333 n++
3334 mu.Unlock()
3335 })).ts
3336
3337 const nReqs = 5
3338 didRead := make(chan bool, nReqs)
3339 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3340 defer SetReadLoopBeforeNextReadHook(nil)
3341
3342 c := ts.Client()
3343 tr := c.Transport.(*Transport)
3344 tr.Dial = func(netw, addr string) (net.Conn, error) {
3345 return net.Dial(netw, ts.Listener.Addr().String())
3346 }
3347
3348
3349 for _, disableKeep := range []bool{true, false} {
3350 tr.DisableKeepAlives = disableKeep
3351 for i := 0; i < nReqs; i++ {
3352 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3353 if err != nil {
3354 t.Fatal(err)
3355 }
3356
3357
3358
3359
3360
3361 }
3362
3363
3364
3365
3366
3367
3368
3369 for i := 0; i < nReqs; i++ {
3370 <-didRead
3371 }
3372
3373 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3374 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3375 }
3376 }
3377 }
3378
3379
3380
3381
3382 func TestTransportClosesRequestBody(t *testing.T) {
3383 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3384 }
3385 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3386 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3387 io.Copy(io.Discard, r.Body)
3388 })).ts
3389
3390 c := ts.Client()
3391
3392 closes := 0
3393
3394 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3395 if err != nil {
3396 t.Fatal(err)
3397 }
3398 res.Body.Close()
3399 if closes != 1 {
3400 t.Errorf("closes = %d; want 1", closes)
3401 }
3402 }
3403
3404 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3405 defer afterTest(t)
3406 if testing.Short() {
3407 t.Skip("skipping in short mode")
3408 }
3409 ln := newLocalListener(t)
3410 defer ln.Close()
3411 testdonec := make(chan struct{})
3412 defer close(testdonec)
3413
3414 go func() {
3415 c, err := ln.Accept()
3416 if err != nil {
3417 t.Error(err)
3418 return
3419 }
3420 <-testdonec
3421 c.Close()
3422 }()
3423
3424 tr := &Transport{
3425 Dial: func(_, _ string) (net.Conn, error) {
3426 return net.Dial("tcp", ln.Addr().String())
3427 },
3428 TLSHandshakeTimeout: 250 * time.Millisecond,
3429 }
3430 cl := &Client{Transport: tr}
3431 _, err := cl.Get("https://dummy.tld/")
3432 if err == nil {
3433 t.Error("expected error")
3434 return
3435 }
3436 ue, ok := err.(*url.Error)
3437 if !ok {
3438 t.Errorf("expected url.Error; got %#v", err)
3439 return
3440 }
3441 ne, ok := ue.Err.(net.Error)
3442 if !ok {
3443 t.Errorf("expected net.Error; got %#v", err)
3444 return
3445 }
3446 if !ne.Timeout() {
3447 t.Errorf("expected timeout error; got %v", err)
3448 }
3449 if !strings.Contains(err.Error(), "handshake timeout") {
3450 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3451 }
3452 }
3453
3454
3455 func TestTLSServerClosesConnection(t *testing.T) {
3456 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3457 }
3458 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3459 closedc := make(chan bool, 1)
3460 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3461 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3462 conn, _, _ := w.(Hijacker).Hijack()
3463 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3464 conn.Close()
3465 closedc <- true
3466 return
3467 }
3468 fmt.Fprintf(w, "hello")
3469 })).ts
3470
3471 c := ts.Client()
3472 tr := c.Transport.(*Transport)
3473
3474 var nSuccess = 0
3475 var errs []error
3476 const trials = 20
3477 for i := 0; i < trials; i++ {
3478 tr.CloseIdleConnections()
3479 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3480 if err != nil {
3481 t.Fatal(err)
3482 }
3483 <-closedc
3484 slurp, err := io.ReadAll(res.Body)
3485 if err != nil {
3486 t.Fatal(err)
3487 }
3488 if string(slurp) != "foo" {
3489 t.Errorf("Got %q, want foo", slurp)
3490 }
3491
3492
3493
3494 res, err = c.Get(ts.URL + "/")
3495 if err != nil {
3496 errs = append(errs, err)
3497 continue
3498 }
3499 slurp, err = io.ReadAll(res.Body)
3500 if err != nil {
3501 errs = append(errs, err)
3502 continue
3503 }
3504 nSuccess++
3505 }
3506 if nSuccess > 0 {
3507 t.Logf("successes = %d of %d", nSuccess, trials)
3508 } else {
3509 t.Errorf("All runs failed:")
3510 }
3511 for _, err := range errs {
3512 t.Logf(" err: %v", err)
3513 }
3514 }
3515
3516
3517
3518
3519 type byteFromChanReader chan byte
3520
3521 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3522 if len(p) == 0 {
3523 return
3524 }
3525 b, ok := <-c
3526 if !ok {
3527 return 0, io.EOF
3528 }
3529 p[0] = b
3530 return 1, nil
3531 }
3532
3533
3534
3535
3536
3537
3538
3539 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3540 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3541 }
3542 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3543 defer func(d time.Duration) {
3544 *MaxWriteWaitBeforeConnReuse = d
3545 }(*MaxWriteWaitBeforeConnReuse)
3546 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3547 var sconn struct {
3548 sync.Mutex
3549 c net.Conn
3550 }
3551 var getOkay bool
3552 var copying sync.WaitGroup
3553 closeConn := func() {
3554 sconn.Lock()
3555 defer sconn.Unlock()
3556 if sconn.c != nil {
3557 sconn.c.Close()
3558 sconn.c = nil
3559 if !getOkay {
3560 t.Logf("Closed server connection")
3561 }
3562 }
3563 }
3564 defer func() {
3565 closeConn()
3566 copying.Wait()
3567 }()
3568
3569 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3570 if r.Method == "GET" {
3571 io.WriteString(w, "bar")
3572 return
3573 }
3574 conn, _, _ := w.(Hijacker).Hijack()
3575 sconn.Lock()
3576 sconn.c = conn
3577 sconn.Unlock()
3578 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3579
3580 copying.Add(1)
3581 go func() {
3582 io.Copy(io.Discard, conn)
3583 copying.Done()
3584 }()
3585 })).ts
3586 c := ts.Client()
3587
3588 const bodySize = 256 << 10
3589 finalBit := make(byteFromChanReader, 1)
3590 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3591 req.ContentLength = bodySize
3592 res, err := c.Do(req)
3593 if err := wantBody(res, err, "foo"); err != nil {
3594 t.Errorf("POST response: %v", err)
3595 }
3596
3597 res, err = c.Get(ts.URL)
3598 if err := wantBody(res, err, "bar"); err != nil {
3599 t.Errorf("GET response: %v", err)
3600 return
3601 }
3602 getOkay = true
3603 finalBit <- 'x'
3604 close(finalBit)
3605 }
3606
3607
3608
3609 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3610 func testTransportIssue10457(t *testing.T, mode testMode) {
3611 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3612
3613
3614
3615
3616
3617 conn, _, _ := w.(Hijacker).Hijack()
3618 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3619 conn.Close()
3620 })).ts
3621 c := ts.Client()
3622
3623 res, err := c.Get(ts.URL)
3624 if err != nil {
3625 t.Fatalf("Get: %v", err)
3626 }
3627 defer res.Body.Close()
3628
3629
3630
3631
3632 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3633 t.Errorf("Foo header = %q; want %q", got, want)
3634 }
3635 }
3636
3637 type closerFunc func() error
3638
3639 func (f closerFunc) Close() error { return f() }
3640
3641 type writerFuncConn struct {
3642 net.Conn
3643 write func(p []byte) (n int, err error)
3644 }
3645
3646 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660 func TestRetryRequestsOnError(t *testing.T) {
3661 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3662 }
3663 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3664 newRequest := func(method, urlStr string, body io.Reader) *Request {
3665 req, err := NewRequest(method, urlStr, body)
3666 if err != nil {
3667 t.Fatal(err)
3668 }
3669 return req
3670 }
3671
3672 testCases := []struct {
3673 name string
3674 failureN int
3675 failureErr error
3676
3677
3678
3679 req func() *Request
3680 reqString string
3681 }{
3682 {
3683 name: "IdempotentNoBodySomeWritten",
3684
3685
3686 failureN: 1,
3687
3688 failureErr: ExportErrServerClosedIdle,
3689 req: func() *Request {
3690 return newRequest("GET", "http://fake.golang", nil)
3691 },
3692 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3693 },
3694 {
3695 name: "IdempotentGetBodySomeWritten",
3696
3697
3698 failureN: 1,
3699
3700 failureErr: ExportErrServerClosedIdle,
3701 req: func() *Request {
3702 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3703 },
3704 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3705 },
3706 {
3707 name: "NothingWrittenNoBody",
3708
3709
3710 failureN: 0,
3711 failureErr: errors.New("second write fails"),
3712 req: func() *Request {
3713 return newRequest("DELETE", "http://fake.golang", nil)
3714 },
3715 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3716 },
3717 {
3718 name: "NothingWrittenGetBody",
3719
3720
3721 failureN: 0,
3722 failureErr: errors.New("second write fails"),
3723
3724
3725 req: func() *Request {
3726 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3727 },
3728 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3729 },
3730 }
3731
3732 for _, tc := range testCases {
3733 t.Run(tc.name, func(t *testing.T) {
3734 var (
3735 mu sync.Mutex
3736 logbuf strings.Builder
3737 )
3738 logf := func(format string, args ...any) {
3739 mu.Lock()
3740 defer mu.Unlock()
3741 fmt.Fprintf(&logbuf, format, args...)
3742 logbuf.WriteByte('\n')
3743 }
3744
3745 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3746 logf("Handler")
3747 w.Header().Set("X-Status", "ok")
3748 })).ts
3749
3750 var writeNumAtomic int32
3751 c := ts.Client()
3752 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3753 logf("Dial")
3754 c, err := net.Dial(network, ts.Listener.Addr().String())
3755 if err != nil {
3756 logf("Dial error: %v", err)
3757 return nil, err
3758 }
3759 return &writerFuncConn{
3760 Conn: c,
3761 write: func(p []byte) (n int, err error) {
3762 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3763 logf("intentional write failure")
3764 return tc.failureN, tc.failureErr
3765 }
3766 logf("Write(%q)", p)
3767 return c.Write(p)
3768 },
3769 }, nil
3770 }
3771
3772 SetRoundTripRetried(func() {
3773 logf("Retried.")
3774 })
3775 defer SetRoundTripRetried(nil)
3776
3777 for i := 0; i < 3; i++ {
3778 t0 := time.Now()
3779 req := tc.req()
3780 res, err := c.Do(req)
3781 if err != nil {
3782 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3783 mu.Lock()
3784 got := logbuf.String()
3785 mu.Unlock()
3786 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3787 }
3788 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3789 }
3790 res.Body.Close()
3791 if res.Request != req {
3792 t.Errorf("Response.Request != original request; want identical Request")
3793 }
3794 }
3795
3796 mu.Lock()
3797 got := logbuf.String()
3798 mu.Unlock()
3799 want := fmt.Sprintf(`Dial
3800 Write("%s")
3801 Handler
3802 intentional write failure
3803 Retried.
3804 Dial
3805 Write("%s")
3806 Handler
3807 Write("%s")
3808 Handler
3809 `, tc.reqString, tc.reqString, tc.reqString)
3810 if got != want {
3811 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3812 }
3813 })
3814 }
3815 }
3816
3817
3818 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3819 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3820 readBody := make(chan error, 1)
3821 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3822 _, err := io.ReadAll(r.Body)
3823 readBody <- err
3824 })).ts
3825 c := ts.Client()
3826 fakeErr := errors.New("fake error")
3827 didClose := make(chan bool, 1)
3828 req, _ := NewRequest("POST", ts.URL, struct {
3829 io.Reader
3830 io.Closer
3831 }{
3832 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3833 closerFunc(func() error {
3834 select {
3835 case didClose <- true:
3836 default:
3837 }
3838 return nil
3839 }),
3840 })
3841 res, err := c.Do(req)
3842 if res != nil {
3843 defer res.Body.Close()
3844 }
3845 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3846 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3847 }
3848 if err := <-readBody; err == nil {
3849 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3850 }
3851 select {
3852 case <-didClose:
3853 default:
3854 t.Errorf("didn't see Body.Close")
3855 }
3856 }
3857
3858 func TestTransportDialTLS(t *testing.T) {
3859 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
3860 }
3861 func testTransportDialTLS(t *testing.T, mode testMode) {
3862 var mu sync.Mutex
3863 var gotReq, didDial bool
3864
3865 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3866 mu.Lock()
3867 gotReq = true
3868 mu.Unlock()
3869 })).ts
3870 c := ts.Client()
3871 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3872 mu.Lock()
3873 didDial = true
3874 mu.Unlock()
3875 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3876 if err != nil {
3877 return nil, err
3878 }
3879 return c, c.Handshake()
3880 }
3881
3882 res, err := c.Get(ts.URL)
3883 if err != nil {
3884 t.Fatal(err)
3885 }
3886 res.Body.Close()
3887 mu.Lock()
3888 if !gotReq {
3889 t.Error("didn't get request")
3890 }
3891 if !didDial {
3892 t.Error("didn't use dial hook")
3893 }
3894 }
3895
3896 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
3897 func testTransportDialContext(t *testing.T, mode testMode) {
3898 var mu sync.Mutex
3899 var gotReq bool
3900 var receivedContext context.Context
3901
3902 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3903 mu.Lock()
3904 gotReq = true
3905 mu.Unlock()
3906 })).ts
3907 c := ts.Client()
3908 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3909 mu.Lock()
3910 receivedContext = ctx
3911 mu.Unlock()
3912 return net.Dial(netw, addr)
3913 }
3914
3915 req, err := NewRequest("GET", ts.URL, nil)
3916 if err != nil {
3917 t.Fatal(err)
3918 }
3919 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3920 res, err := c.Do(req.WithContext(ctx))
3921 if err != nil {
3922 t.Fatal(err)
3923 }
3924 res.Body.Close()
3925 mu.Lock()
3926 if !gotReq {
3927 t.Error("didn't get request")
3928 }
3929 if receivedContext != ctx {
3930 t.Error("didn't receive correct context")
3931 }
3932 }
3933
3934 func TestTransportDialTLSContext(t *testing.T) {
3935 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
3936 }
3937 func testTransportDialTLSContext(t *testing.T, mode testMode) {
3938 var mu sync.Mutex
3939 var gotReq bool
3940 var receivedContext context.Context
3941
3942 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3943 mu.Lock()
3944 gotReq = true
3945 mu.Unlock()
3946 })).ts
3947 c := ts.Client()
3948 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3949 mu.Lock()
3950 receivedContext = ctx
3951 mu.Unlock()
3952 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3953 if err != nil {
3954 return nil, err
3955 }
3956 return c, c.HandshakeContext(ctx)
3957 }
3958
3959 req, err := NewRequest("GET", ts.URL, nil)
3960 if err != nil {
3961 t.Fatal(err)
3962 }
3963 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3964 res, err := c.Do(req.WithContext(ctx))
3965 if err != nil {
3966 t.Fatal(err)
3967 }
3968 res.Body.Close()
3969 mu.Lock()
3970 if !gotReq {
3971 t.Error("didn't get request")
3972 }
3973 if receivedContext != ctx {
3974 t.Error("didn't receive correct context")
3975 }
3976 }
3977
3978
3979
3980 func TestRoundTripReturnsProxyError(t *testing.T) {
3981 badProxy := func(*Request) (*url.URL, error) {
3982 return nil, errors.New("errorMessage")
3983 }
3984
3985 tr := &Transport{Proxy: badProxy}
3986
3987 req, _ := NewRequest("GET", "http://example.com", nil)
3988
3989 _, err := tr.RoundTrip(req)
3990
3991 if err == nil {
3992 t.Error("Expected proxy error to be returned by RoundTrip")
3993 }
3994 }
3995
3996
3997 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
3998 tr := &Transport{}
3999 wantIdle := func(when string, n int) bool {
4000 got := tr.IdleConnCountForTesting("http", "example.com")
4001 if got == n {
4002 return true
4003 }
4004 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4005 return false
4006 }
4007 wantIdle("start", 0)
4008 if !tr.PutIdleTestConn("http", "example.com") {
4009 t.Fatal("put failed")
4010 }
4011 if !tr.PutIdleTestConn("http", "example.com") {
4012 t.Fatal("second put failed")
4013 }
4014 wantIdle("after put", 2)
4015 tr.CloseIdleConnections()
4016 if !tr.IsIdleForTesting() {
4017 t.Error("should be idle after CloseIdleConnections")
4018 }
4019 wantIdle("after close idle", 0)
4020 if tr.PutIdleTestConn("http", "example.com") {
4021 t.Fatal("put didn't fail")
4022 }
4023 wantIdle("after second put", 0)
4024
4025 tr.QueueForIdleConnForTesting()
4026 if tr.IsIdleForTesting() {
4027 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4028 }
4029 if !tr.PutIdleTestConn("http", "example.com") {
4030 t.Fatal("after re-activation")
4031 }
4032 wantIdle("after final put", 1)
4033 }
4034
4035
4036
4037 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4038 tr := &Transport{}
4039 wantIdle := func(when string, n int) bool {
4040 got := tr.IdleConnCountForTesting("https", "example.com:443")
4041 if got == n {
4042 return true
4043 }
4044 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4045 return false
4046 }
4047 wantIdle("start", 0)
4048 alt := funcRoundTripper(func() {})
4049 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4050 t.Fatal("put failed")
4051 }
4052 wantIdle("after put", 1)
4053 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4054 GotConn: func(httptrace.GotConnInfo) {
4055
4056 t.Error("GotConn called")
4057 },
4058 })
4059 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4060 _, err := tr.RoundTrip(req)
4061 if err != errFakeRoundTrip {
4062 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4063 }
4064 wantIdle("after round trip", 1)
4065 }
4066
4067 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
4068 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
4069 }
4070 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
4071 if testing.Short() {
4072 t.Skip("skipping in short mode")
4073 }
4074
4075 timeout := 1 * time.Millisecond
4076 retry := true
4077 for retry {
4078 trFunc := func(tr *Transport) {
4079 tr.MaxConnsPerHost = 1
4080 tr.MaxIdleConnsPerHost = 1
4081 tr.IdleConnTimeout = timeout
4082 }
4083 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
4084
4085 retry = false
4086 tooShort := func(err error) bool {
4087 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
4088 return false
4089 }
4090 if !retry {
4091 t.Helper()
4092 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
4093 timeout *= 2
4094 retry = true
4095 cst.close()
4096 }
4097 return true
4098 }
4099
4100 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4101 if tooShort(err) {
4102 continue
4103 }
4104 t.Fatalf("got error: %s", err)
4105 }
4106
4107 time.Sleep(10 * timeout)
4108 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4109 if tooShort(err) {
4110 continue
4111 }
4112 t.Fatalf("got error: %s", err)
4113 }
4114 }
4115 }
4116
4117
4118
4119
4120
4121 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4122 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4123 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4124 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4125 t.Error("Transport advertised gzip support in the Accept header")
4126 }
4127 if r.Header.Get("Range") == "" {
4128 t.Error("no Range in request")
4129 }
4130 })).ts
4131 c := ts.Client()
4132
4133 req, _ := NewRequest("GET", ts.URL, nil)
4134 req.Header.Set("Range", "bytes=7-11")
4135 res, err := c.Do(req)
4136 if err != nil {
4137 t.Fatal(err)
4138 }
4139 res.Body.Close()
4140 }
4141
4142
4143 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4144 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4145 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4146
4147 var b [1024]byte
4148 w.Write(b[:])
4149 })).ts
4150 tr := ts.Client().Transport.(*Transport)
4151
4152 req, err := NewRequest("GET", ts.URL, nil)
4153 if err != nil {
4154 t.Fatal(err)
4155 }
4156 res, err := tr.RoundTrip(req)
4157 if err != nil {
4158 t.Fatal(err)
4159 }
4160
4161
4162
4163 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4164 t.Fatal(err)
4165 }
4166
4167 req2, err := NewRequest("GET", ts.URL, nil)
4168 if err != nil {
4169 t.Fatal(err)
4170 }
4171 tr.CancelRequest(req)
4172 res, err = tr.RoundTrip(req2)
4173 if err != nil {
4174 t.Fatal(err)
4175 }
4176 res.Body.Close()
4177 }
4178
4179
4180 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4181 run(t, testTransportContentEncodingCaseInsensitive)
4182 }
4183 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4184 for _, ce := range []string{"gzip", "GZIP"} {
4185 ce := ce
4186 t.Run(ce, func(t *testing.T) {
4187 const encodedString = "Hello Gopher"
4188 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4189 w.Header().Set("Content-Encoding", ce)
4190 gz := gzip.NewWriter(w)
4191 gz.Write([]byte(encodedString))
4192 gz.Close()
4193 })).ts
4194
4195 res, err := ts.Client().Get(ts.URL)
4196 if err != nil {
4197 t.Fatal(err)
4198 }
4199
4200 body, err := io.ReadAll(res.Body)
4201 res.Body.Close()
4202 if err != nil {
4203 t.Fatal(err)
4204 }
4205
4206 if string(body) != encodedString {
4207 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4208 }
4209 })
4210 }
4211 }
4212
4213 func TestTransportDialCancelRace(t *testing.T) {
4214 run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
4215 }
4216 func testTransportDialCancelRace(t *testing.T, mode testMode) {
4217 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
4218 tr := ts.Client().Transport.(*Transport)
4219
4220 req, err := NewRequest("GET", ts.URL, nil)
4221 if err != nil {
4222 t.Fatal(err)
4223 }
4224 SetEnterRoundTripHook(func() {
4225 tr.CancelRequest(req)
4226 })
4227 defer SetEnterRoundTripHook(nil)
4228 res, err := tr.RoundTrip(req)
4229 if err != ExportErrRequestCanceled {
4230 t.Errorf("expected canceled request error; got %v", err)
4231 if err == nil {
4232 res.Body.Close()
4233 }
4234 }
4235 }
4236
4237
4238 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4239 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4240 }
4241 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4242 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4243 func(tr *Transport) {
4244 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4245
4246 return &funcConn{
4247 read: func([]byte) (int, error) {
4248 return 0, errors.New("error")
4249 },
4250 write: func([]byte) (int, error) {
4251 return 0, errors.New("error")
4252 },
4253 }, nil
4254 }
4255 },
4256 ).ts
4257
4258
4259
4260
4261
4262 SetEnterRoundTripHook(func() {
4263 time.Sleep(1 * time.Millisecond)
4264 })
4265 defer SetEnterRoundTripHook(nil)
4266 var closes int
4267 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4268 if err == nil {
4269 t.Fatalf("expected request to fail, but it did not")
4270 }
4271 if closes != 1 {
4272 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4273 }
4274 }
4275
4276
4277
4278
4279 type logWritesConn struct {
4280 net.Conn
4281
4282 w io.Writer
4283
4284 rch <-chan io.Reader
4285 r io.Reader
4286
4287 mu sync.Mutex
4288 writes []string
4289 }
4290
4291 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4292 c.mu.Lock()
4293 defer c.mu.Unlock()
4294 c.writes = append(c.writes, string(p))
4295 return c.w.Write(p)
4296 }
4297
4298 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4299 if c.r == nil {
4300 c.r = <-c.rch
4301 }
4302 return c.r.Read(p)
4303 }
4304
4305 func (c *logWritesConn) Close() error { return nil }
4306
4307
4308 func TestTransportFlushesBodyChunks(t *testing.T) {
4309 defer afterTest(t)
4310 resBody := make(chan io.Reader, 1)
4311 connr, connw := io.Pipe()
4312 lw := &logWritesConn{
4313 rch: resBody,
4314 w: connw,
4315 }
4316 tr := &Transport{
4317 Dial: func(network, addr string) (net.Conn, error) {
4318 return lw, nil
4319 },
4320 }
4321 bodyr, bodyw := io.Pipe()
4322 go func() {
4323 defer bodyw.Close()
4324 for i := 0; i < 3; i++ {
4325 fmt.Fprintf(bodyw, "num%d\n", i)
4326 }
4327 }()
4328 resc := make(chan *Response)
4329 go func() {
4330 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4331 req.Header.Set("User-Agent", "x")
4332 res, err := tr.RoundTrip(req)
4333 if err != nil {
4334 t.Errorf("RoundTrip: %v", err)
4335 close(resc)
4336 return
4337 }
4338 resc <- res
4339
4340 }()
4341
4342 req, err := ReadRequest(bufio.NewReader(connr))
4343 if err != nil {
4344 t.Fatal(err)
4345 }
4346 io.Copy(io.Discard, req.Body)
4347
4348
4349 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4350 res, ok := <-resc
4351 if !ok {
4352 return
4353 }
4354 defer res.Body.Close()
4355
4356 want := []string{
4357 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4358 "5\r\nnum0\n\r\n",
4359 "5\r\nnum1\n\r\n",
4360 "5\r\nnum2\n\r\n",
4361 "0\r\n\r\n",
4362 }
4363 if !reflect.DeepEqual(lw.writes, want) {
4364 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4365 }
4366 }
4367
4368
4369 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4370 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4371 gotReq := make(chan struct{})
4372 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4373 close(gotReq)
4374 }))
4375
4376 pr, pw := io.Pipe()
4377 req, err := NewRequest("POST", cst.ts.URL, pr)
4378 if err != nil {
4379 t.Fatal(err)
4380 }
4381 gotRes := make(chan struct{})
4382 go func() {
4383 defer close(gotRes)
4384 res, err := cst.tr.RoundTrip(req)
4385 if err != nil {
4386 t.Error(err)
4387 return
4388 }
4389 res.Body.Close()
4390 }()
4391
4392 <-gotReq
4393 pw.Close()
4394 <-gotRes
4395 }
4396
4397 type wgReadCloser struct {
4398 io.Reader
4399 wg *sync.WaitGroup
4400 closed bool
4401 }
4402
4403 func (c *wgReadCloser) Close() error {
4404 if c.closed {
4405 return net.ErrClosed
4406 }
4407 c.closed = true
4408 c.wg.Done()
4409 return nil
4410 }
4411
4412
4413 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4414
4415 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4416 }
4417 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4418 if testing.Short() {
4419 t.Skip("skipping in short mode")
4420 }
4421
4422 runTimeSensitiveTest(t, []time.Duration{
4423 1 * time.Millisecond,
4424 5 * time.Millisecond,
4425 10 * time.Millisecond,
4426 50 * time.Millisecond,
4427 100 * time.Millisecond,
4428 500 * time.Millisecond,
4429 time.Second,
4430 5 * time.Second,
4431 }, func(t *testing.T, timeout time.Duration) error {
4432 SetRSTAvoidanceDelay(t, timeout)
4433 t.Logf("set RST avoidance delay to %v", timeout)
4434
4435 const contentLengthLimit = 1024 * 1024
4436 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4437 if r.ContentLength >= contentLengthLimit {
4438 w.WriteHeader(StatusBadRequest)
4439 r.Body.Close()
4440 return
4441 }
4442 w.WriteHeader(StatusOK)
4443 }))
4444
4445
4446 defer cst.close()
4447 ts := cst.ts
4448 c := ts.Client()
4449
4450 count := 100
4451
4452 bigBody := strings.Repeat("a", contentLengthLimit*2)
4453 var wg sync.WaitGroup
4454 defer wg.Wait()
4455 getBody := func() (io.ReadCloser, error) {
4456 wg.Add(1)
4457 body := &wgReadCloser{
4458 Reader: strings.NewReader(bigBody),
4459 wg: &wg,
4460 }
4461 return body, nil
4462 }
4463
4464 for i := 0; i < count; i++ {
4465 reqBody, _ := getBody()
4466 req, err := NewRequest("PUT", ts.URL, reqBody)
4467 if err != nil {
4468 reqBody.Close()
4469 t.Fatal(err)
4470 }
4471 req.ContentLength = int64(len(bigBody))
4472 req.GetBody = getBody
4473
4474 resp, err := c.Do(req)
4475 if err != nil {
4476 return fmt.Errorf("Do %d: %v", i, err)
4477 } else {
4478 resp.Body.Close()
4479 if resp.StatusCode != 400 {
4480 t.Errorf("Expected status code 400, got %v", resp.Status)
4481 }
4482 }
4483 }
4484 return nil
4485 })
4486 }
4487
4488 func TestTransportAutomaticHTTP2(t *testing.T) {
4489 testTransportAutoHTTP(t, &Transport{}, true)
4490 }
4491
4492 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4493 testTransportAutoHTTP(t, &Transport{
4494 ForceAttemptHTTP2: true,
4495 TLSClientConfig: new(tls.Config),
4496 }, true)
4497 }
4498
4499
4500 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4501 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4502 }
4503
4504 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4505 testTransportAutoHTTP(t, &Transport{
4506 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4507 }, false)
4508 }
4509
4510 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4511 testTransportAutoHTTP(t, &Transport{
4512 TLSClientConfig: new(tls.Config),
4513 }, false)
4514 }
4515
4516 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4517 testTransportAutoHTTP(t, &Transport{
4518 ExpectContinueTimeout: 1 * time.Second,
4519 }, true)
4520 }
4521
4522 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4523 var d net.Dialer
4524 testTransportAutoHTTP(t, &Transport{
4525 Dial: d.Dial,
4526 }, false)
4527 }
4528
4529 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4530 var d net.Dialer
4531 testTransportAutoHTTP(t, &Transport{
4532 DialContext: d.DialContext,
4533 }, false)
4534 }
4535
4536 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4537 testTransportAutoHTTP(t, &Transport{
4538 DialTLS: func(network, addr string) (net.Conn, error) {
4539 panic("unused")
4540 },
4541 }, false)
4542 }
4543
4544 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4545 CondSkipHTTP2(t)
4546 _, err := tr.RoundTrip(new(Request))
4547 if err == nil {
4548 t.Error("expected error from RoundTrip")
4549 }
4550 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4551 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4552 }
4553 }
4554
4555
4556
4557
4558
4559
4560
4561
4562 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4563 run(t, testTransportReuseConnEmptyResponseBody)
4564 }
4565 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4566 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4567 w.Header().Set("X-Addr", r.RemoteAddr)
4568
4569 }))
4570 n := 100
4571 if testing.Short() {
4572 n = 10
4573 }
4574 var firstAddr string
4575 for i := 0; i < n; i++ {
4576 res, err := cst.c.Get(cst.ts.URL)
4577 if err != nil {
4578 log.Fatal(err)
4579 }
4580 addr := res.Header.Get("X-Addr")
4581 if i == 0 {
4582 firstAddr = addr
4583 } else if addr != firstAddr {
4584 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4585 }
4586 res.Body.Close()
4587 }
4588 }
4589
4590
4591 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4592 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4593 if err != nil {
4594 t.Fatal(err)
4595 }
4596 ln := newLocalListener(t)
4597 defer ln.Close()
4598
4599 var wg sync.WaitGroup
4600 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4601 defer SetPendingDialHooks(nil, nil)
4602
4603 testDone := make(chan struct{})
4604 defer close(testDone)
4605 go func() {
4606 tln := tls.NewListener(ln, &tls.Config{
4607 NextProtos: []string{"foo"},
4608 Certificates: []tls.Certificate{cert},
4609 })
4610 sc, err := tln.Accept()
4611 if err != nil {
4612 t.Error(err)
4613 return
4614 }
4615 if err := sc.(*tls.Conn).Handshake(); err != nil {
4616 t.Error(err)
4617 return
4618 }
4619 <-testDone
4620 sc.Close()
4621 }()
4622
4623 addr := ln.Addr().String()
4624
4625 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4626 cancel := make(chan struct{})
4627 req.Cancel = cancel
4628
4629 doReturned := make(chan bool, 1)
4630 madeRoundTripper := make(chan bool, 1)
4631
4632 tr := &Transport{
4633 DisableKeepAlives: true,
4634 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4635 "foo": func(authority string, c *tls.Conn) RoundTripper {
4636 madeRoundTripper <- true
4637 return funcRoundTripper(func() {
4638 t.Error("foo RoundTripper should not be called")
4639 })
4640 },
4641 },
4642 Dial: func(_, _ string) (net.Conn, error) {
4643 panic("shouldn't be called")
4644 },
4645 DialTLS: func(_, _ string) (net.Conn, error) {
4646 tc, err := tls.Dial("tcp", addr, &tls.Config{
4647 InsecureSkipVerify: true,
4648 NextProtos: []string{"foo"},
4649 })
4650 if err != nil {
4651 return nil, err
4652 }
4653 if err := tc.Handshake(); err != nil {
4654 return nil, err
4655 }
4656 close(cancel)
4657 <-doReturned
4658 return tc, nil
4659 },
4660 }
4661 c := &Client{Transport: tr}
4662
4663 _, err = c.Do(req)
4664 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4665 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4666 }
4667
4668 doReturned <- true
4669 <-madeRoundTripper
4670 wg.Wait()
4671 }
4672
4673 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4674 run(t, func(t *testing.T, mode testMode) {
4675 testTransportReuseConnection_Gzip(t, mode, true)
4676 })
4677 }
4678
4679 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4680 run(t, func(t *testing.T, mode testMode) {
4681 testTransportReuseConnection_Gzip(t, mode, false)
4682 })
4683 }
4684
4685
4686 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4687 addr := make(chan string, 2)
4688 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4689 addr <- r.RemoteAddr
4690 w.Header().Set("Content-Encoding", "gzip")
4691 if chunked {
4692 w.(Flusher).Flush()
4693 }
4694 w.Write(rgz)
4695 })).ts
4696 c := ts.Client()
4697
4698 trace := &httptrace.ClientTrace{
4699 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4700 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4701 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4702 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4703 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4704 }
4705 ctx := httptrace.WithClientTrace(context.Background(), trace)
4706
4707 for i := 0; i < 2; i++ {
4708 req, _ := NewRequest("GET", ts.URL, nil)
4709 req = req.WithContext(ctx)
4710 res, err := c.Do(req)
4711 if err != nil {
4712 t.Fatal(err)
4713 }
4714 buf := make([]byte, len(rgz))
4715 if n, err := io.ReadFull(res.Body, buf); err != nil {
4716 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4717 }
4718
4719
4720
4721 }
4722 a1, a2 := <-addr, <-addr
4723 if a1 != a2 {
4724 t.Fatalf("didn't reuse connection")
4725 }
4726 }
4727
4728 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4729 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4730 if mode == http2Mode {
4731 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4732 }
4733 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4734 if r.URL.Path == "/long" {
4735 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4736 }
4737 })).ts
4738 c := ts.Client()
4739 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4740
4741 if res, err := c.Get(ts.URL); err != nil {
4742 t.Fatal(err)
4743 } else {
4744 res.Body.Close()
4745 }
4746
4747 res, err := c.Get(ts.URL + "/long")
4748 if err == nil {
4749 defer res.Body.Close()
4750 var n int64
4751 for k, vv := range res.Header {
4752 for _, v := range vv {
4753 n += int64(len(k)) + int64(len(v))
4754 }
4755 }
4756 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4757 }
4758 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4759 t.Errorf("got error: %v; want %q", err, want)
4760 }
4761 }
4762
4763 func TestTransportEventTrace(t *testing.T) {
4764 run(t, func(t *testing.T, mode testMode) {
4765 testTransportEventTrace(t, mode, false)
4766 }, testNotParallel)
4767 }
4768
4769
4770 func TestTransportEventTrace_NoHooks(t *testing.T) {
4771 run(t, func(t *testing.T, mode testMode) {
4772 testTransportEventTrace(t, mode, true)
4773 }, testNotParallel)
4774 }
4775
4776 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4777 const resBody = "some body"
4778 gotWroteReqEvent := make(chan struct{}, 500)
4779 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4780 if r.Method == "GET" {
4781
4782 return
4783 }
4784 if _, err := io.ReadAll(r.Body); err != nil {
4785 t.Error(err)
4786 }
4787 if !noHooks {
4788 <-gotWroteReqEvent
4789 }
4790 io.WriteString(w, resBody)
4791 }), func(tr *Transport) {
4792 if tr.TLSClientConfig != nil {
4793 tr.TLSClientConfig.InsecureSkipVerify = true
4794 }
4795 })
4796 defer cst.close()
4797
4798 cst.tr.ExpectContinueTimeout = 1 * time.Second
4799
4800 var mu sync.Mutex
4801 var buf strings.Builder
4802 logf := func(format string, args ...any) {
4803 mu.Lock()
4804 defer mu.Unlock()
4805 fmt.Fprintf(&buf, format, args...)
4806 buf.WriteByte('\n')
4807 }
4808
4809 addrStr := cst.ts.Listener.Addr().String()
4810 ip, port, err := net.SplitHostPort(addrStr)
4811 if err != nil {
4812 t.Fatal(err)
4813 }
4814
4815
4816 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4817 if host != "dns-is-faked.golang" {
4818 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4819 return nil, nil
4820 }
4821 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4822 })
4823
4824 body := "some body"
4825 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4826 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4827 trace := &httptrace.ClientTrace{
4828 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4829 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4830 GotFirstResponseByte: func() { logf("first response byte") },
4831 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4832 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4833 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4834 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4835 ConnectDone: func(network, addr string, err error) {
4836 if err != nil {
4837 t.Errorf("ConnectDone: %v", err)
4838 }
4839 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4840 },
4841 WroteHeaderField: func(key string, value []string) {
4842 logf("WroteHeaderField: %s: %v", key, value)
4843 },
4844 WroteHeaders: func() {
4845 logf("WroteHeaders")
4846 },
4847 Wait100Continue: func() { logf("Wait100Continue") },
4848 Got100Continue: func() { logf("Got100Continue") },
4849 WroteRequest: func(e httptrace.WroteRequestInfo) {
4850 logf("WroteRequest: %+v", e)
4851 gotWroteReqEvent <- struct{}{}
4852 },
4853 }
4854 if mode == http2Mode {
4855 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4856 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4857 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4858 }
4859 }
4860 if noHooks {
4861
4862 *trace = httptrace.ClientTrace{}
4863 }
4864 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4865
4866 req.Header.Set("Expect", "100-continue")
4867 res, err := cst.c.Do(req)
4868 if err != nil {
4869 t.Fatal(err)
4870 }
4871 logf("got roundtrip.response")
4872 slurp, err := io.ReadAll(res.Body)
4873 if err != nil {
4874 t.Fatal(err)
4875 }
4876 logf("consumed body")
4877 if string(slurp) != resBody || res.StatusCode != 200 {
4878 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4879 }
4880 res.Body.Close()
4881
4882 if noHooks {
4883
4884
4885
4886 return
4887 }
4888
4889 mu.Lock()
4890 got := buf.String()
4891 mu.Unlock()
4892
4893 wantOnce := func(sub string) {
4894 if strings.Count(got, sub) != 1 {
4895 t.Errorf("expected substring %q exactly once in output.", sub)
4896 }
4897 }
4898 wantOnceOrMore := func(sub string) {
4899 if strings.Count(got, sub) == 0 {
4900 t.Errorf("expected substring %q at least once in output.", sub)
4901 }
4902 }
4903 wantOnce("Getting conn for dns-is-faked.golang:" + port)
4904 wantOnce("DNS start: {Host:dns-is-faked.golang}")
4905 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4906 wantOnce("got conn: {")
4907 wantOnceOrMore("Connecting to tcp " + addrStr)
4908 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
4909 wantOnce("Reused:false WasIdle:false IdleTime:0s")
4910 wantOnce("first response byte")
4911 if mode == http2Mode {
4912 wantOnce("tls handshake start")
4913 wantOnce("tls handshake done")
4914 } else {
4915 wantOnce("PutIdleConn = <nil>")
4916 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
4917
4918
4919 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
4920 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
4921 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
4922 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
4923 }
4924 wantOnce("WroteHeaders")
4925 wantOnce("Wait100Continue")
4926 wantOnce("Got100Continue")
4927 wantOnce("WroteRequest: {Err:<nil>}")
4928 if strings.Contains(got, " to udp ") {
4929 t.Errorf("should not see UDP (DNS) connections")
4930 }
4931 if t.Failed() {
4932 t.Errorf("Output:\n%s", got)
4933 }
4934
4935
4936 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
4937 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4938 res, err = cst.c.Do(req)
4939 if err != nil {
4940 t.Fatal(err)
4941 }
4942 if res.StatusCode != 200 {
4943 t.Fatal(res.Status)
4944 }
4945 res.Body.Close()
4946
4947 mu.Lock()
4948 got = buf.String()
4949 mu.Unlock()
4950
4951 sub := "Getting conn for dns-is-faked.golang:"
4952 if gotn, want := strings.Count(got, sub), 2; gotn != want {
4953 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
4954 }
4955
4956 }
4957
4958 func TestTransportEventTraceTLSVerify(t *testing.T) {
4959 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
4960 }
4961 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
4962 var mu sync.Mutex
4963 var buf strings.Builder
4964 logf := func(format string, args ...any) {
4965 mu.Lock()
4966 defer mu.Unlock()
4967 fmt.Fprintf(&buf, format, args...)
4968 buf.WriteByte('\n')
4969 }
4970
4971 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4972 t.Error("Unexpected request")
4973 }), func(ts *httptest.Server) {
4974 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
4975 logf("%s", p)
4976 return len(p), nil
4977 }), "", 0)
4978 }).ts
4979
4980 certpool := x509.NewCertPool()
4981 certpool.AddCert(ts.Certificate())
4982
4983 c := &Client{Transport: &Transport{
4984 TLSClientConfig: &tls.Config{
4985 ServerName: "dns-is-faked.golang",
4986 RootCAs: certpool,
4987 },
4988 }}
4989
4990 trace := &httptrace.ClientTrace{
4991 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
4992 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
4993 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
4994 },
4995 }
4996
4997 req, _ := NewRequest("GET", ts.URL, nil)
4998 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
4999 _, err := c.Do(req)
5000 if err == nil {
5001 t.Error("Expected request to fail TLS verification")
5002 }
5003
5004 mu.Lock()
5005 got := buf.String()
5006 mu.Unlock()
5007
5008 wantOnce := func(sub string) {
5009 if strings.Count(got, sub) != 1 {
5010 t.Errorf("expected substring %q exactly once in output.", sub)
5011 }
5012 }
5013
5014 wantOnce("TLSHandshakeStart")
5015 wantOnce("TLSHandshakeDone")
5016 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5017
5018 if t.Failed() {
5019 t.Errorf("Output:\n%s", got)
5020 }
5021 }
5022
5023 var (
5024 isDNSHijackedOnce sync.Once
5025 isDNSHijacked bool
5026 )
5027
5028 func skipIfDNSHijacked(t *testing.T) {
5029
5030
5031
5032 isDNSHijackedOnce.Do(func() {
5033 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5034 isDNSHijacked = len(addrs) != 0
5035 })
5036 if isDNSHijacked {
5037 t.Skip("skipping; test requires non-hijacking DNS server")
5038 }
5039 }
5040
5041 func TestTransportEventTraceRealDNS(t *testing.T) {
5042 skipIfDNSHijacked(t)
5043 defer afterTest(t)
5044 tr := &Transport{}
5045 defer tr.CloseIdleConnections()
5046 c := &Client{Transport: tr}
5047
5048 var mu sync.Mutex
5049 var buf strings.Builder
5050 logf := func(format string, args ...any) {
5051 mu.Lock()
5052 defer mu.Unlock()
5053 fmt.Fprintf(&buf, format, args...)
5054 buf.WriteByte('\n')
5055 }
5056
5057 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5058 trace := &httptrace.ClientTrace{
5059 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5060 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5061 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5062 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5063 }
5064 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5065
5066 resp, err := c.Do(req)
5067 if err == nil {
5068 resp.Body.Close()
5069 t.Fatal("expected error during DNS lookup")
5070 }
5071
5072 mu.Lock()
5073 got := buf.String()
5074 mu.Unlock()
5075
5076 wantSub := func(sub string) {
5077 if !strings.Contains(got, sub) {
5078 t.Errorf("expected substring %q in output.", sub)
5079 }
5080 }
5081 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5082 wantSub("DNSDone: {Addrs:[] Err:")
5083 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5084 t.Errorf("should not see Connect events")
5085 }
5086 if t.Failed() {
5087 t.Errorf("Output:\n%s", got)
5088 }
5089 }
5090
5091
5092 func TestTransportRejectsAlphaPort(t *testing.T) {
5093 res, err := Get("http://dummy.tld:123foo/bar")
5094 if err == nil {
5095 res.Body.Close()
5096 t.Fatal("unexpected success")
5097 }
5098 ue, ok := err.(*url.Error)
5099 if !ok {
5100 t.Fatalf("got %#v; want *url.Error", err)
5101 }
5102 got := ue.Err.Error()
5103 want := `invalid port ":123foo" after host`
5104 if got != want {
5105 t.Errorf("got error %q; want %q", got, want)
5106 }
5107 }
5108
5109
5110
5111 func TestTLSHandshakeTrace(t *testing.T) {
5112 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5113 }
5114 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5115 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5116
5117 var mu sync.Mutex
5118 var start, done bool
5119 trace := &httptrace.ClientTrace{
5120 TLSHandshakeStart: func() {
5121 mu.Lock()
5122 defer mu.Unlock()
5123 start = true
5124 },
5125 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5126 mu.Lock()
5127 defer mu.Unlock()
5128 done = true
5129 if err != nil {
5130 t.Fatal("Expected error to be nil but was:", err)
5131 }
5132 },
5133 }
5134
5135 c := ts.Client()
5136 req, err := NewRequest("GET", ts.URL, nil)
5137 if err != nil {
5138 t.Fatal("Unable to construct test request:", err)
5139 }
5140 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5141
5142 r, err := c.Do(req)
5143 if err != nil {
5144 t.Fatal("Unexpected error making request:", err)
5145 }
5146 r.Body.Close()
5147 mu.Lock()
5148 defer mu.Unlock()
5149 if !start {
5150 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5151 }
5152 if !done {
5153 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5154 }
5155 }
5156
5157 func TestTransportMaxIdleConns(t *testing.T) {
5158 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5159 }
5160 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5161 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5162
5163 })).ts
5164 c := ts.Client()
5165 tr := c.Transport.(*Transport)
5166 tr.MaxIdleConns = 4
5167
5168 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5169 if err != nil {
5170 t.Fatal(err)
5171 }
5172 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5173 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5174 })
5175
5176 hitHost := func(n int) {
5177 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5178 req = req.WithContext(ctx)
5179 res, err := c.Do(req)
5180 if err != nil {
5181 t.Fatal(err)
5182 }
5183 res.Body.Close()
5184 }
5185 for i := 0; i < 4; i++ {
5186 hitHost(i)
5187 }
5188 want := []string{
5189 "|http|host-0.dns-is-faked.golang:" + port,
5190 "|http|host-1.dns-is-faked.golang:" + port,
5191 "|http|host-2.dns-is-faked.golang:" + port,
5192 "|http|host-3.dns-is-faked.golang:" + port,
5193 }
5194 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5195 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5196 }
5197
5198
5199 hitHost(4)
5200 want = []string{
5201 "|http|host-1.dns-is-faked.golang:" + port,
5202 "|http|host-2.dns-is-faked.golang:" + port,
5203 "|http|host-3.dns-is-faked.golang:" + port,
5204 "|http|host-4.dns-is-faked.golang:" + port,
5205 }
5206 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5207 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5208 }
5209 }
5210
5211 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5212 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5213 if testing.Short() {
5214 t.Skip("skipping in short mode")
5215 }
5216
5217 timeout := 1 * time.Millisecond
5218 timeoutLoop:
5219 for {
5220 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5221
5222 }))
5223 tr := cst.tr
5224 tr.IdleConnTimeout = timeout
5225 defer tr.CloseIdleConnections()
5226 c := &Client{Transport: tr}
5227
5228 idleConns := func() []string {
5229 if mode == http2Mode {
5230 return tr.IdleConnStrsForTesting_h2()
5231 } else {
5232 return tr.IdleConnStrsForTesting()
5233 }
5234 }
5235
5236 var conn string
5237 doReq := func(n int) (timeoutOk bool) {
5238 req, _ := NewRequest("GET", cst.ts.URL, nil)
5239 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5240 PutIdleConn: func(err error) {
5241 if err != nil {
5242 t.Errorf("failed to keep idle conn: %v", err)
5243 }
5244 },
5245 }))
5246 res, err := c.Do(req)
5247 if err != nil {
5248 if strings.Contains(err.Error(), "use of closed network connection") {
5249 t.Logf("req %v: connection closed prematurely", n)
5250 return false
5251 }
5252 }
5253 res.Body.Close()
5254 conns := idleConns()
5255 if len(conns) != 1 {
5256 if len(conns) == 0 {
5257 t.Logf("req %v: no idle conns", n)
5258 return false
5259 }
5260 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5261 }
5262 if conn == "" {
5263 conn = conns[0]
5264 }
5265 if conn != conns[0] {
5266 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5267 return false
5268 }
5269 return true
5270 }
5271 for i := 0; i < 3; i++ {
5272 if !doReq(i) {
5273 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5274 timeout *= 2
5275 cst.close()
5276 continue timeoutLoop
5277 }
5278 time.Sleep(timeout / 2)
5279 }
5280
5281 waitCondition(t, timeout/2, func(d time.Duration) bool {
5282 if got := idleConns(); len(got) != 0 {
5283 if d >= timeout*3/2 {
5284 t.Logf("after %v, idle conns = %q", d, got)
5285 }
5286 return false
5287 }
5288 return true
5289 })
5290 break
5291 }
5292 }
5293
5294
5295
5296
5297
5298
5299
5300
5301
5302
5303
5304
5305 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5306 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5307 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5308
5309 }))
5310
5311 ctx, cancel := context.WithCancel(context.Background())
5312 defer cancel()
5313
5314 sawDoErr := make(chan bool, 1)
5315 testDone := make(chan struct{})
5316 defer close(testDone)
5317
5318 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5319 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5320 c, err := tls.Dial(network, addr, &tls.Config{
5321 InsecureSkipVerify: true,
5322 NextProtos: []string{"h2"},
5323 })
5324 if err != nil {
5325 t.Error(err)
5326 return nil, err
5327 }
5328 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5329 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5330 c.Close()
5331 return nil, errors.New("bogus")
5332 }
5333
5334 cancel()
5335
5336 select {
5337 case <-sawDoErr:
5338 case <-testDone:
5339 }
5340 return c, nil
5341 }
5342
5343 req, _ := NewRequest("GET", cst.ts.URL, nil)
5344 req = req.WithContext(ctx)
5345 res, err := cst.c.Do(req)
5346 if err == nil {
5347 res.Body.Close()
5348 t.Fatal("unexpected success")
5349 }
5350 sawDoErr <- true
5351
5352
5353 time.Sleep(cst.tr.IdleConnTimeout * 10)
5354 }
5355
5356 type funcConn struct {
5357 net.Conn
5358 read func([]byte) (int, error)
5359 write func([]byte) (int, error)
5360 }
5361
5362 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5363 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5364 func (c funcConn) Close() error { return nil }
5365
5366
5367
5368 func TestTransportReturnsPeekError(t *testing.T) {
5369 errValue := errors.New("specific error value")
5370
5371 wrote := make(chan struct{})
5372 var wroteOnce sync.Once
5373
5374 tr := &Transport{
5375 Dial: func(network, addr string) (net.Conn, error) {
5376 c := funcConn{
5377 read: func([]byte) (int, error) {
5378 <-wrote
5379 return 0, errValue
5380 },
5381 write: func(p []byte) (int, error) {
5382 wroteOnce.Do(func() { close(wrote) })
5383 return len(p), nil
5384 },
5385 }
5386 return c, nil
5387 },
5388 }
5389 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5390 if err != errValue {
5391 t.Errorf("error = %#v; want %v", err, errValue)
5392 }
5393 }
5394
5395
5396 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5397 func testTransportIDNA(t *testing.T, mode testMode) {
5398 const uniDomain = "гофер.го"
5399 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5400
5401 var port string
5402 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5403 want := punyDomain + ":" + port
5404 if r.Host != want {
5405 t.Errorf("Host header = %q; want %q", r.Host, want)
5406 }
5407 if mode == http2Mode {
5408 if r.TLS == nil {
5409 t.Errorf("r.TLS == nil")
5410 } else if r.TLS.ServerName != punyDomain {
5411 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5412 }
5413 }
5414 w.Header().Set("Hit-Handler", "1")
5415 }), func(tr *Transport) {
5416 if tr.TLSClientConfig != nil {
5417 tr.TLSClientConfig.InsecureSkipVerify = true
5418 }
5419 })
5420
5421 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5422 if err != nil {
5423 t.Fatal(err)
5424 }
5425
5426
5427 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5428 if host != punyDomain {
5429 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5430 return nil, nil
5431 }
5432 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5433 })
5434
5435 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5436 trace := &httptrace.ClientTrace{
5437 GetConn: func(hostPort string) {
5438 want := net.JoinHostPort(punyDomain, port)
5439 if hostPort != want {
5440 t.Errorf("getting conn for %q; want %q", hostPort, want)
5441 }
5442 },
5443 DNSStart: func(e httptrace.DNSStartInfo) {
5444 if e.Host != punyDomain {
5445 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5446 }
5447 },
5448 }
5449 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5450
5451 res, err := cst.tr.RoundTrip(req)
5452 if err != nil {
5453 t.Fatal(err)
5454 }
5455 defer res.Body.Close()
5456 if res.Header.Get("Hit-Handler") != "1" {
5457 out, err := httputil.DumpResponse(res, true)
5458 if err != nil {
5459 t.Fatal(err)
5460 }
5461 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5462 }
5463 }
5464
5465
5466 func TestTransportProxyConnectHeader(t *testing.T) {
5467 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5468 }
5469 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5470 reqc := make(chan *Request, 1)
5471 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5472 if r.Method != "CONNECT" {
5473 t.Errorf("method = %q; want CONNECT", r.Method)
5474 }
5475 reqc <- r
5476 c, _, err := w.(Hijacker).Hijack()
5477 if err != nil {
5478 t.Errorf("Hijack: %v", err)
5479 return
5480 }
5481 c.Close()
5482 })).ts
5483
5484 c := ts.Client()
5485 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5486 return url.Parse(ts.URL)
5487 }
5488 c.Transport.(*Transport).ProxyConnectHeader = Header{
5489 "User-Agent": {"foo"},
5490 "Other": {"bar"},
5491 }
5492
5493 res, err := c.Get("https://dummy.tld/")
5494 if err == nil {
5495 res.Body.Close()
5496 t.Errorf("unexpected success")
5497 }
5498
5499 r := <-reqc
5500 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5501 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5502 }
5503 if got, want := r.Header.Get("Other"), "bar"; got != want {
5504 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5505 }
5506 }
5507
5508 func TestTransportProxyGetConnectHeader(t *testing.T) {
5509 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5510 }
5511 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5512 reqc := make(chan *Request, 1)
5513 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5514 if r.Method != "CONNECT" {
5515 t.Errorf("method = %q; want CONNECT", r.Method)
5516 }
5517 reqc <- r
5518 c, _, err := w.(Hijacker).Hijack()
5519 if err != nil {
5520 t.Errorf("Hijack: %v", err)
5521 return
5522 }
5523 c.Close()
5524 })).ts
5525
5526 c := ts.Client()
5527 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5528 return url.Parse(ts.URL)
5529 }
5530
5531 c.Transport.(*Transport).ProxyConnectHeader = Header{
5532 "User-Agent": {"foo"},
5533 "Other": {"bar"},
5534 }
5535 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5536 return Header{
5537 "User-Agent": {"foo2"},
5538 "Other": {"bar2"},
5539 }, nil
5540 }
5541
5542 res, err := c.Get("https://dummy.tld/")
5543 if err == nil {
5544 res.Body.Close()
5545 t.Errorf("unexpected success")
5546 }
5547
5548 r := <-reqc
5549 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5550 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5551 }
5552 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5553 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5554 }
5555 }
5556
5557 var errFakeRoundTrip = errors.New("fake roundtrip")
5558
5559 type funcRoundTripper func()
5560
5561 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5562 fn()
5563 return nil, errFakeRoundTrip
5564 }
5565
5566 func wantBody(res *Response, err error, want string) error {
5567 if err != nil {
5568 return err
5569 }
5570 slurp, err := io.ReadAll(res.Body)
5571 if err != nil {
5572 return fmt.Errorf("error reading body: %v", err)
5573 }
5574 if string(slurp) != want {
5575 return fmt.Errorf("body = %q; want %q", slurp, want)
5576 }
5577 if err := res.Body.Close(); err != nil {
5578 return fmt.Errorf("body Close = %v", err)
5579 }
5580 return nil
5581 }
5582
5583 func newLocalListener(t *testing.T) net.Listener {
5584 ln, err := net.Listen("tcp", "127.0.0.1:0")
5585 if err != nil {
5586 ln, err = net.Listen("tcp6", "[::1]:0")
5587 }
5588 if err != nil {
5589 t.Fatal(err)
5590 }
5591 return ln
5592 }
5593
5594 type countCloseReader struct {
5595 n *int
5596 io.Reader
5597 }
5598
5599 func (cr countCloseReader) Close() error {
5600 (*cr.n)++
5601 return nil
5602 }
5603
5604
5605 var rgz = []byte{
5606 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5607 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5608 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5609 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5610 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5611 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5612 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5613 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5614 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5615 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5616 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5617 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5618 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5619 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5620 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5621 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5622 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5623 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5624 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5625 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5626 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5627 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5628 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5629 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5630 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5631 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5632 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5633 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5634 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5635 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5636 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5637 0x00, 0x00,
5638 }
5639
5640
5641
5642 func TestMissingStatusNoPanic(t *testing.T) {
5643 t.Parallel()
5644
5645 const want = "unknown status code"
5646
5647 ln := newLocalListener(t)
5648 addr := ln.Addr().String()
5649 done := make(chan bool)
5650 fullAddrURL := fmt.Sprintf("http://%s", addr)
5651 raw := "HTTP/1.1 400\r\n" +
5652 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5653 "Content-Type: text/html; charset=utf-8\r\n" +
5654 "Content-Length: 10\r\n" +
5655 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5656 "Vary: Accept-Encoding\r\n\r\n" +
5657 "Aloha Olaa"
5658
5659 go func() {
5660 defer close(done)
5661
5662 conn, _ := ln.Accept()
5663 if conn != nil {
5664 io.WriteString(conn, raw)
5665 io.ReadAll(conn)
5666 conn.Close()
5667 }
5668 }()
5669
5670 proxyURL, err := url.Parse(fullAddrURL)
5671 if err != nil {
5672 t.Fatalf("proxyURL: %v", err)
5673 }
5674
5675 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5676
5677 req, _ := NewRequest("GET", "https://golang.org/", nil)
5678 res, err, panicked := doFetchCheckPanic(tr, req)
5679 if panicked {
5680 t.Error("panicked, expecting an error")
5681 }
5682 if res != nil && res.Body != nil {
5683 io.Copy(io.Discard, res.Body)
5684 res.Body.Close()
5685 }
5686
5687 if err == nil || !strings.Contains(err.Error(), want) {
5688 t.Errorf("got=%v want=%q", err, want)
5689 }
5690
5691 ln.Close()
5692 <-done
5693 }
5694
5695 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5696 defer func() {
5697 if r := recover(); r != nil {
5698 panicked = true
5699 }
5700 }()
5701 res, err = tr.RoundTrip(req)
5702 return
5703 }
5704
5705
5706
5707 func TestNoBodyOnChunked304Response(t *testing.T) {
5708 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5709 }
5710 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5711 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5712 conn, buf, _ := w.(Hijacker).Hijack()
5713 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5714 buf.Flush()
5715 conn.Close()
5716 }))
5717
5718
5719
5720
5721
5722 cst.tr.DisableKeepAlives = true
5723
5724 res, err := cst.c.Get(cst.ts.URL)
5725 if err != nil {
5726 t.Fatal(err)
5727 }
5728
5729 if res.Body != NoBody {
5730 t.Errorf("Unexpected body on 304 response")
5731 }
5732 }
5733
5734 type funcWriter func([]byte) (int, error)
5735
5736 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5737
5738 type doneContext struct {
5739 context.Context
5740 err error
5741 }
5742
5743 func (doneContext) Done() <-chan struct{} {
5744 c := make(chan struct{})
5745 close(c)
5746 return c
5747 }
5748
5749 func (d doneContext) Err() error { return d.err }
5750
5751
5752 func TestTransportCheckContextDoneEarly(t *testing.T) {
5753 tr := &Transport{}
5754 req, _ := NewRequest("GET", "http://fake.example/", nil)
5755 wantErr := errors.New("some error")
5756 req = req.WithContext(doneContext{context.Background(), wantErr})
5757 _, err := tr.RoundTrip(req)
5758 if err != wantErr {
5759 t.Errorf("error = %v; want %v", err, wantErr)
5760 }
5761 }
5762
5763
5764
5765
5766
5767
5768 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5769 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5770 }
5771 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5772 timeout := 1 * time.Millisecond
5773 for {
5774 inHandler := make(chan bool)
5775 cancelHandler := make(chan struct{})
5776 handlerDone := make(chan bool)
5777 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5778 <-r.Context().Done()
5779
5780 select {
5781 case <-cancelHandler:
5782 return
5783 case inHandler <- true:
5784 }
5785 defer func() { handlerDone <- true }()
5786
5787
5788 conn, _, err := w.(Hijacker).Hijack()
5789 if err != nil {
5790 t.Error(err)
5791 return
5792 }
5793 n, err := conn.Read([]byte{0})
5794 if n != 0 || err != io.EOF {
5795 t.Errorf("unexpected Read result: %v, %v", n, err)
5796 }
5797 conn.Close()
5798 }))
5799
5800 cst.c.Timeout = timeout
5801
5802 _, err := cst.c.Get(cst.ts.URL)
5803 if err == nil {
5804 close(cancelHandler)
5805 t.Fatal("unexpected Get success")
5806 }
5807
5808 tooSlow := time.NewTimer(timeout * 10)
5809 select {
5810 case <-tooSlow.C:
5811
5812
5813
5814 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5815 close(cancelHandler)
5816 cst.close()
5817 timeout *= 2
5818 continue
5819 case <-inHandler:
5820 tooSlow.Stop()
5821 <-handlerDone
5822 }
5823 break
5824 }
5825 }
5826
5827
5828
5829
5830
5831
5832 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5833 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5834 }
5835 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5836 inHandler := make(chan bool)
5837 cancelHandler := make(chan struct{})
5838 handlerDone := make(chan bool)
5839 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5840 w.Header().Set("Content-Length", "100")
5841 w.(Flusher).Flush()
5842
5843 select {
5844 case <-cancelHandler:
5845 return
5846 case inHandler <- true:
5847 }
5848 defer func() { handlerDone <- true }()
5849
5850 conn, _, err := w.(Hijacker).Hijack()
5851 if err != nil {
5852 t.Error(err)
5853 return
5854 }
5855 conn.Write([]byte("foo"))
5856
5857 n, err := conn.Read([]byte{0})
5858
5859
5860
5861
5862
5863 if n != 0 || err == nil {
5864 t.Errorf("unexpected Read result: %v, %v", n, err)
5865 }
5866 conn.Close()
5867 }))
5868
5869
5870
5871
5872
5873 cst.c.Timeout = 24 * time.Hour
5874 req, _ := NewRequest("GET", cst.ts.URL, nil)
5875 cancelReq := make(chan struct{})
5876 req.Cancel = cancelReq
5877
5878 res, err := cst.c.Do(req)
5879 if err != nil {
5880 close(cancelHandler)
5881 t.Fatalf("Get error: %v", err)
5882 }
5883
5884
5885
5886
5887 close(cancelReq)
5888 got, err := io.ReadAll(res.Body)
5889 if err == nil {
5890 t.Errorf("unexpected success; read %q, nil", got)
5891 }
5892
5893
5894 <-inHandler
5895 <-handlerDone
5896 }
5897
5898 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5899 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
5900 }
5901 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
5902 done := make(chan struct{})
5903 defer close(done)
5904 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5905 conn, _, err := w.(Hijacker).Hijack()
5906 if err != nil {
5907 t.Error(err)
5908 return
5909 }
5910 defer conn.Close()
5911 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
5912 bs := bufio.NewScanner(conn)
5913 bs.Scan()
5914 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
5915 <-done
5916 }))
5917
5918 req, _ := NewRequest("GET", cst.ts.URL, nil)
5919 req.Header.Set("Upgrade", "foo")
5920 req.Header.Set("Connection", "upgrade")
5921 res, err := cst.c.Do(req)
5922 if err != nil {
5923 t.Fatal(err)
5924 }
5925 if res.StatusCode != 101 {
5926 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
5927 }
5928 rwc, ok := res.Body.(io.ReadWriteCloser)
5929 if !ok {
5930 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
5931 }
5932 defer rwc.Close()
5933 bs := bufio.NewScanner(rwc)
5934 if !bs.Scan() {
5935 t.Fatalf("expected readable input")
5936 }
5937 if got, want := bs.Text(), "Some buffered data"; got != want {
5938 t.Errorf("read %q; want %q", got, want)
5939 }
5940 io.WriteString(rwc, "echo\n")
5941 if !bs.Scan() {
5942 t.Fatalf("expected another line")
5943 }
5944 if got, want := bs.Text(), "ECHO"; got != want {
5945 t.Errorf("read %q; want %q", got, want)
5946 }
5947 }
5948
5949 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
5950 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
5951 const target = "backend:443"
5952 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5953 if r.Method != "CONNECT" {
5954 t.Errorf("unexpected method %q", r.Method)
5955 w.WriteHeader(500)
5956 return
5957 }
5958 if r.RequestURI != target {
5959 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
5960 w.WriteHeader(500)
5961 return
5962 }
5963 nc, brw, err := w.(Hijacker).Hijack()
5964 if err != nil {
5965 t.Error(err)
5966 return
5967 }
5968 defer nc.Close()
5969 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
5970
5971 for {
5972 line, err := brw.ReadString('\n')
5973 if err != nil {
5974 if err != io.EOF {
5975 t.Error(err)
5976 }
5977 return
5978 }
5979 io.WriteString(brw, strings.ToUpper(line))
5980 brw.Flush()
5981 }
5982 }))
5983 pr, pw := io.Pipe()
5984 defer pw.Close()
5985 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
5986 if err != nil {
5987 t.Fatal(err)
5988 }
5989 req.URL.Opaque = target
5990 res, err := cst.c.Do(req)
5991 if err != nil {
5992 t.Fatal(err)
5993 }
5994 defer res.Body.Close()
5995 if res.StatusCode != 200 {
5996 t.Fatalf("status code = %d; want 200", res.StatusCode)
5997 }
5998 br := bufio.NewReader(res.Body)
5999 for _, str := range []string{"foo", "bar", "baz"} {
6000 fmt.Fprintf(pw, "%s\n", str)
6001 got, err := br.ReadString('\n')
6002 if err != nil {
6003 t.Fatal(err)
6004 }
6005 got = strings.TrimSpace(got)
6006 want := strings.ToUpper(str)
6007 if got != want {
6008 t.Fatalf("got %q; want %q", got, want)
6009 }
6010 }
6011 }
6012
6013 func TestTransportRequestReplayable(t *testing.T) {
6014 someBody := io.NopCloser(strings.NewReader(""))
6015 tests := []struct {
6016 name string
6017 req *Request
6018 want bool
6019 }{
6020 {
6021 name: "GET",
6022 req: &Request{Method: "GET"},
6023 want: true,
6024 },
6025 {
6026 name: "GET_http.NoBody",
6027 req: &Request{Method: "GET", Body: NoBody},
6028 want: true,
6029 },
6030 {
6031 name: "GET_body",
6032 req: &Request{Method: "GET", Body: someBody},
6033 want: false,
6034 },
6035 {
6036 name: "POST",
6037 req: &Request{Method: "POST"},
6038 want: false,
6039 },
6040 {
6041 name: "POST_idempotency-key",
6042 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6043 want: true,
6044 },
6045 {
6046 name: "POST_x-idempotency-key",
6047 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6048 want: true,
6049 },
6050 {
6051 name: "POST_body",
6052 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6053 want: false,
6054 },
6055 }
6056 for _, tt := range tests {
6057 t.Run(tt.name, func(t *testing.T) {
6058 got := tt.req.ExportIsReplayable()
6059 if got != tt.want {
6060 t.Errorf("replyable = %v; want %v", got, tt.want)
6061 }
6062 })
6063 }
6064 }
6065
6066
6067
6068 type testMockTCPConn struct {
6069 *net.TCPConn
6070
6071 ReadFromCalled bool
6072 }
6073
6074 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6075 c.ReadFromCalled = true
6076 return c.TCPConn.ReadFrom(r)
6077 }
6078
6079 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6080 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6081 nBytes := int64(1 << 10)
6082 newFileFunc := func() (r io.Reader, done func(), err error) {
6083 f, err := os.CreateTemp("", "net-http-newfilefunc")
6084 if err != nil {
6085 return nil, nil, err
6086 }
6087
6088
6089 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6090 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6091 }
6092 if _, err := f.Seek(0, 0); err != nil {
6093 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6094 }
6095
6096 done = func() {
6097 f.Close()
6098 os.Remove(f.Name())
6099 }
6100
6101 return f, done, nil
6102 }
6103
6104 newBufferFunc := func() (io.Reader, func(), error) {
6105 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6106 }
6107
6108 cases := []struct {
6109 name string
6110 readerFunc func() (io.Reader, func(), error)
6111 contentLength int64
6112 expectedReadFrom bool
6113 }{
6114 {
6115 name: "file, length",
6116 readerFunc: newFileFunc,
6117 contentLength: nBytes,
6118 expectedReadFrom: true,
6119 },
6120 {
6121 name: "file, no length",
6122 readerFunc: newFileFunc,
6123 },
6124 {
6125 name: "file, negative length",
6126 readerFunc: newFileFunc,
6127 contentLength: -1,
6128 },
6129 {
6130 name: "buffer",
6131 contentLength: nBytes,
6132 readerFunc: newBufferFunc,
6133 },
6134 {
6135 name: "buffer, no length",
6136 readerFunc: newBufferFunc,
6137 },
6138 {
6139 name: "buffer, length -1",
6140 contentLength: -1,
6141 readerFunc: newBufferFunc,
6142 },
6143 }
6144
6145 for _, tc := range cases {
6146 t.Run(tc.name, func(t *testing.T) {
6147 r, cleanup, err := tc.readerFunc()
6148 if err != nil {
6149 t.Fatal(err)
6150 }
6151 defer cleanup()
6152
6153 tConn := &testMockTCPConn{}
6154 trFunc := func(tr *Transport) {
6155 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6156 var d net.Dialer
6157 conn, err := d.DialContext(ctx, network, addr)
6158 if err != nil {
6159 return nil, err
6160 }
6161
6162 tcpConn, ok := conn.(*net.TCPConn)
6163 if !ok {
6164 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6165 }
6166
6167 tConn.TCPConn = tcpConn
6168 return tConn, nil
6169 }
6170 }
6171
6172 cst := newClientServerTest(
6173 t,
6174 mode,
6175 HandlerFunc(func(w ResponseWriter, r *Request) {
6176 io.Copy(io.Discard, r.Body)
6177 r.Body.Close()
6178 w.WriteHeader(200)
6179 }),
6180 trFunc,
6181 )
6182
6183 req, err := NewRequest("PUT", cst.ts.URL, r)
6184 if err != nil {
6185 t.Fatal(err)
6186 }
6187 req.ContentLength = tc.contentLength
6188 req.Header.Set("Content-Type", "application/octet-stream")
6189 resp, err := cst.c.Do(req)
6190 if err != nil {
6191 t.Fatal(err)
6192 }
6193 defer resp.Body.Close()
6194 if resp.StatusCode != 200 {
6195 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6196 }
6197
6198 expectedReadFrom := tc.expectedReadFrom
6199 if mode != http1Mode {
6200 expectedReadFrom = false
6201 }
6202 if !tConn.ReadFromCalled && expectedReadFrom {
6203 t.Fatalf("did not call ReadFrom")
6204 }
6205
6206 if tConn.ReadFromCalled && !expectedReadFrom {
6207 t.Fatalf("ReadFrom was unexpectedly invoked")
6208 }
6209 })
6210 }
6211 }
6212
6213 func TestTransportClone(t *testing.T) {
6214 tr := &Transport{
6215 Proxy: func(*Request) (*url.URL, error) { panic("") },
6216 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6217 return nil
6218 },
6219 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6220 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6221 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6222 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6223 TLSClientConfig: new(tls.Config),
6224 TLSHandshakeTimeout: time.Second,
6225 DisableKeepAlives: true,
6226 DisableCompression: true,
6227 MaxIdleConns: 1,
6228 MaxIdleConnsPerHost: 1,
6229 MaxConnsPerHost: 1,
6230 IdleConnTimeout: time.Second,
6231 ResponseHeaderTimeout: time.Second,
6232 ExpectContinueTimeout: time.Second,
6233 ProxyConnectHeader: Header{},
6234 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6235 MaxResponseHeaderBytes: 1,
6236 ForceAttemptHTTP2: true,
6237 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6238 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6239 },
6240 ReadBufferSize: 1,
6241 WriteBufferSize: 1,
6242 }
6243 tr2 := tr.Clone()
6244 rv := reflect.ValueOf(tr2).Elem()
6245 rt := rv.Type()
6246 for i := 0; i < rt.NumField(); i++ {
6247 sf := rt.Field(i)
6248 if !token.IsExported(sf.Name) {
6249 continue
6250 }
6251 if rv.Field(i).IsZero() {
6252 t.Errorf("cloned field t2.%s is zero", sf.Name)
6253 }
6254 }
6255
6256 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6257 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6258 }
6259
6260
6261 tr = new(Transport)
6262 tr2 = tr.Clone()
6263 if tr2.TLSNextProto != nil {
6264 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6265 }
6266 }
6267
6268 func TestIs408(t *testing.T) {
6269 tests := []struct {
6270 in string
6271 want bool
6272 }{
6273 {"HTTP/1.0 408", true},
6274 {"HTTP/1.1 408", true},
6275 {"HTTP/1.8 408", true},
6276 {"HTTP/2.0 408", false},
6277 {"HTTP/1.1 408 ", true},
6278 {"HTTP/1.1 40", false},
6279 {"http/1.0 408", false},
6280 {"HTTP/1-1 408", false},
6281 }
6282 for _, tt := range tests {
6283 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6284 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6285 }
6286 }
6287 }
6288
6289 func TestTransportIgnores408(t *testing.T) {
6290 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6291 }
6292 func testTransportIgnores408(t *testing.T, mode testMode) {
6293
6294 defer log.SetOutput(log.Writer())
6295
6296 var logout strings.Builder
6297 log.SetOutput(&logout)
6298
6299 const target = "backend:443"
6300
6301 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6302 nc, _, err := w.(Hijacker).Hijack()
6303 if err != nil {
6304 t.Error(err)
6305 return
6306 }
6307 defer nc.Close()
6308 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6309 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6310 }))
6311 req, err := NewRequest("GET", cst.ts.URL, nil)
6312 if err != nil {
6313 t.Fatal(err)
6314 }
6315 res, err := cst.c.Do(req)
6316 if err != nil {
6317 t.Fatal(err)
6318 }
6319 slurp, err := io.ReadAll(res.Body)
6320 if err != nil {
6321 t.Fatal(err)
6322 }
6323 if err != nil {
6324 t.Fatal(err)
6325 }
6326 if string(slurp) != "ok" {
6327 t.Fatalf("got %q; want ok", slurp)
6328 }
6329
6330 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6331 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6332 if d > 0 {
6333 t.Logf("%v idle conns still present after %v", n, d)
6334 }
6335 return false
6336 }
6337 return true
6338 })
6339 if got := logout.String(); got != "" {
6340 t.Fatalf("expected no log output; got: %s", got)
6341 }
6342 }
6343
6344 func TestInvalidHeaderResponse(t *testing.T) {
6345 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6346 }
6347 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6348 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6349 conn, buf, _ := w.(Hijacker).Hijack()
6350 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6351 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6352 "Content-Type: text/html; charset=utf-8\r\n" +
6353 "Content-Length: 0\r\n" +
6354 "Foo : bar\r\n\r\n"))
6355 buf.Flush()
6356 conn.Close()
6357 }))
6358 res, err := cst.c.Get(cst.ts.URL)
6359 if err != nil {
6360 t.Fatal(err)
6361 }
6362 defer res.Body.Close()
6363 if v := res.Header.Get("Foo"); v != "" {
6364 t.Errorf(`unexpected "Foo" header: %q`, v)
6365 }
6366 if v := res.Header.Get("Foo "); v != "bar" {
6367 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6368 }
6369 }
6370
6371 type bodyCloser bool
6372
6373 func (bc *bodyCloser) Close() error {
6374 *bc = true
6375 return nil
6376 }
6377 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6378 return 0, io.EOF
6379 }
6380
6381
6382
6383 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6384 run(t, testTransportClosesBodyOnInvalidRequests)
6385 }
6386 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6387 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6388 t.Errorf("Should not have been invoked")
6389 })).ts
6390
6391 u, _ := url.Parse(cst.URL)
6392
6393 tests := []struct {
6394 name string
6395 req *Request
6396 wantErr string
6397 }{
6398 {
6399 name: "invalid method",
6400 req: &Request{
6401 Method: " ",
6402 URL: u,
6403 },
6404 wantErr: `invalid method " "`,
6405 },
6406 {
6407 name: "nil URL",
6408 req: &Request{
6409 Method: "GET",
6410 },
6411 wantErr: `nil Request.URL`,
6412 },
6413 {
6414 name: "invalid header key",
6415 req: &Request{
6416 Method: "GET",
6417 Header: Header{"💡": {"emoji"}},
6418 URL: u,
6419 },
6420 wantErr: `invalid header field name "💡"`,
6421 },
6422 {
6423 name: "invalid header value",
6424 req: &Request{
6425 Method: "POST",
6426 Header: Header{"key": {"\x19"}},
6427 URL: u,
6428 },
6429 wantErr: `invalid header field value for "key"`,
6430 },
6431 {
6432 name: "non HTTP(s) scheme",
6433 req: &Request{
6434 Method: "POST",
6435 URL: &url.URL{Scheme: "faux"},
6436 },
6437 wantErr: `unsupported protocol scheme "faux"`,
6438 },
6439 {
6440 name: "no Host in URL",
6441 req: &Request{
6442 Method: "POST",
6443 URL: &url.URL{Scheme: "http"},
6444 },
6445 wantErr: `no Host in request URL`,
6446 },
6447 }
6448
6449 for _, tt := range tests {
6450 t.Run(tt.name, func(t *testing.T) {
6451 var bc bodyCloser
6452 req := tt.req
6453 req.Body = &bc
6454 _, err := cst.Client().Do(tt.req)
6455 if err == nil {
6456 t.Fatal("Expected an error")
6457 }
6458 if !bc {
6459 t.Fatal("Expected body to have been closed")
6460 }
6461 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6462 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6463 }
6464 })
6465 }
6466 }
6467
6468
6469
6470 type breakableConn struct {
6471 net.Conn
6472 *brokenState
6473 }
6474
6475 type brokenState struct {
6476 sync.Mutex
6477 broken bool
6478 }
6479
6480 func (w *breakableConn) Write(b []byte) (n int, err error) {
6481 w.Lock()
6482 defer w.Unlock()
6483 if w.broken {
6484 return 0, errors.New("some write error")
6485 }
6486 return w.Conn.Write(b)
6487 }
6488
6489
6490 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6491 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6492 }
6493 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6494 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6495
6496 var brokenState brokenState
6497
6498 const numReqs = 5
6499 var numDials, gotConns uint32
6500
6501 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6502 atomic.AddUint32(&numDials, 1)
6503 c, err := net.Dial(netw, addr)
6504 if err != nil {
6505 t.Errorf("unexpected Dial error: %v", err)
6506 return nil, err
6507 }
6508 return &breakableConn{c, &brokenState}, err
6509 }
6510
6511 for i := 1; i <= numReqs; i++ {
6512 brokenState.Lock()
6513 brokenState.broken = false
6514 brokenState.Unlock()
6515
6516
6517
6518
6519 doBreak := i != numReqs
6520
6521 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6522 GotConn: func(info httptrace.GotConnInfo) {
6523 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6524 atomic.AddUint32(&gotConns, 1)
6525 },
6526 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6527 brokenState.Lock()
6528 defer brokenState.Unlock()
6529 if doBreak {
6530 brokenState.broken = true
6531 }
6532 },
6533 })
6534 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6535 if err != nil {
6536 t.Fatal(err)
6537 }
6538 _, err = cst.c.Do(req)
6539 if doBreak != (err != nil) {
6540 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6541 }
6542 }
6543 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6544 t.Errorf("GotConn calls = %v; want %v", got, want)
6545 }
6546 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6547 t.Errorf("Dials = %v; want %v", got, want)
6548 }
6549 }
6550
6551
6552
6553
6554
6555 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6556 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6557 }
6558 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6559 CondSkipHTTP2(t)
6560
6561 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6562 _, err := w.Write([]byte("foo"))
6563 if err != nil {
6564 t.Fatalf("Write: %v", err)
6565 }
6566 })
6567
6568 ts := newClientServerTest(t, mode, h).ts
6569
6570 c := ts.Client()
6571 tr := c.Transport.(*Transport)
6572 tr.MaxConnsPerHost = 1
6573
6574 errCh := make(chan error, 300)
6575 doReq := func() {
6576 resp, err := c.Get(ts.URL)
6577 if err != nil {
6578 errCh <- fmt.Errorf("request failed: %v", err)
6579 return
6580 }
6581 defer resp.Body.Close()
6582 _, err = io.ReadAll(resp.Body)
6583 if err != nil {
6584 errCh <- fmt.Errorf("read body failed: %v", err)
6585 }
6586 }
6587
6588 var wg sync.WaitGroup
6589 for i := 0; i < 300; i++ {
6590 wg.Add(1)
6591 go func() {
6592 defer wg.Done()
6593 doReq()
6594 }()
6595 }
6596 wg.Wait()
6597 close(errCh)
6598
6599 for err := range errCh {
6600 t.Errorf("error occurred: %v", err)
6601 }
6602 }
6603
6604
6605
6606
6607 func TestAltProtoCancellation(t *testing.T) {
6608 defer afterTest(t)
6609 tr := &Transport{}
6610 c := &Client{
6611 Transport: tr,
6612 Timeout: time.Millisecond,
6613 }
6614 tr.RegisterProtocol("cancel", cancelProto{})
6615 _, err := c.Get("cancel://bar.com/path")
6616 if err == nil {
6617 t.Error("request unexpectedly succeeded")
6618 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6619 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6620 }
6621 }
6622
6623 var errCancelProto = errors.New("canceled as expected")
6624
6625 type cancelProto struct{}
6626
6627 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6628 <-req.Cancel
6629 return nil, errCancelProto
6630 }
6631
6632 type roundTripFunc func(r *Request) (*Response, error)
6633
6634 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6635
6636
6637 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6638 func testIssue32441(t *testing.T, mode testMode) {
6639 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6640 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6641 t.Error("body length is zero")
6642 }
6643 })).ts
6644 c := ts.Client()
6645 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6646
6647 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6648 t.Error("body length is zero during round trip")
6649 }
6650 return nil, ErrSkipAltProtocol
6651 }))
6652 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6653 t.Error(err)
6654 }
6655 }
6656
6657
6658
6659 func TestTransportRejectsSignInContentLength(t *testing.T) {
6660 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6661 }
6662 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6663 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6664 w.Header().Set("Content-Length", "+3")
6665 w.Write([]byte("abc"))
6666 })).ts
6667
6668 c := cst.Client()
6669 res, err := c.Get(cst.URL)
6670 if err == nil || res != nil {
6671 t.Fatal("Expected a non-nil error and a nil http.Response")
6672 }
6673 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6674 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6675 }
6676 }
6677
6678
6679 type dumpConn struct {
6680 io.Writer
6681 io.Reader
6682 }
6683
6684 func (c *dumpConn) Close() error { return nil }
6685 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6686 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6687 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6688 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6689 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6690
6691
6692
6693 type delegateReader struct {
6694 c chan io.Reader
6695 r io.Reader
6696 }
6697
6698 func (r *delegateReader) Read(p []byte) (int, error) {
6699 if r.r == nil {
6700 var ok bool
6701 if r.r, ok = <-r.c; !ok {
6702 return 0, errors.New("delegate closed")
6703 }
6704 }
6705 return r.r.Read(p)
6706 }
6707
6708 func testTransportRace(req *Request) {
6709 save := req.Body
6710 pr, pw := io.Pipe()
6711 defer pr.Close()
6712 defer pw.Close()
6713 dr := &delegateReader{c: make(chan io.Reader)}
6714
6715 t := &Transport{
6716 Dial: func(net, addr string) (net.Conn, error) {
6717 return &dumpConn{pw, dr}, nil
6718 },
6719 }
6720 defer t.CloseIdleConnections()
6721
6722 quitReadCh := make(chan struct{})
6723
6724 go func() {
6725 defer close(quitReadCh)
6726
6727 req, err := ReadRequest(bufio.NewReader(pr))
6728 if err == nil {
6729
6730
6731 io.Copy(io.Discard, req.Body)
6732 req.Body.Close()
6733 }
6734 select {
6735 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6736 case quitReadCh <- struct{}{}:
6737
6738 close(dr.c)
6739 }
6740 }()
6741
6742 t.RoundTrip(req)
6743
6744
6745
6746 pw.Close()
6747 <-quitReadCh
6748
6749 req.Body = save
6750 }
6751
6752
6753
6754
6755
6756 func TestErrorWriteLoopRace(t *testing.T) {
6757 if testing.Short() {
6758 return
6759 }
6760 t.Parallel()
6761 for i := 0; i < 1000; i++ {
6762 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6763 ctx, cancel := context.WithTimeout(context.Background(), delay)
6764 defer cancel()
6765
6766 r := bytes.NewBuffer(make([]byte, 10000))
6767 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6768 if err != nil {
6769 t.Fatal(err)
6770 }
6771
6772 testTransportRace(req)
6773 }
6774 }
6775
6776
6777
6778
6779 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6780 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6781 }
6782 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6783 reqc := make(chan chan struct{}, 2)
6784 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6785 ch := make(chan struct{}, 1)
6786 reqc <- ch
6787 <-ch
6788 w.Header().Add("Content-Length", "0")
6789 })).ts
6790
6791 client := ts.Client()
6792 transport := client.Transport.(*Transport)
6793 transport.MaxIdleConns = 1
6794 transport.MaxConnsPerHost = 1
6795
6796 var wg sync.WaitGroup
6797
6798 wg.Add(1)
6799 putidlec := make(chan chan struct{}, 1)
6800 reqerrc := make(chan error, 1)
6801 go func() {
6802 defer wg.Done()
6803 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6804 PutIdleConn: func(error) {
6805
6806
6807 ch := make(chan struct{})
6808 putidlec <- ch
6809 close(putidlec)
6810 <-ch
6811 },
6812 })
6813 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6814 res, err := client.Do(req)
6815 reqerrc <- err
6816 if err == nil {
6817 res.Body.Close()
6818 }
6819 }()
6820
6821
6822
6823 r1c := <-reqc
6824 close(r1c)
6825 var idlec chan struct{}
6826 select {
6827 case err := <-reqerrc:
6828 if err != nil {
6829 t.Fatalf("request 1: got err %v, want nil", err)
6830 }
6831 idlec = <-putidlec
6832 case idlec = <-putidlec:
6833 }
6834
6835 wg.Add(1)
6836 cancelctx, cancel := context.WithCancel(context.Background())
6837 go func() {
6838 defer wg.Done()
6839 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6840 res, err := client.Do(req)
6841 if err == nil {
6842 res.Body.Close()
6843 }
6844 if !errors.Is(err, context.Canceled) {
6845 t.Errorf("request 2: got err %v, want Canceled", err)
6846 }
6847
6848
6849 close(idlec)
6850 }()
6851
6852
6853
6854 r2c := <-reqc
6855 cancel()
6856
6857 <-idlec
6858
6859 close(r2c)
6860 wg.Wait()
6861 }
6862
6863 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6864 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6865 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6866 go io.Copy(io.Discard, req.Body)
6867 panic(ErrAbortHandler)
6868 })).ts
6869
6870 var wg sync.WaitGroup
6871 for i := 0; i < 2; i++ {
6872 wg.Add(1)
6873 go func() {
6874 defer wg.Done()
6875 for j := 0; j < 10; j++ {
6876 const reqLen = 6 * 1024 * 1024
6877 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6878 req.ContentLength = reqLen
6879 resp, _ := ts.Client().Transport.RoundTrip(req)
6880 if resp != nil {
6881 resp.Body.Close()
6882 }
6883 }
6884 }()
6885 }
6886 wg.Wait()
6887 }
6888
6889 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
6890 func testRequestSanitization(t *testing.T, mode testMode) {
6891 if mode == http2Mode {
6892
6893 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
6894 }
6895 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6896 if h, ok := req.Header["X-Evil"]; ok {
6897 t.Errorf("request has X-Evil header: %q", h)
6898 }
6899 })).ts
6900 req, _ := NewRequest("GET", ts.URL, nil)
6901 req.Host = "go.dev\r\nX-Evil:evil"
6902 resp, _ := ts.Client().Do(req)
6903 if resp != nil {
6904 resp.Body.Close()
6905 }
6906 }
6907
6908 func TestProxyAuthHeader(t *testing.T) {
6909
6910 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
6911 }
6912 func testProxyAuthHeader(t *testing.T, mode testMode) {
6913 const username = "u"
6914 const password = "@/?!"
6915 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6916
6917
6918 var r2 Request
6919 r2.Header = Header{
6920 "Authorization": req.Header["Proxy-Authorization"],
6921 }
6922 gotuser, gotpass, ok := r2.BasicAuth()
6923 if !ok || gotuser != username || gotpass != password {
6924 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
6925 }
6926 }))
6927 u, err := url.Parse(cst.ts.URL)
6928 if err != nil {
6929 t.Fatal(err)
6930 }
6931 u.User = url.UserPassword(username, password)
6932 t.Setenv("HTTP_PROXY", u.String())
6933 cst.tr.Proxy = ProxyURL(u)
6934 resp, err := cst.c.Get("http://_/")
6935 if err != nil {
6936 t.Fatal(err)
6937 }
6938 resp.Body.Close()
6939 }
6940
View as plain text