1
2
3
4
5 package http2
6
7 import (
8 "bufio"
9 "bytes"
10 "compress/gzip"
11 "context"
12 "crypto/tls"
13 "encoding/hex"
14 "errors"
15 "flag"
16 "fmt"
17 "io"
18 "io/fs"
19 "io/ioutil"
20 "log"
21 "math/rand"
22 "net"
23 "net/http"
24 "net/http/httptest"
25 "net/http/httptrace"
26 "net/textproto"
27 "net/url"
28 "os"
29 "reflect"
30 "runtime"
31 "sort"
32 "strconv"
33 "strings"
34 "sync"
35 "sync/atomic"
36 "testing"
37 "time"
38
39 "golang.org/x/net/http2/hpack"
40 )
41
42 var (
43 extNet = flag.Bool("extnet", false, "do external network tests")
44 transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
45 insecure = flag.Bool("insecure", false, "insecure TLS dials")
46 )
47
48 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
49
50 var canceledCtx context.Context
51
52 func init() {
53 ctx, cancel := context.WithCancel(context.Background())
54 cancel()
55 canceledCtx = ctx
56 }
57
58 func TestTransportExternal(t *testing.T) {
59 if !*extNet {
60 t.Skip("skipping external network test")
61 }
62 req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
63 rt := &Transport{TLSClientConfig: tlsConfigInsecure}
64 res, err := rt.RoundTrip(req)
65 if err != nil {
66 t.Fatalf("%v", err)
67 }
68 res.Write(os.Stdout)
69 }
70
71 type fakeTLSConn struct {
72 net.Conn
73 }
74
75 func (c *fakeTLSConn) ConnectionState() tls.ConnectionState {
76 return tls.ConnectionState{
77 Version: tls.VersionTLS12,
78 CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
79 }
80 }
81
82 func startH2cServer(t *testing.T) net.Listener {
83 h2Server := &Server{}
84 l := newLocalListener(t)
85 go func() {
86 conn, err := l.Accept()
87 if err != nil {
88 t.Error(err)
89 return
90 }
91 h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
92 fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
93 })})
94 }()
95 return l
96 }
97
98 func TestTransportH2c(t *testing.T) {
99 l := startH2cServer(t)
100 defer l.Close()
101 req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil)
102 if err != nil {
103 t.Fatal(err)
104 }
105 var gotConnCnt int32
106 trace := &httptrace.ClientTrace{
107 GotConn: func(connInfo httptrace.GotConnInfo) {
108 if !connInfo.Reused {
109 atomic.AddInt32(&gotConnCnt, 1)
110 }
111 },
112 }
113 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
114 tr := &Transport{
115 AllowHTTP: true,
116 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
117 return net.Dial(network, addr)
118 },
119 }
120 res, err := tr.RoundTrip(req)
121 if err != nil {
122 t.Fatal(err)
123 }
124 if res.ProtoMajor != 2 {
125 t.Fatal("proto not h2c")
126 }
127 body, err := ioutil.ReadAll(res.Body)
128 if err != nil {
129 t.Fatal(err)
130 }
131 if got, want := string(body), "Hello, /foobar, http: true"; got != want {
132 t.Fatalf("response got %v, want %v", got, want)
133 }
134 if got, want := gotConnCnt, int32(1); got != want {
135 t.Errorf("Too many got connections: %d", gotConnCnt)
136 }
137 }
138
139 func TestTransport(t *testing.T) {
140 const body = "sup"
141 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
142 io.WriteString(w, body)
143 }, optOnlyServer)
144 defer st.Close()
145
146 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
147 defer tr.CloseIdleConnections()
148
149 u, err := url.Parse(st.ts.URL)
150 if err != nil {
151 t.Fatal(err)
152 }
153 for i, m := range []string{"GET", ""} {
154 req := &http.Request{
155 Method: m,
156 URL: u,
157 }
158 res, err := tr.RoundTrip(req)
159 if err != nil {
160 t.Fatalf("%d: %s", i, err)
161 }
162
163 t.Logf("%d: Got res: %+v", i, res)
164 if g, w := res.StatusCode, 200; g != w {
165 t.Errorf("%d: StatusCode = %v; want %v", i, g, w)
166 }
167 if g, w := res.Status, "200 OK"; g != w {
168 t.Errorf("%d: Status = %q; want %q", i, g, w)
169 }
170 wantHeader := http.Header{
171 "Content-Length": []string{"3"},
172 "Content-Type": []string{"text/plain; charset=utf-8"},
173 "Date": []string{"XXX"},
174 }
175 cleanDate(res)
176 if !reflect.DeepEqual(res.Header, wantHeader) {
177 t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader)
178 }
179 if res.Request != req {
180 t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req)
181 }
182 if res.TLS == nil {
183 t.Errorf("%d: Response.TLS = nil; want non-nil", i)
184 }
185 slurp, err := ioutil.ReadAll(res.Body)
186 if err != nil {
187 t.Errorf("%d: Body read: %v", i, err)
188 } else if string(slurp) != body {
189 t.Errorf("%d: Body = %q; want %q", i, slurp, body)
190 }
191 res.Body.Close()
192 }
193 }
194
195 func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq func(*http.Request)) {
196 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
197 io.WriteString(w, r.RemoteAddr)
198 }, optOnlyServer, func(c net.Conn, st http.ConnState) {
199 t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
200 })
201 defer st.Close()
202 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
203 if useClient {
204 tr.ConnPool = noDialClientConnPool{new(clientConnPool)}
205 }
206 defer tr.CloseIdleConnections()
207 get := func() string {
208 req, err := http.NewRequest("GET", st.ts.URL, nil)
209 if err != nil {
210 t.Fatal(err)
211 }
212 modReq(req)
213 var res *http.Response
214 if useClient {
215 c := st.ts.Client()
216 ConfigureTransports(c.Transport.(*http.Transport))
217 res, err = c.Do(req)
218 } else {
219 res, err = tr.RoundTrip(req)
220 }
221 if err != nil {
222 t.Fatal(err)
223 }
224 defer res.Body.Close()
225 slurp, err := ioutil.ReadAll(res.Body)
226 if err != nil {
227 t.Fatalf("Body read: %v", err)
228 }
229 addr := strings.TrimSpace(string(slurp))
230 if addr == "" {
231 t.Fatalf("didn't get an addr in response")
232 }
233 return addr
234 }
235 first := get()
236 second := get()
237 if got := first == second; got != wantSame {
238 t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame)
239 }
240 }
241
242 func TestTransportReusesConns(t *testing.T) {
243 for _, test := range []struct {
244 name string
245 modReq func(*http.Request)
246 wantSame bool
247 }{{
248 name: "ReuseConn",
249 modReq: func(*http.Request) {},
250 wantSame: true,
251 }, {
252 name: "RequestClose",
253 modReq: func(r *http.Request) { r.Close = true },
254 wantSame: false,
255 }, {
256 name: "ConnClose",
257 modReq: func(r *http.Request) { r.Header.Set("Connection", "close") },
258 wantSame: false,
259 }} {
260 t.Run(test.name, func(t *testing.T) {
261 t.Run("Transport", func(t *testing.T) {
262 const useClient = false
263 testTransportReusesConns(t, useClient, test.wantSame, test.modReq)
264 })
265 t.Run("Client", func(t *testing.T) {
266 const useClient = true
267 testTransportReusesConns(t, useClient, test.wantSame, test.modReq)
268 })
269 })
270 }
271 }
272
273 func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
274 testTransportGetGotConnHooks(t, false)
275 }
276 func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
277
278 func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
279 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
280 io.WriteString(w, r.RemoteAddr)
281 }, func(s *httptest.Server) {
282 s.EnableHTTP2 = true
283 }, optOnlyServer)
284 defer st.Close()
285
286 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
287 client := st.ts.Client()
288 ConfigureTransports(client.Transport.(*http.Transport))
289
290 var (
291 getConns int32
292 gotConns int32
293 )
294 for i := 0; i < 2; i++ {
295 trace := &httptrace.ClientTrace{
296 GetConn: func(hostport string) {
297 atomic.AddInt32(&getConns, 1)
298 },
299 GotConn: func(connInfo httptrace.GotConnInfo) {
300 got := atomic.AddInt32(&gotConns, 1)
301 wantReused, wantWasIdle := false, false
302 if got > 1 {
303 wantReused, wantWasIdle = true, true
304 }
305 if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle {
306 t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle)
307 }
308 },
309 }
310 req, err := http.NewRequest("GET", st.ts.URL, nil)
311 if err != nil {
312 t.Fatal(err)
313 }
314 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
315
316 var res *http.Response
317 if useClient {
318 res, err = client.Do(req)
319 } else {
320 res, err = tr.RoundTrip(req)
321 }
322 if err != nil {
323 t.Fatal(err)
324 }
325 res.Body.Close()
326 if get := atomic.LoadInt32(&getConns); get != int32(i+1) {
327 t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1)
328 }
329 if got := atomic.LoadInt32(&gotConns); got != int32(i+1) {
330 t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1)
331 }
332 }
333 }
334
335 type testNetConn struct {
336 net.Conn
337 closed bool
338 onClose func()
339 }
340
341 func (c *testNetConn) Close() error {
342 if !c.closed {
343
344 c.onClose()
345 }
346 c.closed = true
347 return c.Conn.Close()
348 }
349
350
351
352 func TestTransportGroupsPendingDials(t *testing.T) {
353 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
354 }, optOnlyServer)
355 defer st.Close()
356 var (
357 mu sync.Mutex
358 dialCount int
359 closeCount int
360 )
361 tr := &Transport{
362 TLSClientConfig: tlsConfigInsecure,
363 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
364 mu.Lock()
365 dialCount++
366 mu.Unlock()
367 c, err := tls.Dial(network, addr, cfg)
368 return &testNetConn{
369 Conn: c,
370 onClose: func() {
371 mu.Lock()
372 closeCount++
373 mu.Unlock()
374 },
375 }, err
376 },
377 }
378 defer tr.CloseIdleConnections()
379 var wg sync.WaitGroup
380 for i := 0; i < 10; i++ {
381 wg.Add(1)
382 go func() {
383 defer wg.Done()
384 req, err := http.NewRequest("GET", st.ts.URL, nil)
385 if err != nil {
386 t.Error(err)
387 return
388 }
389 res, err := tr.RoundTrip(req)
390 if err != nil {
391 t.Error(err)
392 return
393 }
394 res.Body.Close()
395 }()
396 }
397 wg.Wait()
398 tr.CloseIdleConnections()
399 if dialCount != 1 {
400 t.Errorf("saw %d dials; want 1", dialCount)
401 }
402 if closeCount != 1 {
403 t.Errorf("saw %d closes; want 1", closeCount)
404 }
405 }
406
407 func retry(tries int, delay time.Duration, fn func() error) error {
408 var err error
409 for i := 0; i < tries; i++ {
410 err = fn()
411 if err == nil {
412 return nil
413 }
414 time.Sleep(delay)
415 }
416 return err
417 }
418
419 func TestTransportAbortClosesPipes(t *testing.T) {
420 shutdown := make(chan struct{})
421 st := newServerTester(t,
422 func(w http.ResponseWriter, r *http.Request) {
423 w.(http.Flusher).Flush()
424 <-shutdown
425 },
426 optOnlyServer,
427 )
428 defer st.Close()
429 defer close(shutdown)
430
431 errCh := make(chan error)
432 go func() {
433 defer close(errCh)
434 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
435 req, err := http.NewRequest("GET", st.ts.URL, nil)
436 if err != nil {
437 errCh <- err
438 return
439 }
440 res, err := tr.RoundTrip(req)
441 if err != nil {
442 errCh <- err
443 return
444 }
445 defer res.Body.Close()
446 st.closeConn()
447 _, err = ioutil.ReadAll(res.Body)
448 if err == nil {
449 errCh <- errors.New("expected error from res.Body.Read")
450 return
451 }
452 }()
453
454 select {
455 case err := <-errCh:
456 if err != nil {
457 t.Fatal(err)
458 }
459
460 case <-time.After(3 * time.Second):
461 t.Fatal("timeout")
462 }
463 }
464
465
466
467 func TestTransportPath(t *testing.T) {
468 gotc := make(chan *url.URL, 1)
469 st := newServerTester(t,
470 func(w http.ResponseWriter, r *http.Request) {
471 gotc <- r.URL
472 },
473 optOnlyServer,
474 )
475 defer st.Close()
476
477 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
478 defer tr.CloseIdleConnections()
479 const (
480 path = "/testpath"
481 query = "q=1"
482 )
483 surl := st.ts.URL + path + "?" + query
484 req, err := http.NewRequest("POST", surl, nil)
485 if err != nil {
486 t.Fatal(err)
487 }
488 c := &http.Client{Transport: tr}
489 res, err := c.Do(req)
490 if err != nil {
491 t.Fatal(err)
492 }
493 defer res.Body.Close()
494 got := <-gotc
495 if got.Path != path {
496 t.Errorf("Read Path = %q; want %q", got.Path, path)
497 }
498 if got.RawQuery != query {
499 t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
500 }
501 }
502
503 func randString(n int) string {
504 rnd := rand.New(rand.NewSource(int64(n)))
505 b := make([]byte, n)
506 for i := range b {
507 b[i] = byte(rnd.Intn(256))
508 }
509 return string(b)
510 }
511
512 type panicReader struct{}
513
514 func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") }
515 func (panicReader) Close() error { panic("unexpected Close") }
516
517 func TestActualContentLength(t *testing.T) {
518 tests := []struct {
519 req *http.Request
520 want int64
521 }{
522
523 0: {
524 req: &http.Request{Body: panicReader{}},
525 want: -1,
526 },
527
528 1: {
529 req: &http.Request{Body: nil, ContentLength: 5},
530 want: 0,
531 },
532
533 2: {
534 req: &http.Request{Body: panicReader{}, ContentLength: 5},
535 want: 5,
536 },
537
538 3: {
539 req: &http.Request{Body: http.NoBody},
540 want: 0,
541 },
542 }
543 for i, tt := range tests {
544 got := actualContentLength(tt.req)
545 if got != tt.want {
546 t.Errorf("test[%d]: got %d; want %d", i, got, tt.want)
547 }
548 }
549 }
550
551 func TestTransportBody(t *testing.T) {
552 bodyTests := []struct {
553 body string
554 noContentLen bool
555 }{
556 {body: "some message"},
557 {body: "some message", noContentLen: true},
558 {body: strings.Repeat("a", 1<<20), noContentLen: true},
559 {body: strings.Repeat("a", 1<<20)},
560 {body: randString(16<<10 - 1)},
561 {body: randString(16 << 10)},
562 {body: randString(16<<10 + 1)},
563 {body: randString(512<<10 - 1)},
564 {body: randString(512 << 10)},
565 {body: randString(512<<10 + 1)},
566 {body: randString(1<<20 - 1)},
567 {body: randString(1 << 20)},
568 {body: randString(1<<20 + 2)},
569 }
570
571 type reqInfo struct {
572 req *http.Request
573 slurp []byte
574 err error
575 }
576 gotc := make(chan reqInfo, 1)
577 st := newServerTester(t,
578 func(w http.ResponseWriter, r *http.Request) {
579 slurp, err := ioutil.ReadAll(r.Body)
580 if err != nil {
581 gotc <- reqInfo{err: err}
582 } else {
583 gotc <- reqInfo{req: r, slurp: slurp}
584 }
585 },
586 optOnlyServer,
587 )
588 defer st.Close()
589
590 for i, tt := range bodyTests {
591 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
592 defer tr.CloseIdleConnections()
593
594 var body io.Reader = strings.NewReader(tt.body)
595 if tt.noContentLen {
596 body = struct{ io.Reader }{body}
597 }
598 req, err := http.NewRequest("POST", st.ts.URL, body)
599 if err != nil {
600 t.Fatalf("#%d: %v", i, err)
601 }
602 c := &http.Client{Transport: tr}
603 res, err := c.Do(req)
604 if err != nil {
605 t.Fatalf("#%d: %v", i, err)
606 }
607 defer res.Body.Close()
608 ri := <-gotc
609 if ri.err != nil {
610 t.Errorf("#%d: read error: %v", i, ri.err)
611 continue
612 }
613 if got := string(ri.slurp); got != tt.body {
614 t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
615 }
616 wantLen := int64(len(tt.body))
617 if tt.noContentLen && tt.body != "" {
618 wantLen = -1
619 }
620 if ri.req.ContentLength != wantLen {
621 t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
622 }
623 }
624 }
625
626 func shortString(v string) string {
627 const maxLen = 100
628 if len(v) <= maxLen {
629 return v
630 }
631 return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
632 }
633
634 func TestTransportDialTLS(t *testing.T) {
635 var mu sync.Mutex
636 var gotReq, didDial bool
637
638 ts := newServerTester(t,
639 func(w http.ResponseWriter, r *http.Request) {
640 mu.Lock()
641 gotReq = true
642 mu.Unlock()
643 },
644 optOnlyServer,
645 )
646 defer ts.Close()
647 tr := &Transport{
648 DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
649 mu.Lock()
650 didDial = true
651 mu.Unlock()
652 cfg.InsecureSkipVerify = true
653 c, err := tls.Dial(netw, addr, cfg)
654 if err != nil {
655 return nil, err
656 }
657 return c, c.Handshake()
658 },
659 }
660 defer tr.CloseIdleConnections()
661 client := &http.Client{Transport: tr}
662 res, err := client.Get(ts.ts.URL)
663 if err != nil {
664 t.Fatal(err)
665 }
666 res.Body.Close()
667 mu.Lock()
668 if !gotReq {
669 t.Error("didn't get request")
670 }
671 if !didDial {
672 t.Error("didn't use dial hook")
673 }
674 }
675
676 func TestConfigureTransport(t *testing.T) {
677 t1 := &http.Transport{}
678 err := ConfigureTransport(t1)
679 if err != nil {
680 t.Fatal(err)
681 }
682 if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) {
683
684 t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
685 }
686 wantNextProtos := []string{"h2", "http/1.1"}
687 if t1.TLSClientConfig == nil {
688 t.Errorf("nil t1.TLSClientConfig")
689 } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
690 t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
691 }
692 if err := ConfigureTransport(t1); err == nil {
693 t.Error("unexpected success on second call to ConfigureTransport")
694 }
695
696
697 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
698 io.WriteString(w, r.Proto)
699 }, optOnlyServer)
700 defer st.Close()
701
702 t1.TLSClientConfig.InsecureSkipVerify = true
703 c := &http.Client{Transport: t1}
704 res, err := c.Get(st.ts.URL)
705 if err != nil {
706 t.Fatal(err)
707 }
708 slurp, err := ioutil.ReadAll(res.Body)
709 if err != nil {
710 t.Fatal(err)
711 }
712 if got, want := string(slurp), "HTTP/2.0"; got != want {
713 t.Errorf("body = %q; want %q", got, want)
714 }
715 }
716
717 type capitalizeReader struct {
718 r io.Reader
719 }
720
721 func (cr capitalizeReader) Read(p []byte) (n int, err error) {
722 n, err = cr.r.Read(p)
723 for i, b := range p[:n] {
724 if b >= 'a' && b <= 'z' {
725 p[i] = b - ('a' - 'A')
726 }
727 }
728 return
729 }
730
731 type flushWriter struct {
732 w io.Writer
733 }
734
735 func (fw flushWriter) Write(p []byte) (n int, err error) {
736 n, err = fw.w.Write(p)
737 if f, ok := fw.w.(http.Flusher); ok {
738 f.Flush()
739 }
740 return
741 }
742
743 type clientTester struct {
744 t *testing.T
745 tr *Transport
746 sc, cc net.Conn
747 fr *Framer
748 settings *SettingsFrame
749 client func() error
750 server func() error
751 }
752
753 func newClientTester(t *testing.T) *clientTester {
754 var dialOnce struct {
755 sync.Mutex
756 dialed bool
757 }
758 ct := &clientTester{
759 t: t,
760 }
761 ct.tr = &Transport{
762 TLSClientConfig: tlsConfigInsecure,
763 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
764 dialOnce.Lock()
765 defer dialOnce.Unlock()
766 if dialOnce.dialed {
767 return nil, errors.New("only one dial allowed in test mode")
768 }
769 dialOnce.dialed = true
770 return ct.cc, nil
771 },
772 }
773
774 ln := newLocalListener(t)
775 cc, err := net.Dial("tcp", ln.Addr().String())
776 if err != nil {
777 t.Fatal(err)
778 }
779 sc, err := ln.Accept()
780 if err != nil {
781 t.Fatal(err)
782 }
783 ln.Close()
784 ct.cc = cc
785 ct.sc = sc
786 ct.fr = NewFramer(sc, sc)
787 return ct
788 }
789
790 func newLocalListener(t *testing.T) net.Listener {
791 ln, err := net.Listen("tcp4", "127.0.0.1:0")
792 if err == nil {
793 return ln
794 }
795 ln, err = net.Listen("tcp6", "[::1]:0")
796 if err != nil {
797 t.Fatal(err)
798 }
799 return ln
800 }
801
802 func (ct *clientTester) greet(settings ...Setting) {
803 buf := make([]byte, len(ClientPreface))
804 _, err := io.ReadFull(ct.sc, buf)
805 if err != nil {
806 ct.t.Fatalf("reading client preface: %v", err)
807 }
808 f, err := ct.fr.ReadFrame()
809 if err != nil {
810 ct.t.Fatalf("Reading client settings frame: %v", err)
811 }
812 var ok bool
813 if ct.settings, ok = f.(*SettingsFrame); !ok {
814 ct.t.Fatalf("Wanted client settings frame; got %v", f)
815 }
816 if err := ct.fr.WriteSettings(settings...); err != nil {
817 ct.t.Fatal(err)
818 }
819 if err := ct.fr.WriteSettingsAck(); err != nil {
820 ct.t.Fatal(err)
821 }
822 }
823
824 func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
825 for {
826 f, err := ct.fr.ReadFrame()
827 if err != nil {
828 return nil, err
829 }
830 if _, ok := f.(*SettingsFrame); ok {
831 continue
832 }
833 return f, nil
834 }
835 }
836
837
838
839
840 func (ct *clientTester) writeReadPing() error {
841 data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
842 if err := ct.fr.WritePing(false, data); err != nil {
843 return fmt.Errorf("Error writing PING: %v", err)
844 }
845 f, err := ct.readNonSettingsFrame()
846 if err != nil {
847 return err
848 }
849 p, ok := f.(*PingFrame)
850 if !ok {
851 return fmt.Errorf("got a %v, want a PING ACK", f)
852 }
853 if p.Flags&FlagPingAck == 0 {
854 return fmt.Errorf("got a PING, want a PING ACK")
855 }
856 if p.Data != data {
857 return fmt.Errorf("got PING data = %x, want %x", p.Data, data)
858 }
859 return nil
860 }
861
862 func (ct *clientTester) inflowWindow(streamID uint32) int32 {
863 pool := ct.tr.connPoolOrDef.(*clientConnPool)
864 pool.mu.Lock()
865 defer pool.mu.Unlock()
866 if n := len(pool.keys); n != 1 {
867 ct.t.Errorf("clientConnPool contains %v keys, expected 1", n)
868 return -1
869 }
870 for cc := range pool.keys {
871 cc.mu.Lock()
872 defer cc.mu.Unlock()
873 if streamID == 0 {
874 return cc.inflow.avail + cc.inflow.unsent
875 }
876 cs := cc.streams[streamID]
877 if cs == nil {
878 ct.t.Errorf("no stream with id %v", streamID)
879 return -1
880 }
881 return cs.inflow.avail + cs.inflow.unsent
882 }
883 return -1
884 }
885
886 func (ct *clientTester) cleanup() {
887 ct.tr.CloseIdleConnections()
888
889
890 ct.sc.Close()
891 ct.cc.Close()
892 }
893
894 func (ct *clientTester) run() {
895 var errOnce sync.Once
896 var wg sync.WaitGroup
897
898 run := func(which string, fn func() error) {
899 defer wg.Done()
900 if err := fn(); err != nil {
901 errOnce.Do(func() {
902 ct.t.Errorf("%s: %v", which, err)
903 ct.cleanup()
904 })
905 }
906 }
907
908 wg.Add(2)
909 go run("client", ct.client)
910 go run("server", ct.server)
911 wg.Wait()
912
913 errOnce.Do(ct.cleanup)
914 }
915
916 func (ct *clientTester) readFrame() (Frame, error) {
917 return ct.fr.ReadFrame()
918 }
919
920 func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
921 for {
922 f, err := ct.readFrame()
923 if err != nil {
924 return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
925 }
926 switch f.(type) {
927 case *WindowUpdateFrame, *SettingsFrame:
928 continue
929 }
930 hf, ok := f.(*HeadersFrame)
931 if !ok {
932 return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
933 }
934 return hf, nil
935 }
936 }
937
938 type countingReader struct {
939 n *int64
940 }
941
942 func (r countingReader) Read(p []byte) (n int, err error) {
943 for i := range p {
944 p[i] = byte(i)
945 }
946 atomic.AddInt64(r.n, int64(len(p)))
947 return len(p), err
948 }
949
950 func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
951 func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
952
953 func testTransportReqBodyAfterResponse(t *testing.T, status int) {
954 const bodySize = 10 << 20
955 clientDone := make(chan struct{})
956 ct := newClientTester(t)
957 recvLen := make(chan int64, 1)
958 ct.client = func() error {
959 defer ct.cc.(*net.TCPConn).CloseWrite()
960 if runtime.GOOS == "plan9" {
961
962 defer ct.cc.(*net.TCPConn).Close()
963 }
964 defer close(clientDone)
965
966 body := &pipe{b: new(bytes.Buffer)}
967 io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
968 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
969 if err != nil {
970 return err
971 }
972 res, err := ct.tr.RoundTrip(req)
973 if err != nil {
974 return fmt.Errorf("RoundTrip: %v", err)
975 }
976 if res.StatusCode != status {
977 return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
978 }
979 io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
980 body.CloseWithError(io.EOF)
981 slurp, err := ioutil.ReadAll(res.Body)
982 if err != nil {
983 return fmt.Errorf("Slurp: %v", err)
984 }
985 if len(slurp) > 0 {
986 return fmt.Errorf("unexpected body: %q", slurp)
987 }
988 res.Body.Close()
989 if status == 200 {
990 if got := <-recvLen; got != bodySize {
991 return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
992 }
993 } else {
994 if got := <-recvLen; got == 0 || got >= bodySize {
995 return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
996 }
997 }
998 return nil
999 }
1000 ct.server = func() error {
1001 ct.greet()
1002 defer close(recvLen)
1003 var buf bytes.Buffer
1004 enc := hpack.NewEncoder(&buf)
1005 var dataRecv int64
1006 var closed bool
1007 for {
1008 f, err := ct.fr.ReadFrame()
1009 if err != nil {
1010 select {
1011 case <-clientDone:
1012
1013
1014
1015 return nil
1016 default:
1017 return err
1018 }
1019 }
1020
1021 ended := false
1022 switch f := f.(type) {
1023 case *WindowUpdateFrame, *SettingsFrame:
1024 case *HeadersFrame:
1025 if !f.HeadersEnded() {
1026 return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
1027 }
1028 if f.StreamEnded() {
1029 return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
1030 }
1031 case *DataFrame:
1032 dataLen := len(f.Data())
1033 if dataLen > 0 {
1034 if dataRecv == 0 {
1035 enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
1036 ct.fr.WriteHeaders(HeadersFrameParam{
1037 StreamID: f.StreamID,
1038 EndHeaders: true,
1039 EndStream: false,
1040 BlockFragment: buf.Bytes(),
1041 })
1042 }
1043 if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
1044 return err
1045 }
1046 if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
1047 return err
1048 }
1049 }
1050 dataRecv += int64(dataLen)
1051
1052 if !closed && ((status != 200 && dataRecv > 0) ||
1053 (status == 200 && f.StreamEnded())) {
1054 closed = true
1055 if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
1056 return err
1057 }
1058 }
1059
1060 if f.StreamEnded() {
1061 ended = true
1062 }
1063 case *RSTStreamFrame:
1064 if status == 200 {
1065 return fmt.Errorf("Unexpected client frame %v", f)
1066 }
1067 ended = true
1068 default:
1069 return fmt.Errorf("Unexpected client frame %v", f)
1070 }
1071 if ended {
1072 select {
1073 case recvLen <- dataRecv:
1074 default:
1075 }
1076 }
1077 }
1078 }
1079 ct.run()
1080 }
1081
1082
1083 func TestTransportFullDuplex(t *testing.T) {
1084 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1085 w.WriteHeader(200)
1086 w.(http.Flusher).Flush()
1087 io.Copy(flushWriter{w}, capitalizeReader{r.Body})
1088 fmt.Fprintf(w, "bye.\n")
1089 }, optOnlyServer)
1090 defer st.Close()
1091
1092 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1093 defer tr.CloseIdleConnections()
1094 c := &http.Client{Transport: tr}
1095
1096 pr, pw := io.Pipe()
1097 req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
1098 if err != nil {
1099 t.Fatal(err)
1100 }
1101 req.ContentLength = -1
1102 res, err := c.Do(req)
1103 if err != nil {
1104 t.Fatal(err)
1105 }
1106 defer res.Body.Close()
1107 if res.StatusCode != 200 {
1108 t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
1109 }
1110 bs := bufio.NewScanner(res.Body)
1111 want := func(v string) {
1112 if !bs.Scan() {
1113 t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
1114 }
1115 }
1116 write := func(v string) {
1117 _, err := io.WriteString(pw, v)
1118 if err != nil {
1119 t.Fatalf("pipe write: %v", err)
1120 }
1121 }
1122 write("foo\n")
1123 want("FOO")
1124 write("bar\n")
1125 want("BAR")
1126 pw.Close()
1127 want("bye.")
1128 if err := bs.Err(); err != nil {
1129 t.Fatal(err)
1130 }
1131 }
1132
1133 func TestTransportConnectRequest(t *testing.T) {
1134 gotc := make(chan *http.Request, 1)
1135 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1136 gotc <- r
1137 }, optOnlyServer)
1138 defer st.Close()
1139
1140 u, err := url.Parse(st.ts.URL)
1141 if err != nil {
1142 t.Fatal(err)
1143 }
1144
1145 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1146 defer tr.CloseIdleConnections()
1147 c := &http.Client{Transport: tr}
1148
1149 tests := []struct {
1150 req *http.Request
1151 want string
1152 }{
1153 {
1154 req: &http.Request{
1155 Method: "CONNECT",
1156 Header: http.Header{},
1157 URL: u,
1158 },
1159 want: u.Host,
1160 },
1161 {
1162 req: &http.Request{
1163 Method: "CONNECT",
1164 Header: http.Header{},
1165 URL: u,
1166 Host: "example.com:123",
1167 },
1168 want: "example.com:123",
1169 },
1170 }
1171
1172 for i, tt := range tests {
1173 res, err := c.Do(tt.req)
1174 if err != nil {
1175 t.Errorf("%d. RoundTrip = %v", i, err)
1176 continue
1177 }
1178 res.Body.Close()
1179 req := <-gotc
1180 if req.Method != "CONNECT" {
1181 t.Errorf("method = %q; want CONNECT", req.Method)
1182 }
1183 if req.Host != tt.want {
1184 t.Errorf("Host = %q; want %q", req.Host, tt.want)
1185 }
1186 if req.URL.Host != tt.want {
1187 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
1188 }
1189 }
1190 }
1191
1192 type headerType int
1193
1194 const (
1195 noHeader headerType = iota
1196 oneHeader
1197 splitHeader
1198 )
1199
1200 const (
1201 f0 = noHeader
1202 f1 = oneHeader
1203 f2 = splitHeader
1204 d0 = false
1205 d1 = true
1206 )
1207
1208
1209
1210
1211
1212
1213 func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
1214 func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
1215 func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
1216 func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
1217 func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
1218 func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
1219 func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
1220 func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
1221 func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
1222 func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
1223 func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
1224 func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
1225 func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
1226 func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
1227 func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
1228 func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
1229 func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
1230 func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
1231 func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
1232 func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
1233 func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
1234 func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
1235 func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
1236 func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
1237 func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
1238 func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
1239 func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
1240 func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
1241 func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
1242 func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
1243 func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
1244 func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
1245 func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
1246 func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
1247 func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
1248 func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
1249
1250 func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
1251 const reqBody = "some request body"
1252 const resBody = "some response body"
1253
1254 if resHeader == noHeader {
1255
1256
1257 panic("invalid combination")
1258 }
1259
1260 ct := newClientTester(t)
1261 ct.client = func() error {
1262 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
1263 if expect100Continue != noHeader {
1264 req.Header.Set("Expect", "100-continue")
1265 }
1266 res, err := ct.tr.RoundTrip(req)
1267 if err != nil {
1268 return fmt.Errorf("RoundTrip: %v", err)
1269 }
1270 defer res.Body.Close()
1271 if res.StatusCode != 200 {
1272 return fmt.Errorf("status code = %v; want 200", res.StatusCode)
1273 }
1274 slurp, err := ioutil.ReadAll(res.Body)
1275 if err != nil {
1276 return fmt.Errorf("Slurp: %v", err)
1277 }
1278 wantBody := resBody
1279 if !withData {
1280 wantBody = ""
1281 }
1282 if string(slurp) != wantBody {
1283 return fmt.Errorf("body = %q; want %q", slurp, wantBody)
1284 }
1285 if trailers == noHeader {
1286 if len(res.Trailer) > 0 {
1287 t.Errorf("Trailer = %v; want none", res.Trailer)
1288 }
1289 } else {
1290 want := http.Header{"Some-Trailer": {"some-value"}}
1291 if !reflect.DeepEqual(res.Trailer, want) {
1292 t.Errorf("Trailer = %v; want %v", res.Trailer, want)
1293 }
1294 }
1295 return nil
1296 }
1297 ct.server = func() error {
1298 ct.greet()
1299 var buf bytes.Buffer
1300 enc := hpack.NewEncoder(&buf)
1301
1302 for {
1303 f, err := ct.fr.ReadFrame()
1304 if err != nil {
1305 return err
1306 }
1307 endStream := false
1308 send := func(mode headerType) {
1309 hbf := buf.Bytes()
1310 switch mode {
1311 case oneHeader:
1312 ct.fr.WriteHeaders(HeadersFrameParam{
1313 StreamID: f.Header().StreamID,
1314 EndHeaders: true,
1315 EndStream: endStream,
1316 BlockFragment: hbf,
1317 })
1318 case splitHeader:
1319 if len(hbf) < 2 {
1320 panic("too small")
1321 }
1322 ct.fr.WriteHeaders(HeadersFrameParam{
1323 StreamID: f.Header().StreamID,
1324 EndHeaders: false,
1325 EndStream: endStream,
1326 BlockFragment: hbf[:1],
1327 })
1328 ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
1329 default:
1330 panic("bogus mode")
1331 }
1332 }
1333 switch f := f.(type) {
1334 case *WindowUpdateFrame, *SettingsFrame:
1335 case *DataFrame:
1336 if !f.StreamEnded() {
1337
1338 continue
1339 }
1340
1341 {
1342 buf.Reset()
1343 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1344 enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
1345 enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
1346 if trailers != noHeader {
1347 enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
1348 }
1349 endStream = withData == false && trailers == noHeader
1350 send(resHeader)
1351 }
1352 if withData {
1353 endStream = trailers == noHeader
1354 ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
1355 }
1356 if trailers != noHeader {
1357 endStream = true
1358 buf.Reset()
1359 enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
1360 send(trailers)
1361 }
1362 if endStream {
1363 return nil
1364 }
1365 case *HeadersFrame:
1366 if expect100Continue != noHeader {
1367 buf.Reset()
1368 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
1369 send(expect100Continue)
1370 }
1371 }
1372 }
1373 }
1374 ct.run()
1375 }
1376
1377
1378 func TestTransportUnknown1xx(t *testing.T) {
1379 var buf bytes.Buffer
1380 defer func() { got1xxFuncForTests = nil }()
1381 got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error {
1382 fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
1383 return nil
1384 }
1385
1386 ct := newClientTester(t)
1387 ct.client = func() error {
1388 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1389 res, err := ct.tr.RoundTrip(req)
1390 if err != nil {
1391 return fmt.Errorf("RoundTrip: %v", err)
1392 }
1393 defer res.Body.Close()
1394 if res.StatusCode != 204 {
1395 return fmt.Errorf("status code = %v; want 204", res.StatusCode)
1396 }
1397 want := `code=110 header=map[Foo-Bar:[110]]
1398 code=111 header=map[Foo-Bar:[111]]
1399 code=112 header=map[Foo-Bar:[112]]
1400 code=113 header=map[Foo-Bar:[113]]
1401 code=114 header=map[Foo-Bar:[114]]
1402 `
1403 if got := buf.String(); got != want {
1404 t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
1405 }
1406 return nil
1407 }
1408 ct.server = func() error {
1409 ct.greet()
1410 var buf bytes.Buffer
1411 enc := hpack.NewEncoder(&buf)
1412
1413 for {
1414 f, err := ct.fr.ReadFrame()
1415 if err != nil {
1416 return err
1417 }
1418 switch f := f.(type) {
1419 case *WindowUpdateFrame, *SettingsFrame:
1420 case *HeadersFrame:
1421 for i := 110; i <= 114; i++ {
1422 buf.Reset()
1423 enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)})
1424 enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)})
1425 ct.fr.WriteHeaders(HeadersFrameParam{
1426 StreamID: f.StreamID,
1427 EndHeaders: true,
1428 EndStream: false,
1429 BlockFragment: buf.Bytes(),
1430 })
1431 }
1432 buf.Reset()
1433 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
1434 ct.fr.WriteHeaders(HeadersFrameParam{
1435 StreamID: f.StreamID,
1436 EndHeaders: true,
1437 EndStream: false,
1438 BlockFragment: buf.Bytes(),
1439 })
1440 return nil
1441 }
1442 }
1443 }
1444 ct.run()
1445
1446 }
1447
1448 func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
1449 ct := newClientTester(t)
1450 ct.client = func() error {
1451 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1452 res, err := ct.tr.RoundTrip(req)
1453 if err != nil {
1454 return fmt.Errorf("RoundTrip: %v", err)
1455 }
1456 defer res.Body.Close()
1457 if res.StatusCode != 200 {
1458 return fmt.Errorf("status code = %v; want 200", res.StatusCode)
1459 }
1460 slurp, err := ioutil.ReadAll(res.Body)
1461 if err != nil {
1462 return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
1463 }
1464 if len(slurp) > 0 {
1465 return fmt.Errorf("body = %q; want nothing", slurp)
1466 }
1467 if _, ok := res.Trailer["Some-Trailer"]; !ok {
1468 return fmt.Errorf("expected Some-Trailer")
1469 }
1470 return nil
1471 }
1472 ct.server = func() error {
1473 ct.greet()
1474
1475 var n int
1476 var hf *HeadersFrame
1477 for hf == nil && n < 10 {
1478 f, err := ct.fr.ReadFrame()
1479 if err != nil {
1480 return err
1481 }
1482 hf, _ = f.(*HeadersFrame)
1483 n++
1484 }
1485
1486 var buf bytes.Buffer
1487 enc := hpack.NewEncoder(&buf)
1488
1489
1490 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1491 ct.fr.WriteHeaders(HeadersFrameParam{
1492 StreamID: hf.StreamID,
1493 EndHeaders: true,
1494 EndStream: false,
1495 BlockFragment: buf.Bytes(),
1496 })
1497
1498
1499 buf.Reset()
1500 enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
1501 ct.fr.WriteHeaders(HeadersFrameParam{
1502 StreamID: hf.StreamID,
1503 EndHeaders: true,
1504 EndStream: true,
1505 BlockFragment: buf.Bytes(),
1506 })
1507 return nil
1508 }
1509 ct.run()
1510 }
1511
1512 func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
1513 testTransportInvalidTrailer_Pseudo(t, oneHeader)
1514 }
1515 func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
1516 testTransportInvalidTrailer_Pseudo(t, splitHeader)
1517 }
1518 func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
1519 testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
1520 enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
1521 enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
1522 })
1523 }
1524
1525 func TestTransportInvalidTrailer_Capital1(t *testing.T) {
1526 testTransportInvalidTrailer_Capital(t, oneHeader)
1527 }
1528 func TestTransportInvalidTrailer_Capital2(t *testing.T) {
1529 testTransportInvalidTrailer_Capital(t, splitHeader)
1530 }
1531 func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
1532 testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
1533 enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
1534 enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
1535 })
1536 }
1537 func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
1538 testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
1539 enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
1540 })
1541 }
1542 func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
1543 testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) {
1544 enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
1545 })
1546 }
1547
1548 func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
1549 ct := newClientTester(t)
1550 ct.client = func() error {
1551 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1552 res, err := ct.tr.RoundTrip(req)
1553 if err != nil {
1554 return fmt.Errorf("RoundTrip: %v", err)
1555 }
1556 defer res.Body.Close()
1557 if res.StatusCode != 200 {
1558 return fmt.Errorf("status code = %v; want 200", res.StatusCode)
1559 }
1560 slurp, err := ioutil.ReadAll(res.Body)
1561 se, ok := err.(StreamError)
1562 if !ok || se.Cause != wantErr {
1563 return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
1564 }
1565 if len(slurp) > 0 {
1566 return fmt.Errorf("body = %q; want nothing", slurp)
1567 }
1568 return nil
1569 }
1570 ct.server = func() error {
1571 ct.greet()
1572 var buf bytes.Buffer
1573 enc := hpack.NewEncoder(&buf)
1574
1575 for {
1576 f, err := ct.fr.ReadFrame()
1577 if err != nil {
1578 return err
1579 }
1580 switch f := f.(type) {
1581 case *HeadersFrame:
1582 var endStream bool
1583 send := func(mode headerType) {
1584 hbf := buf.Bytes()
1585 switch mode {
1586 case oneHeader:
1587 ct.fr.WriteHeaders(HeadersFrameParam{
1588 StreamID: f.StreamID,
1589 EndHeaders: true,
1590 EndStream: endStream,
1591 BlockFragment: hbf,
1592 })
1593 case splitHeader:
1594 if len(hbf) < 2 {
1595 panic("too small")
1596 }
1597 ct.fr.WriteHeaders(HeadersFrameParam{
1598 StreamID: f.StreamID,
1599 EndHeaders: false,
1600 EndStream: endStream,
1601 BlockFragment: hbf[:1],
1602 })
1603 ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
1604 default:
1605 panic("bogus mode")
1606 }
1607 }
1608
1609 {
1610 buf.Reset()
1611 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1612 enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
1613 endStream = false
1614 send(oneHeader)
1615 }
1616
1617 {
1618 endStream = true
1619 buf.Reset()
1620 writeTrailer(enc)
1621 send(trailers)
1622 }
1623 return nil
1624 }
1625 }
1626 }
1627 ct.run()
1628 }
1629
1630
1631
1632
1633
1634 func headerListSize(h http.Header) (size uint32) {
1635 for k, vv := range h {
1636 for _, v := range vv {
1637 hf := hpack.HeaderField{Name: k, Value: v}
1638 size += hf.Size()
1639 }
1640 }
1641 return size
1642 }
1643
1644
1645
1646
1647
1648
1649
1650
1651 func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
1652 if limit > 0xffffffff {
1653 t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
1654 }
1655 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1656 minPadding := uint64(hf.Size())
1657 size := uint64(headerListSize(h))
1658
1659 minlimit := size + minPadding
1660 if limit < minlimit {
1661 t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
1662 }
1663
1664
1665
1666 nameFmt := "Pad-Headers-%06d"
1667 hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
1668 fieldSize := uint64(hf.Size())
1669
1670
1671
1672 limit = limit - minPadding
1673 for i := 0; size+fieldSize < limit; i++ {
1674 name := fmt.Sprintf(nameFmt, i)
1675 h.Add(name, filler)
1676 size += fieldSize
1677 }
1678
1679
1680 remain := limit - size
1681 lastValue := strings.Repeat("*", int(remain))
1682 h.Add("Pad-Headers", lastValue)
1683 }
1684
1685 func TestPadHeaders(t *testing.T) {
1686 check := func(h http.Header, limit uint32, fillerLen int) {
1687 if h == nil {
1688 h = make(http.Header)
1689 }
1690 filler := strings.Repeat("f", fillerLen)
1691 padHeaders(t, h, uint64(limit), filler)
1692 gotSize := headerListSize(h)
1693 if gotSize != limit {
1694 t.Errorf("Got size = %v; want %v", gotSize, limit)
1695 }
1696 }
1697
1698 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1699 minLimit := hf.Size()
1700 for limit := minLimit; limit <= 128; limit++ {
1701 for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
1702 check(nil, limit, fillerLen)
1703 }
1704 }
1705
1706
1707
1708
1709
1710
1711 tests := []struct {
1712 fillerLen int
1713 limit uint32
1714 }{
1715 {
1716 fillerLen: 64,
1717 limit: 1024,
1718 },
1719 {
1720 fillerLen: 1024,
1721 limit: 1286,
1722 },
1723 {
1724 fillerLen: 256,
1725 limit: 2048,
1726 },
1727 {
1728 fillerLen: 1024,
1729 limit: 10 * 1024,
1730 },
1731 {
1732 fillerLen: 1023,
1733 limit: 11 * 1024,
1734 },
1735 }
1736 h := make(http.Header)
1737 for _, tc := range tests {
1738 check(nil, tc.limit, tc.fillerLen)
1739 check(h, tc.limit, tc.fillerLen)
1740 }
1741 }
1742
1743 func TestTransportChecksRequestHeaderListSize(t *testing.T) {
1744 st := newServerTester(t,
1745 func(w http.ResponseWriter, r *http.Request) {
1746
1747
1748
1749
1750
1751
1752 ioutil.ReadAll(r.Body)
1753 r.Body.Close()
1754 },
1755 func(ts *httptest.Server) {
1756 ts.Config.MaxHeaderBytes = 16 << 10
1757 },
1758 optOnlyServer,
1759 optQuiet,
1760 )
1761 defer st.Close()
1762
1763 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1764 defer tr.CloseIdleConnections()
1765
1766 checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
1767
1768
1769 req0, err := http.NewRequest("GET", st.ts.URL, nil)
1770 if err != nil {
1771 t.Fatalf("newRequest: NewRequest: %v", err)
1772 }
1773 res0, err := tr.RoundTrip(req0)
1774 if err != nil {
1775 t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
1776 }
1777 res0.Body.Close()
1778
1779 res, err := tr.RoundTrip(req)
1780 if err != wantErr {
1781 if res != nil {
1782 res.Body.Close()
1783 }
1784 t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
1785 return
1786 }
1787 if err == nil {
1788 if res == nil {
1789 t.Errorf("%v: response nil; want non-nil.", desc)
1790 return
1791 }
1792 defer res.Body.Close()
1793 if res.StatusCode != http.StatusOK {
1794 t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
1795 }
1796 return
1797 }
1798 if res != nil {
1799 t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
1800 }
1801 }
1802 headerListSizeForRequest := func(req *http.Request) (size uint64) {
1803 contentLen := actualContentLength(req)
1804 trailers, err := commaSeparatedTrailers(req)
1805 if err != nil {
1806 t.Fatalf("headerListSizeForRequest: %v", err)
1807 }
1808 cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
1809 cc.henc = hpack.NewEncoder(&cc.hbuf)
1810 cc.mu.Lock()
1811 hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
1812 cc.mu.Unlock()
1813 if err != nil {
1814 t.Fatalf("headerListSizeForRequest: %v", err)
1815 }
1816 hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
1817 size += uint64(hf.Size())
1818 })
1819 if len(hdrs) > 0 {
1820 if _, err := hpackDec.Write(hdrs); err != nil {
1821 t.Fatalf("headerListSizeForRequest: %v", err)
1822 }
1823 }
1824 return size
1825 }
1826
1827
1828
1829 newRequest := func() *http.Request {
1830
1831 body := strings.NewReader("hello")
1832 req, err := http.NewRequest("POST", st.ts.URL, body)
1833 if err != nil {
1834 t.Fatalf("newRequest: NewRequest: %v", err)
1835 }
1836 return req
1837 }
1838
1839
1840 req := newRequest()
1841 checkRoundTrip(req, nil, "Initial request")
1842 addr := authorityAddr(req.URL.Scheme, req.URL.Host)
1843 cc, err := tr.connPool().GetClientConn(req, addr)
1844 if err != nil {
1845 t.Fatalf("GetClientConn: %v", err)
1846 }
1847 cc.mu.Lock()
1848 peerSize := cc.peerMaxHeaderListSize
1849 cc.mu.Unlock()
1850 st.scMu.Lock()
1851 wantSize := uint64(st.sc.maxHeaderListSize())
1852 st.scMu.Unlock()
1853 if peerSize != wantSize {
1854 t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
1855 }
1856
1857
1858
1859 wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
1860 if peerSize != wantHeaderBytes {
1861 t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
1862 }
1863
1864
1865 req = newRequest()
1866 req.Header = make(http.Header)
1867 req.Trailer = make(http.Header)
1868 filler := strings.Repeat("*", 1024)
1869 padHeaders(t, req.Trailer, peerSize, filler)
1870
1871
1872 defaultBytes := headerListSizeForRequest(req)
1873 padHeaders(t, req.Header, peerSize-defaultBytes, filler)
1874 checkRoundTrip(req, nil, "Headers & Trailers under limit")
1875
1876
1877 req = newRequest()
1878 req.Header = make(http.Header)
1879 padHeaders(t, req.Header, peerSize, filler)
1880 checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
1881
1882
1883 req = newRequest()
1884 req.Trailer = make(http.Header)
1885 padHeaders(t, req.Trailer, peerSize+1, filler)
1886 checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
1887
1888
1889 req = newRequest()
1890 filler = strings.Repeat("*", int(peerSize))
1891 req.Header = make(http.Header)
1892 req.Header.Set("Big", filler)
1893 checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
1894
1895
1896 req = newRequest()
1897 req.Trailer = make(http.Header)
1898 req.Trailer.Set("Big", filler)
1899 checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
1900 }
1901
1902 func TestTransportChecksResponseHeaderListSize(t *testing.T) {
1903 ct := newClientTester(t)
1904 ct.client = func() error {
1905 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1906 res, err := ct.tr.RoundTrip(req)
1907 if e, ok := err.(StreamError); ok {
1908 err = e.Cause
1909 }
1910 if err != errResponseHeaderListSize {
1911 size := int64(0)
1912 if res != nil {
1913 res.Body.Close()
1914 for k, vv := range res.Header {
1915 for _, v := range vv {
1916 size += int64(len(k)) + int64(len(v)) + 32
1917 }
1918 }
1919 }
1920 return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
1921 }
1922 return nil
1923 }
1924 ct.server = func() error {
1925 ct.greet()
1926 var buf bytes.Buffer
1927 enc := hpack.NewEncoder(&buf)
1928
1929 for {
1930 f, err := ct.fr.ReadFrame()
1931 if err != nil {
1932 return err
1933 }
1934 switch f := f.(type) {
1935 case *HeadersFrame:
1936 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1937 large := strings.Repeat("a", 1<<10)
1938 for i := 0; i < 5042; i++ {
1939 enc.WriteField(hpack.HeaderField{Name: large, Value: large})
1940 }
1941 if size, want := buf.Len(), 6329; size != want {
1942
1943
1944
1945
1946
1947
1948 return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
1949 }
1950 ct.fr.WriteHeaders(HeadersFrameParam{
1951 StreamID: f.StreamID,
1952 EndHeaders: true,
1953 EndStream: true,
1954 BlockFragment: buf.Bytes(),
1955 })
1956 return nil
1957 }
1958 }
1959 }
1960 ct.run()
1961 }
1962
1963 func TestTransportCookieHeaderSplit(t *testing.T) {
1964 ct := newClientTester(t)
1965 ct.client = func() error {
1966 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1967 req.Header.Add("Cookie", "a=b;c=d; e=f;")
1968 req.Header.Add("Cookie", "e=f;g=h; ")
1969 req.Header.Add("Cookie", "i=j")
1970 _, err := ct.tr.RoundTrip(req)
1971 return err
1972 }
1973 ct.server = func() error {
1974 ct.greet()
1975 for {
1976 f, err := ct.fr.ReadFrame()
1977 if err != nil {
1978 return err
1979 }
1980 switch f := f.(type) {
1981 case *HeadersFrame:
1982 dec := hpack.NewDecoder(initialHeaderTableSize, nil)
1983 hfs, err := dec.DecodeFull(f.HeaderBlockFragment())
1984 if err != nil {
1985 return err
1986 }
1987 got := []string{}
1988 want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}
1989 for _, hf := range hfs {
1990 if hf.Name == "cookie" {
1991 got = append(got, hf.Value)
1992 }
1993 }
1994 if !reflect.DeepEqual(got, want) {
1995 t.Errorf("Cookies = %#v, want %#v", got, want)
1996 }
1997
1998 var buf bytes.Buffer
1999 enc := hpack.NewEncoder(&buf)
2000 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2001 ct.fr.WriteHeaders(HeadersFrameParam{
2002 StreamID: f.StreamID,
2003 EndHeaders: true,
2004 EndStream: true,
2005 BlockFragment: buf.Bytes(),
2006 })
2007 return nil
2008 }
2009 }
2010 }
2011 ct.run()
2012 }
2013
2014
2015
2016
2017 func TestTransportBodyReadErrorType(t *testing.T) {
2018 doPanic := make(chan bool, 1)
2019 st := newServerTester(t,
2020 func(w http.ResponseWriter, r *http.Request) {
2021 w.(http.Flusher).Flush()
2022 <-doPanic
2023 panic("boom")
2024 },
2025 optOnlyServer,
2026 optQuiet,
2027 )
2028 defer st.Close()
2029
2030 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2031 defer tr.CloseIdleConnections()
2032 c := &http.Client{Transport: tr}
2033
2034 res, err := c.Get(st.ts.URL)
2035 if err != nil {
2036 t.Fatal(err)
2037 }
2038 defer res.Body.Close()
2039 doPanic <- true
2040 buf := make([]byte, 100)
2041 n, err := res.Body.Read(buf)
2042 got, ok := err.(StreamError)
2043 want := StreamError{StreamID: 0x1, Code: 0x2}
2044 if !ok || got.StreamID != want.StreamID || got.Code != want.Code {
2045 t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
2046 }
2047 }
2048
2049
2050
2051
2052 func TestTransportDoubleCloseOnWriteError(t *testing.T) {
2053 var (
2054 mu sync.Mutex
2055 conn net.Conn
2056 )
2057
2058 st := newServerTester(t,
2059 func(w http.ResponseWriter, r *http.Request) {
2060 mu.Lock()
2061 defer mu.Unlock()
2062 if conn != nil {
2063 conn.Close()
2064 }
2065 },
2066 optOnlyServer,
2067 )
2068 defer st.Close()
2069
2070 tr := &Transport{
2071 TLSClientConfig: tlsConfigInsecure,
2072 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
2073 tc, err := tls.Dial(network, addr, cfg)
2074 if err != nil {
2075 return nil, err
2076 }
2077 mu.Lock()
2078 defer mu.Unlock()
2079 conn = tc
2080 return tc, nil
2081 },
2082 }
2083 defer tr.CloseIdleConnections()
2084 c := &http.Client{Transport: tr}
2085 c.Get(st.ts.URL)
2086 }
2087
2088
2089
2090
2091 func TestTransportDisableKeepAlives(t *testing.T) {
2092 st := newServerTester(t,
2093 func(w http.ResponseWriter, r *http.Request) {
2094 io.WriteString(w, "hi")
2095 },
2096 optOnlyServer,
2097 )
2098 defer st.Close()
2099
2100 connClosed := make(chan struct{})
2101 tr := &Transport{
2102 t1: &http.Transport{
2103 DisableKeepAlives: true,
2104 },
2105 TLSClientConfig: tlsConfigInsecure,
2106 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
2107 tc, err := tls.Dial(network, addr, cfg)
2108 if err != nil {
2109 return nil, err
2110 }
2111 return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
2112 },
2113 }
2114 c := &http.Client{Transport: tr}
2115 res, err := c.Get(st.ts.URL)
2116 if err != nil {
2117 t.Fatal(err)
2118 }
2119 if _, err := ioutil.ReadAll(res.Body); err != nil {
2120 t.Fatal(err)
2121 }
2122 defer res.Body.Close()
2123
2124 select {
2125 case <-connClosed:
2126 case <-time.After(1 * time.Second):
2127 t.Errorf("timeout")
2128 }
2129
2130 }
2131
2132
2133
2134 func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
2135 const D = 25 * time.Millisecond
2136 st := newServerTester(t,
2137 func(w http.ResponseWriter, r *http.Request) {
2138 time.Sleep(D)
2139 io.WriteString(w, "hi")
2140 },
2141 optOnlyServer,
2142 )
2143 defer st.Close()
2144
2145 var dials int32
2146 var conns sync.WaitGroup
2147 tr := &Transport{
2148 t1: &http.Transport{
2149 DisableKeepAlives: true,
2150 },
2151 TLSClientConfig: tlsConfigInsecure,
2152 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
2153 tc, err := tls.Dial(network, addr, cfg)
2154 if err != nil {
2155 return nil, err
2156 }
2157 atomic.AddInt32(&dials, 1)
2158 conns.Add(1)
2159 return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
2160 },
2161 }
2162 c := &http.Client{Transport: tr}
2163 var reqs sync.WaitGroup
2164 const N = 20
2165 for i := 0; i < N; i++ {
2166 reqs.Add(1)
2167 if i == N-1 {
2168
2169
2170
2171
2172
2173
2174 time.Sleep(D * 2)
2175 }
2176 go func() {
2177 defer reqs.Done()
2178 res, err := c.Get(st.ts.URL)
2179 if err != nil {
2180 t.Error(err)
2181 return
2182 }
2183 if _, err := ioutil.ReadAll(res.Body); err != nil {
2184 t.Error(err)
2185 return
2186 }
2187 res.Body.Close()
2188 }()
2189 }
2190 reqs.Wait()
2191 conns.Wait()
2192 t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
2193 }
2194
2195 type noteCloseConn struct {
2196 net.Conn
2197 onceClose sync.Once
2198 closefn func()
2199 }
2200
2201 func (c *noteCloseConn) Close() error {
2202 c.onceClose.Do(c.closefn)
2203 return c.Conn.Close()
2204 }
2205
2206 func isTimeout(err error) bool {
2207 switch err := err.(type) {
2208 case nil:
2209 return false
2210 case *url.Error:
2211 return isTimeout(err.Err)
2212 case net.Error:
2213 return err.Timeout()
2214 }
2215 return false
2216 }
2217
2218
2219 func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
2220 testTransportResponseHeaderTimeout(t, false)
2221 }
2222 func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
2223 testTransportResponseHeaderTimeout(t, true)
2224 }
2225
2226 func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
2227 ct := newClientTester(t)
2228 ct.tr.t1 = &http.Transport{
2229 ResponseHeaderTimeout: 5 * time.Millisecond,
2230 }
2231 ct.client = func() error {
2232 c := &http.Client{Transport: ct.tr}
2233 var err error
2234 var n int64
2235 const bodySize = 4 << 20
2236 if body {
2237 _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
2238 } else {
2239 _, err = c.Get("https://dummy.tld/")
2240 }
2241 if !isTimeout(err) {
2242 t.Errorf("client expected timeout error; got %#v", err)
2243 }
2244 if body && n != bodySize {
2245 t.Errorf("only read %d bytes of body; want %d", n, bodySize)
2246 }
2247 return nil
2248 }
2249 ct.server = func() error {
2250 ct.greet()
2251 for {
2252 f, err := ct.fr.ReadFrame()
2253 if err != nil {
2254 t.Logf("ReadFrame: %v", err)
2255 return nil
2256 }
2257 switch f := f.(type) {
2258 case *DataFrame:
2259 dataLen := len(f.Data())
2260 if dataLen > 0 {
2261 if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
2262 return err
2263 }
2264 if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
2265 return err
2266 }
2267 }
2268 case *RSTStreamFrame:
2269 if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
2270 return nil
2271 }
2272 }
2273 }
2274 }
2275 ct.run()
2276 }
2277
2278 func TestTransportDisableCompression(t *testing.T) {
2279 const body = "sup"
2280 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2281 want := http.Header{
2282 "User-Agent": []string{"Go-http-client/2.0"},
2283 }
2284 if !reflect.DeepEqual(r.Header, want) {
2285 t.Errorf("request headers = %v; want %v", r.Header, want)
2286 }
2287 }, optOnlyServer)
2288 defer st.Close()
2289
2290 tr := &Transport{
2291 TLSClientConfig: tlsConfigInsecure,
2292 t1: &http.Transport{
2293 DisableCompression: true,
2294 },
2295 }
2296 defer tr.CloseIdleConnections()
2297
2298 req, err := http.NewRequest("GET", st.ts.URL, nil)
2299 if err != nil {
2300 t.Fatal(err)
2301 }
2302 res, err := tr.RoundTrip(req)
2303 if err != nil {
2304 t.Fatal(err)
2305 }
2306 defer res.Body.Close()
2307 }
2308
2309
2310 func TestTransportRejectsConnHeaders(t *testing.T) {
2311 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2312 var got []string
2313 for k := range r.Header {
2314 got = append(got, k)
2315 }
2316 sort.Strings(got)
2317 w.Header().Set("Got-Header", strings.Join(got, ","))
2318 }, optOnlyServer)
2319 defer st.Close()
2320
2321 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2322 defer tr.CloseIdleConnections()
2323
2324 tests := []struct {
2325 key string
2326 value []string
2327 want string
2328 }{
2329 {
2330 key: "Upgrade",
2331 value: []string{"anything"},
2332 want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
2333 },
2334 {
2335 key: "Connection",
2336 value: []string{"foo"},
2337 want: "ERROR: http2: invalid Connection request header: [\"foo\"]",
2338 },
2339 {
2340 key: "Connection",
2341 value: []string{"close"},
2342 want: "Accept-Encoding,User-Agent",
2343 },
2344 {
2345 key: "Connection",
2346 value: []string{"CLoSe"},
2347 want: "Accept-Encoding,User-Agent",
2348 },
2349 {
2350 key: "Connection",
2351 value: []string{"close", "something-else"},
2352 want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
2353 },
2354 {
2355 key: "Connection",
2356 value: []string{"keep-alive"},
2357 want: "Accept-Encoding,User-Agent",
2358 },
2359 {
2360 key: "Connection",
2361 value: []string{"Keep-ALIVE"},
2362 want: "Accept-Encoding,User-Agent",
2363 },
2364 {
2365 key: "Proxy-Connection",
2366 value: []string{"keep-alive"},
2367 want: "Accept-Encoding,User-Agent",
2368 },
2369 {
2370 key: "Transfer-Encoding",
2371 value: []string{""},
2372 want: "Accept-Encoding,User-Agent",
2373 },
2374 {
2375 key: "Transfer-Encoding",
2376 value: []string{"foo"},
2377 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
2378 },
2379 {
2380 key: "Transfer-Encoding",
2381 value: []string{"chunked"},
2382 want: "Accept-Encoding,User-Agent",
2383 },
2384 {
2385 key: "Transfer-Encoding",
2386 value: []string{"chunKed"},
2387 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]",
2388 },
2389 {
2390 key: "Transfer-Encoding",
2391 value: []string{"chunked", "other"},
2392 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
2393 },
2394 {
2395 key: "Content-Length",
2396 value: []string{"123"},
2397 want: "Accept-Encoding,User-Agent",
2398 },
2399 {
2400 key: "Keep-Alive",
2401 value: []string{"doop"},
2402 want: "Accept-Encoding,User-Agent",
2403 },
2404 }
2405
2406 for _, tt := range tests {
2407 req, _ := http.NewRequest("GET", st.ts.URL, nil)
2408 req.Header[tt.key] = tt.value
2409 res, err := tr.RoundTrip(req)
2410 var got string
2411 if err != nil {
2412 got = fmt.Sprintf("ERROR: %v", err)
2413 } else {
2414 got = res.Header.Get("Got-Header")
2415 res.Body.Close()
2416 }
2417 if got != tt.want {
2418 t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
2419 }
2420 }
2421 }
2422
2423
2424
2425 func TestTransportRejectsContentLengthWithSign(t *testing.T) {
2426 tests := []struct {
2427 name string
2428 cl []string
2429 wantCL string
2430 }{
2431 {
2432 name: "proper content-length",
2433 cl: []string{"3"},
2434 wantCL: "3",
2435 },
2436 {
2437 name: "ignore cl with plus sign",
2438 cl: []string{"+3"},
2439 wantCL: "",
2440 },
2441 {
2442 name: "ignore cl with minus sign",
2443 cl: []string{"-3"},
2444 wantCL: "",
2445 },
2446 {
2447 name: "max int64, for safe uint64->int64 conversion",
2448 cl: []string{"9223372036854775807"},
2449 wantCL: "9223372036854775807",
2450 },
2451 {
2452 name: "overflows int64, so ignored",
2453 cl: []string{"9223372036854775808"},
2454 wantCL: "",
2455 },
2456 }
2457
2458 for _, tt := range tests {
2459 tt := tt
2460 t.Run(tt.name, func(t *testing.T) {
2461 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2462 w.Header().Set("Content-Length", tt.cl[0])
2463 }, optOnlyServer)
2464 defer st.Close()
2465 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2466 defer tr.CloseIdleConnections()
2467
2468 req, _ := http.NewRequest("HEAD", st.ts.URL, nil)
2469 res, err := tr.RoundTrip(req)
2470
2471 var got string
2472 if err != nil {
2473 got = fmt.Sprintf("ERROR: %v", err)
2474 } else {
2475 got = res.Header.Get("Content-Length")
2476 res.Body.Close()
2477 }
2478
2479 if got != tt.wantCL {
2480 t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL)
2481 }
2482 })
2483 }
2484 }
2485
2486
2487 func TestTransportFailsOnInvalidHeaders(t *testing.T) {
2488 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2489 var got []string
2490 for k := range r.Header {
2491 got = append(got, k)
2492 }
2493 sort.Strings(got)
2494 w.Header().Set("Got-Header", strings.Join(got, ","))
2495 }, optOnlyServer)
2496 defer st.Close()
2497
2498 tests := [...]struct {
2499 h http.Header
2500 wantErr string
2501 }{
2502 0: {
2503 h: http.Header{"with space": {"foo"}},
2504 wantErr: `invalid HTTP header name "with space"`,
2505 },
2506 1: {
2507 h: http.Header{"name": {"Брэд"}},
2508 wantErr: "",
2509 },
2510 2: {
2511 h: http.Header{"имя": {"Brad"}},
2512 wantErr: `invalid HTTP header name "имя"`,
2513 },
2514 3: {
2515 h: http.Header{"foo": {"foo\x01bar"}},
2516 wantErr: `invalid HTTP header value for header "foo"`,
2517 },
2518 }
2519
2520 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2521 defer tr.CloseIdleConnections()
2522
2523 for i, tt := range tests {
2524 req, _ := http.NewRequest("GET", st.ts.URL, nil)
2525 req.Header = tt.h
2526 res, err := tr.RoundTrip(req)
2527 var bad bool
2528 if tt.wantErr == "" {
2529 if err != nil {
2530 bad = true
2531 t.Errorf("case %d: error = %v; want no error", i, err)
2532 }
2533 } else {
2534 if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
2535 bad = true
2536 t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
2537 }
2538 }
2539 if err == nil {
2540 if bad {
2541 t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
2542 }
2543 res.Body.Close()
2544 }
2545 }
2546 }
2547
2548
2549
2550 func TestGzipReader_DoubleReadCrash(t *testing.T) {
2551 gz := &gzipReader{
2552 body: ioutil.NopCloser(strings.NewReader("0123456789")),
2553 }
2554 var buf [1]byte
2555 n, err1 := gz.Read(buf[:])
2556 if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
2557 t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
2558 }
2559 n, err2 := gz.Read(buf[:])
2560 if n != 0 || err2 != err1 {
2561 t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
2562 }
2563 }
2564
2565 func TestGzipReader_ReadAfterClose(t *testing.T) {
2566 body := bytes.Buffer{}
2567 w := gzip.NewWriter(&body)
2568 w.Write([]byte("012345679"))
2569 w.Close()
2570 gz := &gzipReader{
2571 body: ioutil.NopCloser(&body),
2572 }
2573 var buf [1]byte
2574 n, err := gz.Read(buf[:])
2575 if n != 1 || err != nil {
2576 t.Fatalf("first Read = %v, %v; want 1, nil", n, err)
2577 }
2578 if err := gz.Close(); err != nil {
2579 t.Fatalf("gz Close error: %v", err)
2580 }
2581 n, err = gz.Read(buf[:])
2582 if n != 0 || err != fs.ErrClosed {
2583 t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err)
2584 }
2585 }
2586
2587 func TestTransportNewTLSConfig(t *testing.T) {
2588 tests := [...]struct {
2589 conf *tls.Config
2590 host string
2591 want *tls.Config
2592 }{
2593
2594 0: {
2595 conf: nil,
2596 host: "foo.com",
2597 want: &tls.Config{
2598 ServerName: "foo.com",
2599 NextProtos: []string{NextProtoTLS},
2600 },
2601 },
2602
2603
2604 1: {
2605 conf: &tls.Config{
2606 ServerName: "bar.com",
2607 },
2608 host: "foo.com",
2609 want: &tls.Config{
2610 ServerName: "bar.com",
2611 NextProtos: []string{NextProtoTLS},
2612 },
2613 },
2614
2615
2616 2: {
2617 conf: &tls.Config{
2618 NextProtos: []string{"foo", "bar"},
2619 },
2620 host: "example.com",
2621 want: &tls.Config{
2622 ServerName: "example.com",
2623 NextProtos: []string{NextProtoTLS, "foo", "bar"},
2624 },
2625 },
2626
2627
2628 3: {
2629 conf: &tls.Config{
2630 NextProtos: []string{"foo", "bar", NextProtoTLS},
2631 },
2632 host: "example.com",
2633 want: &tls.Config{
2634 ServerName: "example.com",
2635 NextProtos: []string{"foo", "bar", NextProtoTLS},
2636 },
2637 },
2638 }
2639 for i, tt := range tests {
2640
2641
2642 if tt.conf != nil {
2643 tt.conf.SessionTicketsDisabled = true
2644 }
2645
2646 tr := &Transport{TLSClientConfig: tt.conf}
2647 got := tr.newTLSConfig(tt.host)
2648
2649 got.SessionTicketsDisabled = false
2650
2651 if !reflect.DeepEqual(got, tt.want) {
2652 t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
2653 }
2654 }
2655 }
2656
2657
2658
2659
2660 func TestTransportReadHeadResponse(t *testing.T) {
2661 ct := newClientTester(t)
2662 clientDone := make(chan struct{})
2663 ct.client = func() error {
2664 defer close(clientDone)
2665 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2666 res, err := ct.tr.RoundTrip(req)
2667 if err != nil {
2668 return err
2669 }
2670 if res.ContentLength != 123 {
2671 return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
2672 }
2673 slurp, err := ioutil.ReadAll(res.Body)
2674 if err != nil {
2675 return fmt.Errorf("ReadAll: %v", err)
2676 }
2677 if len(slurp) > 0 {
2678 return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
2679 }
2680 return nil
2681 }
2682 ct.server = func() error {
2683 ct.greet()
2684 for {
2685 f, err := ct.fr.ReadFrame()
2686 if err != nil {
2687 t.Logf("ReadFrame: %v", err)
2688 return nil
2689 }
2690 hf, ok := f.(*HeadersFrame)
2691 if !ok {
2692 continue
2693 }
2694 var buf bytes.Buffer
2695 enc := hpack.NewEncoder(&buf)
2696 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2697 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
2698 ct.fr.WriteHeaders(HeadersFrameParam{
2699 StreamID: hf.StreamID,
2700 EndHeaders: true,
2701 EndStream: false,
2702 BlockFragment: buf.Bytes(),
2703 })
2704 ct.fr.WriteData(hf.StreamID, true, nil)
2705
2706 <-clientDone
2707 return nil
2708 }
2709 }
2710 ct.run()
2711 }
2712
2713 func TestTransportReadHeadResponseWithBody(t *testing.T) {
2714
2715
2716 log.SetOutput(ioutil.Discard)
2717 defer log.SetOutput(os.Stderr)
2718
2719 response := "redirecting to /elsewhere"
2720 ct := newClientTester(t)
2721 clientDone := make(chan struct{})
2722 ct.client = func() error {
2723 defer close(clientDone)
2724 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2725 res, err := ct.tr.RoundTrip(req)
2726 if err != nil {
2727 return err
2728 }
2729 if res.ContentLength != int64(len(response)) {
2730 return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
2731 }
2732 slurp, err := ioutil.ReadAll(res.Body)
2733 if err != nil {
2734 return fmt.Errorf("ReadAll: %v", err)
2735 }
2736 if len(slurp) > 0 {
2737 return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
2738 }
2739 return nil
2740 }
2741 ct.server = func() error {
2742 ct.greet()
2743 for {
2744 f, err := ct.fr.ReadFrame()
2745 if err != nil {
2746 t.Logf("ReadFrame: %v", err)
2747 return nil
2748 }
2749 hf, ok := f.(*HeadersFrame)
2750 if !ok {
2751 continue
2752 }
2753 var buf bytes.Buffer
2754 enc := hpack.NewEncoder(&buf)
2755 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2756 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
2757 ct.fr.WriteHeaders(HeadersFrameParam{
2758 StreamID: hf.StreamID,
2759 EndHeaders: true,
2760 EndStream: false,
2761 BlockFragment: buf.Bytes(),
2762 })
2763 ct.fr.WriteData(hf.StreamID, true, []byte(response))
2764
2765 <-clientDone
2766 return nil
2767 }
2768 }
2769 ct.run()
2770 }
2771
2772 type neverEnding byte
2773
2774 func (b neverEnding) Read(p []byte) (int, error) {
2775 for i := range p {
2776 p[i] = byte(b)
2777 }
2778 return len(p), nil
2779 }
2780
2781
2782
2783
2784
2785 func TestTransportHandlerBodyClose(t *testing.T) {
2786 const bodySize = 10 << 20
2787 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2788 r.Body.Close()
2789 io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
2790 }, optOnlyServer)
2791 defer st.Close()
2792
2793 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2794 defer tr.CloseIdleConnections()
2795
2796 g0 := runtime.NumGoroutine()
2797
2798 const numReq = 10
2799 for i := 0; i < numReq; i++ {
2800 req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
2801 if err != nil {
2802 t.Fatal(err)
2803 }
2804 res, err := tr.RoundTrip(req)
2805 if err != nil {
2806 t.Fatal(err)
2807 }
2808 n, err := io.Copy(ioutil.Discard, res.Body)
2809 res.Body.Close()
2810 if n != bodySize || err != nil {
2811 t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
2812 }
2813 }
2814 tr.CloseIdleConnections()
2815
2816 if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool {
2817 gd := runtime.NumGoroutine() - g0
2818 return gd < numReq/2
2819 }) {
2820 t.Errorf("appeared to leak goroutines")
2821 }
2822 }
2823
2824
2825 func TestTransportFlowControl(t *testing.T) {
2826 const bufLen = 64 << 10
2827 var total int64 = 100 << 20
2828 if testing.Short() {
2829 total = 10 << 20
2830 }
2831
2832 var wrote int64
2833 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2834 b := make([]byte, bufLen)
2835 for wrote < total {
2836 n, err := w.Write(b)
2837 atomic.AddInt64(&wrote, int64(n))
2838 if err != nil {
2839 t.Errorf("ResponseWriter.Write error: %v", err)
2840 break
2841 }
2842 w.(http.Flusher).Flush()
2843 }
2844 }, optOnlyServer)
2845
2846 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2847 defer tr.CloseIdleConnections()
2848 req, err := http.NewRequest("GET", st.ts.URL, nil)
2849 if err != nil {
2850 t.Fatal("NewRequest error:", err)
2851 }
2852 resp, err := tr.RoundTrip(req)
2853 if err != nil {
2854 t.Fatal("RoundTrip error:", err)
2855 }
2856 defer resp.Body.Close()
2857
2858 var read int64
2859 b := make([]byte, bufLen)
2860 for {
2861 n, err := resp.Body.Read(b)
2862 if err == io.EOF {
2863 break
2864 }
2865 if err != nil {
2866 t.Fatal("Read error:", err)
2867 }
2868 read += int64(n)
2869
2870 const max = transportDefaultStreamFlow
2871 if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max {
2872 t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read)
2873 }
2874
2875
2876 time.Sleep(1 * time.Millisecond)
2877 }
2878 }
2879
2880
2881
2882
2883
2884
2885 func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
2886 testTransportUsesGoAwayDebugError(t, false)
2887 }
2888
2889 func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
2890 testTransportUsesGoAwayDebugError(t, true)
2891 }
2892
2893 func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
2894 ct := newClientTester(t)
2895 clientDone := make(chan struct{})
2896
2897 const goAwayErrCode = ErrCodeHTTP11Required
2898 const goAwayDebugData = "some debug data"
2899
2900 ct.client = func() error {
2901 defer close(clientDone)
2902 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2903 res, err := ct.tr.RoundTrip(req)
2904 if failMidBody {
2905 if err != nil {
2906 return fmt.Errorf("unexpected client RoundTrip error: %v", err)
2907 }
2908 _, err = io.Copy(ioutil.Discard, res.Body)
2909 res.Body.Close()
2910 }
2911 want := GoAwayError{
2912 LastStreamID: 5,
2913 ErrCode: goAwayErrCode,
2914 DebugData: goAwayDebugData,
2915 }
2916 if !reflect.DeepEqual(err, want) {
2917 t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
2918 }
2919 return nil
2920 }
2921 ct.server = func() error {
2922 ct.greet()
2923 for {
2924 f, err := ct.fr.ReadFrame()
2925 if err != nil {
2926 t.Logf("ReadFrame: %v", err)
2927 return nil
2928 }
2929 hf, ok := f.(*HeadersFrame)
2930 if !ok {
2931 continue
2932 }
2933 if failMidBody {
2934 var buf bytes.Buffer
2935 enc := hpack.NewEncoder(&buf)
2936 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2937 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
2938 ct.fr.WriteHeaders(HeadersFrameParam{
2939 StreamID: hf.StreamID,
2940 EndHeaders: true,
2941 EndStream: false,
2942 BlockFragment: buf.Bytes(),
2943 })
2944 }
2945
2946
2947 ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
2948 ct.fr.WriteGoAway(5, goAwayErrCode, nil)
2949 ct.sc.(*net.TCPConn).CloseWrite()
2950 if runtime.GOOS == "plan9" {
2951
2952 ct.sc.(*net.TCPConn).Close()
2953 }
2954 <-clientDone
2955 return nil
2956 }
2957 }
2958 ct.run()
2959 }
2960
2961 func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
2962 ct := newClientTester(t)
2963
2964 ct.client = func() error {
2965 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2966 res, err := ct.tr.RoundTrip(req)
2967 if err != nil {
2968 return err
2969 }
2970
2971 if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
2972 return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
2973 }
2974 res.Body.Close()
2975
2976 return nil
2977 }
2978 ct.server = func() error {
2979 ct.greet()
2980
2981 var hf *HeadersFrame
2982 for {
2983 f, err := ct.fr.ReadFrame()
2984 if err != nil {
2985 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
2986 }
2987 switch f.(type) {
2988 case *WindowUpdateFrame, *SettingsFrame:
2989 continue
2990 }
2991 var ok bool
2992 hf, ok = f.(*HeadersFrame)
2993 if !ok {
2994 return fmt.Errorf("Got %T; want HeadersFrame", f)
2995 }
2996 break
2997 }
2998
2999 var buf bytes.Buffer
3000 enc := hpack.NewEncoder(&buf)
3001 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3002 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
3003 ct.fr.WriteHeaders(HeadersFrameParam{
3004 StreamID: hf.StreamID,
3005 EndHeaders: true,
3006 EndStream: false,
3007 BlockFragment: buf.Bytes(),
3008 })
3009 initialInflow := ct.inflowWindow(0)
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021 if oneDataFrame {
3022 ct.fr.WriteData(hf.StreamID, false , make([]byte, 5000))
3023 } else {
3024 ct.fr.WriteData(hf.StreamID, false , make([]byte, 1))
3025 }
3026
3027 wantRST := true
3028 wantWUF := true
3029 if !oneDataFrame {
3030 wantWUF = false
3031 }
3032 for wantRST || wantWUF {
3033 f, err := ct.readNonSettingsFrame()
3034 if err != nil {
3035 return err
3036 }
3037 switch f := f.(type) {
3038 case *RSTStreamFrame:
3039 if !wantRST {
3040 return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
3041 }
3042 if f.ErrCode != ErrCodeCancel {
3043 return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
3044 }
3045 wantRST = false
3046 case *WindowUpdateFrame:
3047 if !wantWUF {
3048 return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
3049 }
3050 if f.Increment != 5000 {
3051 return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
3052 }
3053 wantWUF = false
3054 default:
3055 return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
3056 }
3057 }
3058 if !oneDataFrame {
3059 ct.fr.WriteData(hf.StreamID, false , make([]byte, 4999))
3060 f, err := ct.readNonSettingsFrame()
3061 if err != nil {
3062 return err
3063 }
3064 wuf, ok := f.(*WindowUpdateFrame)
3065 if !ok || wuf.Increment != 5000 {
3066 return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f))
3067 }
3068 }
3069 if err := ct.writeReadPing(); err != nil {
3070 return err
3071 }
3072 if got, want := ct.inflowWindow(0), initialInflow; got != want {
3073 return fmt.Errorf("connection flow tokens = %v, want %v", got, want)
3074 }
3075 return nil
3076 }
3077 ct.run()
3078 }
3079
3080
3081 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
3082 testTransportReturnsUnusedFlowControl(t, true)
3083 }
3084
3085
3086 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
3087 testTransportReturnsUnusedFlowControl(t, false)
3088 }
3089
3090
3091
3092 func TestTransportAdjustsFlowControl(t *testing.T) {
3093 ct := newClientTester(t)
3094 clientDone := make(chan struct{})
3095
3096 const bodySize = 1 << 20
3097
3098 ct.client = func() error {
3099 defer ct.cc.(*net.TCPConn).CloseWrite()
3100 if runtime.GOOS == "plan9" {
3101
3102 defer ct.cc.(*net.TCPConn).Close()
3103 }
3104 defer close(clientDone)
3105
3106 req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
3107 res, err := ct.tr.RoundTrip(req)
3108 if err != nil {
3109 return err
3110 }
3111 res.Body.Close()
3112 return nil
3113 }
3114 ct.server = func() error {
3115 _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
3116 if err != nil {
3117 return fmt.Errorf("reading client preface: %v", err)
3118 }
3119
3120 var gotBytes int64
3121 var sentSettings bool
3122 for {
3123 f, err := ct.fr.ReadFrame()
3124 if err != nil {
3125 select {
3126 case <-clientDone:
3127 return nil
3128 default:
3129 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
3130 }
3131 }
3132 switch f := f.(type) {
3133 case *DataFrame:
3134 gotBytes += int64(len(f.Data()))
3135
3136
3137
3138
3139 if gotBytes >= initialWindowSize/2 && !sentSettings {
3140 sentSettings = true
3141
3142 ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
3143 ct.fr.WriteWindowUpdate(0, bodySize)
3144 ct.fr.WriteSettingsAck()
3145 }
3146
3147 if f.StreamEnded() {
3148 var buf bytes.Buffer
3149 enc := hpack.NewEncoder(&buf)
3150 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3151 ct.fr.WriteHeaders(HeadersFrameParam{
3152 StreamID: f.StreamID,
3153 EndHeaders: true,
3154 EndStream: true,
3155 BlockFragment: buf.Bytes(),
3156 })
3157 }
3158 }
3159 }
3160 }
3161 ct.run()
3162 }
3163
3164
3165 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
3166 ct := newClientTester(t)
3167
3168 unblockClient := make(chan bool, 1)
3169
3170 ct.client = func() error {
3171 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3172 res, err := ct.tr.RoundTrip(req)
3173 if err != nil {
3174 return err
3175 }
3176 defer res.Body.Close()
3177 <-unblockClient
3178 return nil
3179 }
3180 ct.server = func() error {
3181 ct.greet()
3182
3183 var hf *HeadersFrame
3184 for {
3185 f, err := ct.fr.ReadFrame()
3186 if err != nil {
3187 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
3188 }
3189 switch f.(type) {
3190 case *WindowUpdateFrame, *SettingsFrame:
3191 continue
3192 }
3193 var ok bool
3194 hf, ok = f.(*HeadersFrame)
3195 if !ok {
3196 return fmt.Errorf("Got %T; want HeadersFrame", f)
3197 }
3198 break
3199 }
3200
3201 initialConnWindow := ct.inflowWindow(0)
3202
3203 var buf bytes.Buffer
3204 enc := hpack.NewEncoder(&buf)
3205 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3206 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
3207 ct.fr.WriteHeaders(HeadersFrameParam{
3208 StreamID: hf.StreamID,
3209 EndHeaders: true,
3210 EndStream: false,
3211 BlockFragment: buf.Bytes(),
3212 })
3213 initialStreamWindow := ct.inflowWindow(hf.StreamID)
3214 pad := make([]byte, 5)
3215 ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad)
3216 if err := ct.writeReadPing(); err != nil {
3217 return err
3218 }
3219
3220 if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want {
3221 t.Errorf("conn inflow window = %v, want %v", got, want)
3222 }
3223 if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want {
3224 t.Errorf("stream inflow window = %v, want %v", got, want)
3225 }
3226 unblockClient <- true
3227 return nil
3228 }
3229 ct.run()
3230 }
3231
3232
3233
3234 func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
3235 ct := newClientTester(t)
3236
3237 ct.client = func() error {
3238 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3239 res, err := ct.tr.RoundTrip(req)
3240 if err == nil {
3241 res.Body.Close()
3242 return errors.New("unexpected successful GET")
3243 }
3244 want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
3245 if !reflect.DeepEqual(want, err) {
3246 t.Errorf("RoundTrip error = %#v; want %#v", err, want)
3247 }
3248 return nil
3249 }
3250 ct.server = func() error {
3251 ct.greet()
3252
3253 hf, err := ct.firstHeaders()
3254 if err != nil {
3255 return err
3256 }
3257
3258 var buf bytes.Buffer
3259 enc := hpack.NewEncoder(&buf)
3260 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3261 enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"})
3262 ct.fr.WriteHeaders(HeadersFrameParam{
3263 StreamID: hf.StreamID,
3264 EndHeaders: true,
3265 EndStream: false,
3266 BlockFragment: buf.Bytes(),
3267 })
3268
3269 for {
3270 fr, err := ct.readFrame()
3271 if err != nil {
3272 return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
3273 }
3274 if _, ok := fr.(*SettingsFrame); ok {
3275 continue
3276 }
3277 if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
3278 t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
3279 }
3280 break
3281 }
3282
3283 return nil
3284 }
3285 ct.run()
3286 }
3287
3288
3289
3290 type byteAndEOFReader byte
3291
3292 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
3293 if len(p) == 0 {
3294 panic("unexpected useless call")
3295 }
3296 p[0] = byte(b)
3297 return 1, io.EOF
3298 }
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309 func TestTransportBodyDoubleEndStream(t *testing.T) {
3310 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3311
3312 }, optOnlyServer)
3313 defer st.Close()
3314
3315 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3316 defer tr.CloseIdleConnections()
3317
3318 for i := 0; i < 2; i++ {
3319 req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a'))
3320 req.ContentLength = 1
3321 res, err := tr.RoundTrip(req)
3322 if err != nil {
3323 t.Fatalf("failure on req %d: %v", i+1, err)
3324 }
3325 defer res.Body.Close()
3326 }
3327 }
3328
3329
3330 func TestTransportRequestPathPseudo(t *testing.T) {
3331 type result struct {
3332 path string
3333 err string
3334 }
3335 tests := []struct {
3336 req *http.Request
3337 want result
3338 }{
3339 0: {
3340 req: &http.Request{
3341 Method: "GET",
3342 URL: &url.URL{
3343 Host: "foo.com",
3344 Path: "/foo",
3345 },
3346 },
3347 want: result{path: "/foo"},
3348 },
3349
3350
3351
3352 1: {
3353 req: &http.Request{
3354 Method: "GET",
3355 URL: &url.URL{
3356 Host: "foo.com",
3357 Path: "//foo",
3358 },
3359 },
3360 want: result{path: "//foo"},
3361 },
3362
3363
3364 2: {
3365 req: &http.Request{
3366 Method: "GET",
3367 URL: &url.URL{
3368 Scheme: "https",
3369 Opaque: "//foo.com/path",
3370 Host: "foo.com",
3371 Path: "/ignored",
3372 },
3373 },
3374 want: result{path: "/path"},
3375 },
3376
3377
3378 3: {
3379 req: &http.Request{
3380 Method: "GET",
3381 Host: "bar.com",
3382 URL: &url.URL{
3383 Scheme: "https",
3384 Opaque: "//bar.com/path",
3385 Host: "foo.com",
3386 Path: "/ignored",
3387 },
3388 },
3389 want: result{path: "/path"},
3390 },
3391
3392
3393 4: {
3394 req: &http.Request{
3395 Method: "GET",
3396 URL: &url.URL{
3397 Opaque: "/path",
3398 Host: "foo.com",
3399 Path: "/ignored",
3400 },
3401 },
3402 want: result{path: "/path"},
3403 },
3404
3405
3406 5: {
3407 req: &http.Request{
3408 Method: "GET",
3409 URL: &url.URL{
3410 Scheme: "https",
3411 Opaque: "//unknown_host/path",
3412 Host: "foo.com",
3413 Path: "/ignored",
3414 },
3415 },
3416 want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
3417 },
3418
3419
3420 6: {
3421 req: &http.Request{
3422 Method: "CONNECT",
3423 URL: &url.URL{
3424 Host: "foo.com",
3425 },
3426 },
3427 want: result{},
3428 },
3429 }
3430 for i, tt := range tests {
3431 cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
3432 cc.henc = hpack.NewEncoder(&cc.hbuf)
3433 cc.mu.Lock()
3434 hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
3435 cc.mu.Unlock()
3436 var got result
3437 hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
3438 if f.Name == ":path" {
3439 got.path = f.Value
3440 }
3441 })
3442 if err != nil {
3443 got.err = err.Error()
3444 } else if len(hdrs) > 0 {
3445 if _, err := hpackDec.Write(hdrs); err != nil {
3446 t.Errorf("%d. bogus hpack: %v", i, err)
3447 continue
3448 }
3449 }
3450 if got != tt.want {
3451 t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
3452 }
3453
3454 }
3455
3456 }
3457
3458
3459
3460 func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
3461 const body = "foo"
3462 req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
3463 cc := &ClientConn{
3464 closed: true,
3465 reqHeaderMu: make(chan struct{}, 1),
3466 }
3467 _, err := cc.RoundTrip(req)
3468 if err != errClientConnUnusable {
3469 t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
3470 }
3471 slurp, err := ioutil.ReadAll(req.Body)
3472 if err != nil {
3473 t.Errorf("ReadAll = %v", err)
3474 }
3475 if string(slurp) != body {
3476 t.Errorf("Body = %q; want %q", slurp, body)
3477 }
3478 }
3479
3480 func TestClientConnPing(t *testing.T) {
3481 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
3482 defer st.Close()
3483 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3484 defer tr.CloseIdleConnections()
3485 ctx := context.Background()
3486 cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
3487 if err != nil {
3488 t.Fatal(err)
3489 }
3490 if err = cc.Ping(context.Background()); err != nil {
3491 t.Fatal(err)
3492 }
3493 }
3494
3495
3496
3497
3498
3499 func TestTransportCancelDataResponseRace(t *testing.T) {
3500 cancel := make(chan struct{})
3501 clientGotResponse := make(chan bool, 1)
3502
3503 const msg = "Hello."
3504 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3505 if strings.Contains(r.URL.Path, "/hello") {
3506 time.Sleep(50 * time.Millisecond)
3507 io.WriteString(w, msg)
3508 return
3509 }
3510 for i := 0; i < 50; i++ {
3511 io.WriteString(w, "Some data.")
3512 w.(http.Flusher).Flush()
3513 if i == 2 {
3514 <-clientGotResponse
3515 close(cancel)
3516 }
3517 time.Sleep(10 * time.Millisecond)
3518 }
3519 }, optOnlyServer)
3520 defer st.Close()
3521
3522 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3523 defer tr.CloseIdleConnections()
3524
3525 c := &http.Client{Transport: tr}
3526 req, _ := http.NewRequest("GET", st.ts.URL, nil)
3527 req.Cancel = cancel
3528 res, err := c.Do(req)
3529 clientGotResponse <- true
3530 if err != nil {
3531 t.Fatal(err)
3532 }
3533 if _, err = io.Copy(ioutil.Discard, res.Body); err == nil {
3534 t.Fatal("unexpected success")
3535 }
3536
3537 res, err = c.Get(st.ts.URL + "/hello")
3538 if err != nil {
3539 t.Fatal(err)
3540 }
3541 slurp, err := ioutil.ReadAll(res.Body)
3542 if err != nil {
3543 t.Fatal(err)
3544 }
3545 if string(slurp) != msg {
3546 t.Errorf("Got = %q; want %q", slurp, msg)
3547 }
3548 }
3549
3550
3551
3552 func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
3553 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3554 w.WriteHeader(200)
3555 io.WriteString(w, "body")
3556 }, optOnlyServer)
3557 defer st.Close()
3558
3559 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3560 defer tr.CloseIdleConnections()
3561
3562 req, _ := http.NewRequest("GET", st.ts.URL, nil)
3563 resp, err := tr.RoundTrip(req)
3564 if err != nil {
3565 t.Fatal(err)
3566 }
3567 if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
3568 t.Fatalf("error reading response body: %v", err)
3569 }
3570 if err := resp.Body.Close(); err != nil {
3571 t.Fatalf("error closing response body: %v", err)
3572 }
3573
3574
3575 req.Header = http.Header{}
3576 }
3577
3578 func TestTransportCloseAfterLostPing(t *testing.T) {
3579 clientDone := make(chan struct{})
3580 ct := newClientTester(t)
3581 ct.tr.PingTimeout = 1 * time.Second
3582 ct.tr.ReadIdleTimeout = 1 * time.Second
3583 ct.client = func() error {
3584 defer ct.cc.(*net.TCPConn).CloseWrite()
3585 defer close(clientDone)
3586 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3587 _, err := ct.tr.RoundTrip(req)
3588 if err == nil || !strings.Contains(err.Error(), "client connection lost") {
3589 return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
3590 }
3591 return nil
3592 }
3593 ct.server = func() error {
3594 ct.greet()
3595 <-clientDone
3596 return nil
3597 }
3598 ct.run()
3599 }
3600
3601 func TestTransportPingWriteBlocks(t *testing.T) {
3602 st := newServerTester(t,
3603 func(w http.ResponseWriter, r *http.Request) {},
3604 optOnlyServer,
3605 )
3606 defer st.Close()
3607 tr := &Transport{
3608 TLSClientConfig: tlsConfigInsecure,
3609 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
3610 s, c := net.Pipe()
3611 go func() {
3612
3613
3614
3615 var buf [1024]byte
3616 s.Read(buf[:])
3617 }()
3618 return c, nil
3619 },
3620 PingTimeout: 1 * time.Millisecond,
3621 ReadIdleTimeout: 1 * time.Millisecond,
3622 }
3623 defer tr.CloseIdleConnections()
3624 c := &http.Client{Transport: tr}
3625 _, err := c.Get(st.ts.URL)
3626 if err == nil {
3627 t.Fatalf("Get = nil, want error")
3628 }
3629 }
3630
3631 func TestTransportPingWhenReading(t *testing.T) {
3632 testCases := []struct {
3633 name string
3634 readIdleTimeout time.Duration
3635 deadline time.Duration
3636 expectedPingCount int
3637 }{
3638 {
3639 name: "two pings",
3640 readIdleTimeout: 100 * time.Millisecond,
3641 deadline: time.Second,
3642 expectedPingCount: 2,
3643 },
3644 {
3645 name: "zero ping",
3646 readIdleTimeout: time.Second,
3647 deadline: 200 * time.Millisecond,
3648 expectedPingCount: 0,
3649 },
3650 {
3651 name: "0 readIdleTimeout means no ping",
3652 readIdleTimeout: 0 * time.Millisecond,
3653 deadline: 500 * time.Millisecond,
3654 expectedPingCount: 0,
3655 },
3656 }
3657
3658 for _, tc := range testCases {
3659 tc := tc
3660 t.Run(tc.name, func(t *testing.T) {
3661 testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
3662 })
3663 }
3664 }
3665
3666 func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
3667 var pingCount int
3668 ct := newClientTester(t)
3669 ct.tr.ReadIdleTimeout = readIdleTimeout
3670
3671 ctx, cancel := context.WithTimeout(context.Background(), deadline)
3672 defer cancel()
3673 ct.client = func() error {
3674 defer ct.cc.(*net.TCPConn).CloseWrite()
3675 if runtime.GOOS == "plan9" {
3676
3677 defer ct.cc.(*net.TCPConn).Close()
3678 }
3679 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
3680 res, err := ct.tr.RoundTrip(req)
3681 if err != nil {
3682 return fmt.Errorf("RoundTrip: %v", err)
3683 }
3684 defer res.Body.Close()
3685 if res.StatusCode != 200 {
3686 return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
3687 }
3688 _, err = ioutil.ReadAll(res.Body)
3689 if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
3690 return nil
3691 }
3692
3693 cancel()
3694 return err
3695 }
3696
3697 ct.server = func() error {
3698 ct.greet()
3699 var buf bytes.Buffer
3700 enc := hpack.NewEncoder(&buf)
3701 var streamID uint32
3702 for {
3703 f, err := ct.fr.ReadFrame()
3704 if err != nil {
3705 select {
3706 case <-ctx.Done():
3707
3708
3709
3710 return nil
3711 default:
3712 return err
3713 }
3714 }
3715 switch f := f.(type) {
3716 case *WindowUpdateFrame, *SettingsFrame:
3717 case *HeadersFrame:
3718 if !f.HeadersEnded() {
3719 return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
3720 }
3721 enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
3722 ct.fr.WriteHeaders(HeadersFrameParam{
3723 StreamID: f.StreamID,
3724 EndHeaders: true,
3725 EndStream: false,
3726 BlockFragment: buf.Bytes(),
3727 })
3728 streamID = f.StreamID
3729 case *PingFrame:
3730 pingCount++
3731 if pingCount == expectedPingCount {
3732 if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
3733 return err
3734 }
3735 }
3736 if err := ct.fr.WritePing(true, f.Data); err != nil {
3737 return err
3738 }
3739 case *RSTStreamFrame:
3740 default:
3741 return fmt.Errorf("Unexpected client frame %v", f)
3742 }
3743 }
3744 }
3745 ct.run()
3746 }
3747
3748 func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
3749 ln := newLocalListener(t)
3750 defer ln.Close()
3751
3752 var (
3753 mu sync.Mutex
3754 count int
3755 conns []net.Conn
3756 )
3757 var wg sync.WaitGroup
3758 tr := &Transport{
3759 TLSClientConfig: tlsConfigInsecure,
3760 }
3761 tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
3762 mu.Lock()
3763 defer mu.Unlock()
3764 count++
3765 cc, err := net.Dial("tcp", ln.Addr().String())
3766 if err != nil {
3767 return nil, fmt.Errorf("dial error: %v", err)
3768 }
3769 conns = append(conns, cc)
3770 sc, err := ln.Accept()
3771 if err != nil {
3772 return nil, fmt.Errorf("accept error: %v", err)
3773 }
3774 conns = append(conns, sc)
3775 ct := &clientTester{
3776 t: t,
3777 tr: tr,
3778 cc: cc,
3779 sc: sc,
3780 fr: NewFramer(sc, sc),
3781 }
3782 wg.Add(1)
3783 go func(count int) {
3784 defer wg.Done()
3785 server(count, ct)
3786 }(count)
3787 return cc, nil
3788 }
3789
3790 client(tr)
3791 tr.CloseIdleConnections()
3792 ln.Close()
3793 for _, c := range conns {
3794 c.Close()
3795 }
3796 wg.Wait()
3797 }
3798
3799 func TestTransportRetryAfterGOAWAY(t *testing.T) {
3800 client := func(tr *Transport) {
3801 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3802 res, err := tr.RoundTrip(req)
3803 if res != nil {
3804 res.Body.Close()
3805 if got := res.Header.Get("Foo"); got != "bar" {
3806 err = fmt.Errorf("foo header = %q; want bar", got)
3807 }
3808 }
3809 if err != nil {
3810 t.Errorf("RoundTrip: %v", err)
3811 }
3812 }
3813
3814 server := func(count int, ct *clientTester) {
3815 switch count {
3816 case 1:
3817 ct.greet()
3818 hf, err := ct.firstHeaders()
3819 if err != nil {
3820 t.Errorf("server1 failed reading HEADERS: %v", err)
3821 return
3822 }
3823 t.Logf("server1 got %v", hf)
3824 if err := ct.fr.WriteGoAway(0 , ErrCodeNo, nil); err != nil {
3825 t.Errorf("server1 failed writing GOAWAY: %v", err)
3826 return
3827 }
3828 case 2:
3829 ct.greet()
3830 hf, err := ct.firstHeaders()
3831 if err != nil {
3832 t.Errorf("server2 failed reading HEADERS: %v", err)
3833 return
3834 }
3835 t.Logf("server2 got %v", hf)
3836
3837 var buf bytes.Buffer
3838 enc := hpack.NewEncoder(&buf)
3839 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3840 enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
3841 err = ct.fr.WriteHeaders(HeadersFrameParam{
3842 StreamID: hf.StreamID,
3843 EndHeaders: true,
3844 EndStream: false,
3845 BlockFragment: buf.Bytes(),
3846 })
3847 if err != nil {
3848 t.Errorf("server2 failed writing response HEADERS: %v", err)
3849 }
3850 default:
3851 t.Errorf("unexpected number of dials")
3852 return
3853 }
3854 }
3855
3856 testClientMultipleDials(t, client, server)
3857 }
3858
3859 func TestTransportRetryAfterRefusedStream(t *testing.T) {
3860 clientDone := make(chan struct{})
3861 client := func(tr *Transport) {
3862 defer close(clientDone)
3863 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3864 resp, err := tr.RoundTrip(req)
3865 if err != nil {
3866 t.Errorf("RoundTrip: %v", err)
3867 return
3868 }
3869 resp.Body.Close()
3870 if resp.StatusCode != 204 {
3871 t.Errorf("Status = %v; want 204", resp.StatusCode)
3872 return
3873 }
3874 }
3875
3876 server := func(_ int, ct *clientTester) {
3877 ct.greet()
3878 var buf bytes.Buffer
3879 enc := hpack.NewEncoder(&buf)
3880 var count int
3881 for {
3882 f, err := ct.fr.ReadFrame()
3883 if err != nil {
3884 select {
3885 case <-clientDone:
3886
3887
3888
3889 default:
3890 t.Error(err)
3891 }
3892 return
3893 }
3894 switch f := f.(type) {
3895 case *WindowUpdateFrame, *SettingsFrame:
3896 case *HeadersFrame:
3897 if !f.HeadersEnded() {
3898 t.Errorf("headers should have END_HEADERS be ended: %v", f)
3899 return
3900 }
3901 count++
3902 if count == 1 {
3903 ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
3904 } else {
3905 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
3906 ct.fr.WriteHeaders(HeadersFrameParam{
3907 StreamID: f.StreamID,
3908 EndHeaders: true,
3909 EndStream: true,
3910 BlockFragment: buf.Bytes(),
3911 })
3912 }
3913 default:
3914 t.Errorf("Unexpected client frame %v", f)
3915 return
3916 }
3917 }
3918 }
3919
3920 testClientMultipleDials(t, client, server)
3921 }
3922
3923 func TestTransportRetryHasLimit(t *testing.T) {
3924
3925 if testing.Short() {
3926 t.Skip("skipping long test in short mode")
3927 }
3928 retryBackoffHook = func(d time.Duration) *time.Timer {
3929 return time.NewTimer(0)
3930 }
3931 defer func() {
3932 retryBackoffHook = nil
3933 }()
3934 clientDone := make(chan struct{})
3935 ct := newClientTester(t)
3936 ct.client = func() error {
3937 defer ct.cc.(*net.TCPConn).CloseWrite()
3938 if runtime.GOOS == "plan9" {
3939
3940 defer ct.cc.(*net.TCPConn).Close()
3941 }
3942 defer close(clientDone)
3943 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3944 resp, err := ct.tr.RoundTrip(req)
3945 if err == nil {
3946 return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
3947 }
3948 t.Logf("expected error, got: %v", err)
3949 return nil
3950 }
3951 ct.server = func() error {
3952 ct.greet()
3953 for {
3954 f, err := ct.fr.ReadFrame()
3955 if err != nil {
3956 select {
3957 case <-clientDone:
3958
3959
3960
3961 return nil
3962 default:
3963 return err
3964 }
3965 }
3966 switch f := f.(type) {
3967 case *WindowUpdateFrame, *SettingsFrame:
3968 case *HeadersFrame:
3969 if !f.HeadersEnded() {
3970 return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
3971 }
3972 ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
3973 default:
3974 return fmt.Errorf("Unexpected client frame %v", f)
3975 }
3976 }
3977 }
3978 ct.run()
3979 }
3980
3981 func TestTransportResponseDataBeforeHeaders(t *testing.T) {
3982
3983
3984 log.SetOutput(ioutil.Discard)
3985 defer log.SetOutput(os.Stderr)
3986
3987 ct := newClientTester(t)
3988 ct.client = func() error {
3989 defer ct.cc.(*net.TCPConn).CloseWrite()
3990 if runtime.GOOS == "plan9" {
3991
3992 defer ct.cc.(*net.TCPConn).Close()
3993 }
3994 req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
3995
3996 _, err := ct.tr.RoundTrip(req)
3997 if err != nil {
3998 return fmt.Errorf("RoundTrip expected no error, got: %v", err)
3999 }
4000
4001 resp, err := ct.tr.RoundTrip(req)
4002 if err == nil {
4003 return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
4004 }
4005 if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
4006 return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
4007 }
4008 return nil
4009 }
4010 ct.server = func() error {
4011 ct.greet()
4012 for {
4013 f, err := ct.fr.ReadFrame()
4014 if err == io.EOF {
4015 return nil
4016 } else if err != nil {
4017 return err
4018 }
4019 switch f := f.(type) {
4020 case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame:
4021 case *HeadersFrame:
4022 switch f.StreamID {
4023 case 1:
4024
4025 var buf bytes.Buffer
4026 enc := hpack.NewEncoder(&buf)
4027 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
4028 ct.fr.WriteHeaders(HeadersFrameParam{
4029 StreamID: f.StreamID,
4030 EndHeaders: true,
4031 EndStream: true,
4032 BlockFragment: buf.Bytes(),
4033 })
4034 case 3:
4035 ct.fr.WriteData(f.StreamID, true, []byte("payload"))
4036 }
4037 default:
4038 return fmt.Errorf("Unexpected client frame %v", f)
4039 }
4040 }
4041 }
4042 ct.run()
4043 }
4044
4045 func TestTransportMaxFrameReadSize(t *testing.T) {
4046 for _, test := range []struct {
4047 maxReadFrameSize uint32
4048 want uint32
4049 }{{
4050 maxReadFrameSize: 64000,
4051 want: 64000,
4052 }, {
4053 maxReadFrameSize: 1024,
4054 want: minMaxFrameSize,
4055 }} {
4056 ct := newClientTester(t)
4057 ct.tr.MaxReadFrameSize = test.maxReadFrameSize
4058 ct.client = func() error {
4059 req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
4060 ct.tr.RoundTrip(req)
4061 return nil
4062 }
4063 ct.server = func() error {
4064 defer ct.cc.(*net.TCPConn).Close()
4065 ct.greet()
4066 var got uint32
4067 ct.settings.ForeachSetting(func(s Setting) error {
4068 switch s.ID {
4069 case SettingMaxFrameSize:
4070 got = s.Val
4071 }
4072 return nil
4073 })
4074 if got != test.want {
4075 t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
4076 }
4077 return nil
4078 }
4079 ct.run()
4080 }
4081 }
4082
4083 func TestTransportRequestsLowServerLimit(t *testing.T) {
4084 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4085 }, optOnlyServer, func(s *Server) {
4086 s.MaxConcurrentStreams = 1
4087 })
4088 defer st.Close()
4089
4090 var (
4091 connCountMu sync.Mutex
4092 connCount int
4093 )
4094 tr := &Transport{
4095 TLSClientConfig: tlsConfigInsecure,
4096 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
4097 connCountMu.Lock()
4098 defer connCountMu.Unlock()
4099 connCount++
4100 return tls.Dial(network, addr, cfg)
4101 },
4102 }
4103 defer tr.CloseIdleConnections()
4104
4105 const reqCount = 3
4106 for i := 0; i < reqCount; i++ {
4107 req, err := http.NewRequest("GET", st.ts.URL, nil)
4108 if err != nil {
4109 t.Fatal(err)
4110 }
4111 res, err := tr.RoundTrip(req)
4112 if err != nil {
4113 t.Fatal(err)
4114 }
4115 if got, want := res.StatusCode, 200; got != want {
4116 t.Errorf("StatusCode = %v; want %v", got, want)
4117 }
4118 if res != nil && res.Body != nil {
4119 res.Body.Close()
4120 }
4121 }
4122
4123 if connCount != 1 {
4124 t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
4125 }
4126 }
4127
4128
4129 func TestTransportRequestsStallAtServerLimit(t *testing.T) {
4130 const maxConcurrent = 2
4131
4132 greet := make(chan struct{})
4133 gotRequest := make(chan struct{})
4134 clientDone := make(chan struct{})
4135 cancelClientRequest := make(chan struct{})
4136
4137
4138 var wg sync.WaitGroup
4139 errs := make(chan error, 100)
4140 defer func() {
4141 wg.Wait()
4142 close(errs)
4143 for err := range errs {
4144 t.Error(err)
4145 }
4146 }()
4147
4148
4149
4150
4151
4152
4153
4154 wg.Add(1)
4155 unblockClient := make(chan struct{})
4156 clientRequestCancelled := make(chan struct{})
4157 unblockServer := make(chan struct{})
4158 go func() {
4159 defer wg.Done()
4160
4161 for k := 0; k < maxConcurrent; k++ {
4162 <-gotRequest
4163 }
4164
4165 close(unblockClient)
4166 <-clientRequestCancelled
4167
4168
4169 time.Sleep(50 * time.Millisecond)
4170 select {
4171 case <-gotRequest:
4172 errs <- errors.New("last request did not stall")
4173 close(unblockServer)
4174 return
4175 default:
4176 }
4177 close(unblockServer)
4178
4179 <-gotRequest
4180 }()
4181
4182 ct := newClientTester(t)
4183 ct.tr.StrictMaxConcurrentStreams = true
4184 ct.client = func() error {
4185 var wg sync.WaitGroup
4186 defer func() {
4187 wg.Wait()
4188 close(clientDone)
4189 ct.cc.(*net.TCPConn).CloseWrite()
4190 if runtime.GOOS == "plan9" {
4191
4192 ct.cc.(*net.TCPConn).Close()
4193 }
4194 }()
4195 for k := 0; k < maxConcurrent+2; k++ {
4196 wg.Add(1)
4197 go func(k int) {
4198 defer wg.Done()
4199
4200
4201
4202
4203 if k > 0 {
4204 <-greet
4205 }
4206
4207 if k >= maxConcurrent {
4208 <-unblockClient
4209 }
4210 body := newStaticCloseChecker("")
4211 req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
4212 if k == maxConcurrent {
4213
4214 req.Cancel = cancelClientRequest
4215 close(cancelClientRequest)
4216 _, err := ct.tr.RoundTrip(req)
4217 close(clientRequestCancelled)
4218 if err == nil {
4219 errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
4220 return
4221 }
4222 } else {
4223 resp, err := ct.tr.RoundTrip(req)
4224 if err != nil {
4225 errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
4226 return
4227 }
4228 ioutil.ReadAll(resp.Body)
4229 resp.Body.Close()
4230 if resp.StatusCode != 204 {
4231 errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
4232 return
4233 }
4234 }
4235 if err := body.isClosed(); err != nil {
4236 errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
4237 }
4238 }(k)
4239 }
4240 return nil
4241 }
4242
4243 ct.server = func() error {
4244 var wg sync.WaitGroup
4245 defer wg.Wait()
4246
4247 ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
4248
4249
4250 var buf bytes.Buffer
4251 enc := hpack.NewEncoder(&buf)
4252 writeResp := make(chan uint32, maxConcurrent+1)
4253
4254 wg.Add(1)
4255 go func() {
4256 defer wg.Done()
4257 <-unblockServer
4258 for id := range writeResp {
4259 buf.Reset()
4260 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
4261 ct.fr.WriteHeaders(HeadersFrameParam{
4262 StreamID: id,
4263 EndHeaders: true,
4264 EndStream: true,
4265 BlockFragment: buf.Bytes(),
4266 })
4267 }
4268 }()
4269
4270
4271 var nreq int
4272 for {
4273 f, err := ct.fr.ReadFrame()
4274 if err != nil {
4275 select {
4276 case <-clientDone:
4277
4278 return nil
4279 default:
4280 return err
4281 }
4282 }
4283 switch f := f.(type) {
4284 case *WindowUpdateFrame:
4285 case *SettingsFrame:
4286
4287 close(greet)
4288 case *HeadersFrame:
4289 if !f.HeadersEnded() {
4290 return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
4291 }
4292 gotRequest <- struct{}{}
4293 nreq++
4294 writeResp <- f.StreamID
4295 if nreq == maxConcurrent+1 {
4296 close(writeResp)
4297 }
4298 case *DataFrame:
4299 default:
4300 return fmt.Errorf("Unexpected client frame %v", f)
4301 }
4302 }
4303 }
4304
4305 ct.run()
4306 }
4307
4308 func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
4309 ct := newClientTester(t)
4310 var reqSize, resSize uint32 = 8192, 16384
4311 ct.tr.MaxDecoderHeaderTableSize = reqSize
4312 ct.client = func() error {
4313 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4314 cc, err := ct.tr.NewClientConn(ct.cc)
4315 if err != nil {
4316 return err
4317 }
4318 _, err = cc.RoundTrip(req)
4319 if err != nil {
4320 return err
4321 }
4322 if got, want := cc.peerMaxHeaderTableSize, resSize; got != want {
4323 return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want)
4324 }
4325 return nil
4326 }
4327 ct.server = func() error {
4328 buf := make([]byte, len(ClientPreface))
4329 _, err := io.ReadFull(ct.sc, buf)
4330 if err != nil {
4331 return fmt.Errorf("reading client preface: %v", err)
4332 }
4333 f, err := ct.fr.ReadFrame()
4334 if err != nil {
4335 return err
4336 }
4337 sf, ok := f.(*SettingsFrame)
4338 if !ok {
4339 ct.t.Fatalf("wanted client settings frame; got %v", f)
4340 _ = sf
4341 }
4342 var found bool
4343 err = sf.ForeachSetting(func(s Setting) error {
4344 if s.ID == SettingHeaderTableSize {
4345 found = true
4346 if got, want := s.Val, reqSize; got != want {
4347 return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want)
4348 }
4349 }
4350 return nil
4351 })
4352 if err != nil {
4353 return err
4354 }
4355 if !found {
4356 return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting")
4357 }
4358 if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil {
4359 ct.t.Fatal(err)
4360 }
4361 if err := ct.fr.WriteSettingsAck(); err != nil {
4362 ct.t.Fatal(err)
4363 }
4364
4365 for {
4366 f, err := ct.fr.ReadFrame()
4367 if err != nil {
4368 return err
4369 }
4370 switch f := f.(type) {
4371 case *HeadersFrame:
4372 var buf bytes.Buffer
4373 enc := hpack.NewEncoder(&buf)
4374 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
4375 ct.fr.WriteHeaders(HeadersFrameParam{
4376 StreamID: f.StreamID,
4377 EndHeaders: true,
4378 EndStream: true,
4379 BlockFragment: buf.Bytes(),
4380 })
4381 return nil
4382 }
4383 }
4384 }
4385 ct.run()
4386 }
4387
4388 func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
4389 ct := newClientTester(t)
4390 var peerAdvertisedMaxHeaderTableSize uint32 = 16384
4391 ct.tr.MaxEncoderHeaderTableSize = 8192
4392 ct.client = func() error {
4393 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4394 cc, err := ct.tr.NewClientConn(ct.cc)
4395 if err != nil {
4396 return err
4397 }
4398 _, err = cc.RoundTrip(req)
4399 if err != nil {
4400 return err
4401 }
4402 if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want {
4403 return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
4404 }
4405 return nil
4406 }
4407 ct.server = func() error {
4408 buf := make([]byte, len(ClientPreface))
4409 _, err := io.ReadFull(ct.sc, buf)
4410 if err != nil {
4411 return fmt.Errorf("reading client preface: %v", err)
4412 }
4413 f, err := ct.fr.ReadFrame()
4414 if err != nil {
4415 return err
4416 }
4417 sf, ok := f.(*SettingsFrame)
4418 if !ok {
4419 ct.t.Fatalf("wanted client settings frame; got %v", f)
4420 _ = sf
4421 }
4422 if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil {
4423 ct.t.Fatal(err)
4424 }
4425 if err := ct.fr.WriteSettingsAck(); err != nil {
4426 ct.t.Fatal(err)
4427 }
4428
4429 for {
4430 f, err := ct.fr.ReadFrame()
4431 if err != nil {
4432 return err
4433 }
4434 switch f := f.(type) {
4435 case *HeadersFrame:
4436 var buf bytes.Buffer
4437 enc := hpack.NewEncoder(&buf)
4438 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
4439 ct.fr.WriteHeaders(HeadersFrameParam{
4440 StreamID: f.StreamID,
4441 EndHeaders: true,
4442 EndStream: true,
4443 BlockFragment: buf.Bytes(),
4444 })
4445 return nil
4446 }
4447 }
4448 }
4449 ct.run()
4450 }
4451
4452 func TestAuthorityAddr(t *testing.T) {
4453 tests := []struct {
4454 scheme, authority string
4455 want string
4456 }{
4457 {"http", "foo.com", "foo.com:80"},
4458 {"https", "foo.com", "foo.com:443"},
4459 {"https", "foo.com:", "foo.com:443"},
4460 {"https", "foo.com:1234", "foo.com:1234"},
4461 {"https", "1.2.3.4:1234", "1.2.3.4:1234"},
4462 {"https", "1.2.3.4", "1.2.3.4:443"},
4463 {"https", "1.2.3.4:", "1.2.3.4:443"},
4464 {"https", "[::1]:1234", "[::1]:1234"},
4465 {"https", "[::1]", "[::1]:443"},
4466 {"https", "[::1]:", "[::1]:443"},
4467 }
4468 for _, tt := range tests {
4469 got := authorityAddr(tt.scheme, tt.authority)
4470 if got != tt.want {
4471 t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
4472 }
4473 }
4474 }
4475
4476
4477
4478 func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
4479 megabyteZero := make([]byte, 1<<20)
4480
4481 writeErr := make(chan error, 1)
4482
4483 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4484 w.(http.Flusher).Flush()
4485 var sum int64
4486 for i := 0; i < 100; i++ {
4487 n, err := w.Write(megabyteZero)
4488 sum += int64(n)
4489 if err != nil {
4490 writeErr <- err
4491 return
4492 }
4493 }
4494 t.Logf("wrote all %d bytes", sum)
4495 writeErr <- nil
4496 }, optOnlyServer)
4497 defer st.Close()
4498
4499 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4500 defer tr.CloseIdleConnections()
4501 c := &http.Client{Transport: tr}
4502 res, err := c.Get(st.ts.URL)
4503 if err != nil {
4504 t.Fatal(err)
4505 }
4506 var buf [1]byte
4507 if _, err := res.Body.Read(buf[:]); err != nil {
4508 t.Error(err)
4509 }
4510 if err := res.Body.Close(); err != nil {
4511 t.Error(err)
4512 }
4513
4514 trb, ok := res.Body.(transportResponseBody)
4515 if !ok {
4516 t.Fatalf("res.Body = %T; want transportResponseBody", res.Body)
4517 }
4518 if trb.cs.bufPipe.b != nil {
4519 t.Errorf("response body pipe is still open")
4520 }
4521
4522 gotErr := <-writeErr
4523 if gotErr == nil {
4524 t.Errorf("Handler unexpectedly managed to write its entire response without getting an error")
4525 } else if gotErr != errStreamClosed {
4526 t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr)
4527 }
4528 }
4529
4530
4531
4532 func TestTransportNoBodyMeansNoDATA(t *testing.T) {
4533 ct := newClientTester(t)
4534
4535 unblockClient := make(chan bool)
4536
4537 ct.client = func() error {
4538 req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
4539 ct.tr.RoundTrip(req)
4540 <-unblockClient
4541 return nil
4542 }
4543 ct.server = func() error {
4544 defer close(unblockClient)
4545 defer ct.cc.(*net.TCPConn).Close()
4546 ct.greet()
4547
4548 for {
4549 f, err := ct.fr.ReadFrame()
4550 if err != nil {
4551 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
4552 }
4553 switch f := f.(type) {
4554 default:
4555 return fmt.Errorf("Got %T; want HeadersFrame", f)
4556 case *WindowUpdateFrame, *SettingsFrame:
4557 continue
4558 case *HeadersFrame:
4559 if !f.StreamEnded() {
4560 return fmt.Errorf("got headers frame without END_STREAM")
4561 }
4562 return nil
4563 }
4564 }
4565 }
4566 ct.run()
4567 }
4568
4569 func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
4570 defer disableGoroutineTracking()()
4571 b.ReportAllocs()
4572 st := newServerTester(b,
4573 func(w http.ResponseWriter, r *http.Request) {
4574 for i := 0; i < nResHeader; i++ {
4575 name := fmt.Sprint("A-", i)
4576 w.Header().Set(name, "*")
4577 }
4578 },
4579 optOnlyServer,
4580 optQuiet,
4581 )
4582 defer st.Close()
4583
4584 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4585 defer tr.CloseIdleConnections()
4586
4587 req, err := http.NewRequest("GET", st.ts.URL, nil)
4588 if err != nil {
4589 b.Fatal(err)
4590 }
4591
4592 for i := 0; i < nReqHeaders; i++ {
4593 name := fmt.Sprint("A-", i)
4594 req.Header.Set(name, "*")
4595 }
4596
4597 b.ResetTimer()
4598
4599 for i := 0; i < b.N; i++ {
4600 res, err := tr.RoundTrip(req)
4601 if err != nil {
4602 if res != nil {
4603 res.Body.Close()
4604 }
4605 b.Fatalf("RoundTrip err = %v; want nil", err)
4606 }
4607 res.Body.Close()
4608 if res.StatusCode != http.StatusOK {
4609 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
4610 }
4611 }
4612 }
4613
4614 type infiniteReader struct{}
4615
4616 func (r infiniteReader) Read(b []byte) (int, error) {
4617 return len(b), nil
4618 }
4619
4620
4621
4622 func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
4623 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4624 w.WriteHeader(http.StatusOK)
4625 }, optOnlyServer)
4626 defer st.Close()
4627
4628 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4629 defer tr.CloseIdleConnections()
4630
4631
4632 req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{})
4633 res, err := tr.RoundTrip(req)
4634 if err != nil {
4635 t.Fatal(err)
4636 }
4637 if res.StatusCode != http.StatusOK {
4638 t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
4639 }
4640 }
4641
4642
4643
4644 func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
4645 ct := newClientTester(t)
4646 ct.client = func() error {
4647 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4648 _, err := ct.tr.RoundTrip(req)
4649 const substr = "malformed response from server: missing status pseudo header"
4650 if !strings.Contains(fmt.Sprint(err), substr) {
4651 return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
4652 }
4653 return nil
4654 }
4655 ct.server = func() error {
4656 ct.greet()
4657 var buf bytes.Buffer
4658 enc := hpack.NewEncoder(&buf)
4659
4660 for {
4661 f, err := ct.fr.ReadFrame()
4662 if err != nil {
4663 return err
4664 }
4665 switch f := f.(type) {
4666 case *HeadersFrame:
4667 enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"})
4668 ct.fr.WriteHeaders(HeadersFrameParam{
4669 StreamID: f.StreamID,
4670 EndHeaders: true,
4671 EndStream: false,
4672 BlockFragment: buf.Bytes(),
4673 })
4674 ct.fr.WriteData(f.StreamID, true, []byte("payload"))
4675 return nil
4676 }
4677 }
4678 }
4679 ct.run()
4680 }
4681
4682 func BenchmarkClientRequestHeaders(b *testing.B) {
4683 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
4684 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) })
4685 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) })
4686 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) })
4687 }
4688
4689 func BenchmarkClientResponseHeaders(b *testing.B) {
4690 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
4691 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) })
4692 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) })
4693 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
4694 }
4695
4696 func BenchmarkDownloadFrameSize(b *testing.B) {
4697 b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
4698 b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
4699 b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
4700 b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
4701 b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
4702 }
4703 func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
4704 defer disableGoroutineTracking()()
4705 const transferSize = 1024 * 1024 * 1024
4706 b.ReportAllocs()
4707 st := newServerTester(b,
4708 func(w http.ResponseWriter, r *http.Request) {
4709
4710 w.Header().Set("Content-Length", strconv.Itoa(transferSize))
4711 w.Header().Set("Content-Transfer-Encoding", "binary")
4712 var data [1024 * 1024]byte
4713 for i := 0; i < transferSize/(1024*1024); i++ {
4714 w.Write(data[:])
4715 }
4716 }, optQuiet,
4717 )
4718 defer st.Close()
4719
4720 tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
4721 defer tr.CloseIdleConnections()
4722
4723 req, err := http.NewRequest("GET", st.ts.URL, nil)
4724 if err != nil {
4725 b.Fatal(err)
4726 }
4727
4728 b.N = 3
4729 b.SetBytes(transferSize)
4730 b.ResetTimer()
4731
4732 for i := 0; i < b.N; i++ {
4733 res, err := tr.RoundTrip(req)
4734 if err != nil {
4735 if res != nil {
4736 res.Body.Close()
4737 }
4738 b.Fatalf("RoundTrip err = %v; want nil", err)
4739 }
4740 data, _ := io.ReadAll(res.Body)
4741 if len(data) != transferSize {
4742 b.Fatalf("Response length invalid")
4743 }
4744 res.Body.Close()
4745 if res.StatusCode != http.StatusOK {
4746 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
4747 }
4748 }
4749 }
4750
4751 func activeStreams(cc *ClientConn) int {
4752 count := 0
4753 cc.mu.Lock()
4754 defer cc.mu.Unlock()
4755 for _, cs := range cc.streams {
4756 select {
4757 case <-cs.abort:
4758 default:
4759 count++
4760 }
4761 }
4762 return count
4763 }
4764
4765 type closeMode int
4766
4767 const (
4768 closeAtHeaders closeMode = iota
4769 closeAtBody
4770 shutdown
4771 shutdownCancel
4772 )
4773
4774
4775 func testClientConnClose(t *testing.T, closeMode closeMode) {
4776 clientDone := make(chan struct{})
4777 defer close(clientDone)
4778 handlerDone := make(chan struct{})
4779 closeDone := make(chan struct{})
4780 beforeHeader := func() {}
4781 bodyWrite := func(w http.ResponseWriter) {}
4782 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4783 defer close(handlerDone)
4784 beforeHeader()
4785 w.WriteHeader(http.StatusOK)
4786 w.(http.Flusher).Flush()
4787 bodyWrite(w)
4788 select {
4789 case <-w.(http.CloseNotifier).CloseNotify():
4790
4791 if closeMode == shutdown || closeMode == shutdownCancel {
4792 t.Error("expected request to complete")
4793 }
4794 case <-clientDone:
4795 if closeMode == closeAtHeaders || closeMode == closeAtBody {
4796 t.Error("expected connection closed by client")
4797 }
4798 }
4799 }, optOnlyServer)
4800 defer st.Close()
4801 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4802 defer tr.CloseIdleConnections()
4803 ctx := context.Background()
4804 cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
4805 req, err := http.NewRequest("GET", st.ts.URL, nil)
4806 if err != nil {
4807 t.Fatal(err)
4808 }
4809 if closeMode == closeAtHeaders {
4810 beforeHeader = func() {
4811 if err := cc.Close(); err != nil {
4812 t.Error(err)
4813 }
4814 close(closeDone)
4815 }
4816 }
4817 var sendBody chan struct{}
4818 if closeMode == closeAtBody {
4819 sendBody = make(chan struct{})
4820 bodyWrite = func(w http.ResponseWriter) {
4821 <-sendBody
4822 b := make([]byte, 32)
4823 w.Write(b)
4824 w.(http.Flusher).Flush()
4825 if err := cc.Close(); err != nil {
4826 t.Errorf("unexpected ClientConn close error: %v", err)
4827 }
4828 close(closeDone)
4829 w.Write(b)
4830 w.(http.Flusher).Flush()
4831 }
4832 }
4833 res, err := cc.RoundTrip(req)
4834 if res != nil {
4835 defer res.Body.Close()
4836 }
4837 if closeMode == closeAtHeaders {
4838 got := fmt.Sprint(err)
4839 want := "http2: client connection force closed via ClientConn.Close"
4840 if got != want {
4841 t.Fatalf("RoundTrip error = %v, want %v", got, want)
4842 }
4843 } else {
4844 if err != nil {
4845 t.Fatalf("RoundTrip: %v", err)
4846 }
4847 if got, want := activeStreams(cc), 1; got != want {
4848 t.Errorf("got %d active streams, want %d", got, want)
4849 }
4850 }
4851 switch closeMode {
4852 case shutdownCancel:
4853 if err = cc.Shutdown(canceledCtx); err != context.Canceled {
4854 t.Errorf("got %v, want %v", err, context.Canceled)
4855 }
4856 if cc.closing == false {
4857 t.Error("expected closing to be true")
4858 }
4859 if cc.CanTakeNewRequest() == true {
4860 t.Error("CanTakeNewRequest to return false")
4861 }
4862 if v, want := len(cc.streams), 1; v != want {
4863 t.Errorf("expected %d active streams, got %d", want, v)
4864 }
4865 clientDone <- struct{}{}
4866 <-handlerDone
4867 case shutdown:
4868 wait := make(chan struct{})
4869 shutdownEnterWaitStateHook = func() {
4870 close(wait)
4871 shutdownEnterWaitStateHook = func() {}
4872 }
4873 defer func() { shutdownEnterWaitStateHook = func() {} }()
4874 shutdown := make(chan struct{}, 1)
4875 go func() {
4876 if err = cc.Shutdown(context.Background()); err != nil {
4877 t.Error(err)
4878 }
4879 close(shutdown)
4880 }()
4881
4882 <-wait
4883 cc.mu.Lock()
4884 if cc.closing == false {
4885 t.Error("expected closing to be true")
4886 }
4887 cc.mu.Unlock()
4888 if cc.CanTakeNewRequest() == true {
4889 t.Error("CanTakeNewRequest to return false")
4890 }
4891 if got, want := activeStreams(cc), 1; got != want {
4892 t.Errorf("got %d active streams, want %d", got, want)
4893 }
4894
4895 clientDone <- struct{}{}
4896
4897 select {
4898 case <-shutdown:
4899 case <-time.After(2 * time.Second):
4900 t.Fatal("expected server connection to close")
4901 }
4902 case closeAtHeaders, closeAtBody:
4903 if closeMode == closeAtBody {
4904 go close(sendBody)
4905 if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
4906 t.Error("expected a Copy error, got nil")
4907 }
4908 }
4909 <-closeDone
4910 if got, want := activeStreams(cc), 0; got != want {
4911 t.Errorf("got %d active streams, want %d", got, want)
4912 }
4913
4914 select {
4915 case <-handlerDone:
4916 case <-time.After(2 * time.Second):
4917 t.Fatal("expected server connection to close")
4918 }
4919 }
4920 }
4921
4922
4923
4924
4925 func TestClientConnCloseAtHeaders(t *testing.T) {
4926 testClientConnClose(t, closeAtHeaders)
4927 }
4928
4929
4930
4931 func TestClientConnCloseAtBody(t *testing.T) {
4932 testClientConnClose(t, closeAtBody)
4933 }
4934
4935
4936
4937 func TestClientConnShutdown(t *testing.T) {
4938 testClientConnClose(t, shutdown)
4939 }
4940
4941
4942
4943
4944 func TestClientConnShutdownCancel(t *testing.T) {
4945 testClientConnClose(t, shutdownCancel)
4946 }
4947
4948
4949
4950
4951
4952
4953 func TestTransportUsesGetBodyWhenPresent(t *testing.T) {
4954 calls := 0
4955 someBody := func() io.ReadCloser {
4956 return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))}
4957 }
4958 req := &http.Request{
4959 Body: someBody(),
4960 GetBody: func() (io.ReadCloser, error) {
4961 calls++
4962 return someBody(), nil
4963 },
4964 }
4965
4966 req2, err := shouldRetryRequest(req, errClientConnUnusable)
4967 if err != nil {
4968 t.Fatal(err)
4969 }
4970 if calls != 1 {
4971 t.Errorf("Calls = %d; want 1", calls)
4972 }
4973 if req2 == req {
4974 t.Error("req2 changed")
4975 }
4976 if req2 == nil {
4977 t.Fatal("req2 is nil")
4978 }
4979 if req2.Body == nil {
4980 t.Fatal("req2.Body is nil")
4981 }
4982 if req2.GetBody == nil {
4983 t.Fatal("req2.GetBody is nil")
4984 }
4985 if req2.Body == req.Body {
4986 t.Error("req2.Body unchanged")
4987 }
4988 }
4989
4990
4991
4992 func TestNoDialH2RoundTripperType(t *testing.T) {
4993 t1 := new(http.Transport)
4994 t2 := new(Transport)
4995 rt := noDialH2RoundTripper{t2}
4996 if err := registerHTTPSProtocol(t1, rt); err != nil {
4997 t.Fatal(err)
4998 }
4999 rv := reflect.ValueOf(rt)
5000 if rv.Type().Kind() != reflect.Struct {
5001 t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind())
5002 }
5003 if n := rv.Type().NumField(); n != 1 {
5004 t.Fatalf("fields = %d; net/http expects 1", n)
5005 }
5006 v := rv.Field(0)
5007 if _, ok := v.Interface().(*Transport); !ok {
5008 t.Fatalf("wrong kind %T; want *Transport", v.Interface())
5009 }
5010 }
5011
5012 type errReader struct {
5013 body []byte
5014 err error
5015 }
5016
5017 func (r *errReader) Read(p []byte) (int, error) {
5018 if len(r.body) > 0 {
5019 n := copy(p, r.body)
5020 r.body = r.body[n:]
5021 return n, nil
5022 }
5023 return 0, r.err
5024 }
5025
5026 func testTransportBodyReadError(t *testing.T, body []byte) {
5027 if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
5028
5029
5030
5031
5032
5033 t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS)
5034 }
5035 clientDone := make(chan struct{})
5036 ct := newClientTester(t)
5037 ct.client = func() error {
5038 defer ct.cc.(*net.TCPConn).CloseWrite()
5039 if runtime.GOOS == "plan9" {
5040
5041 defer ct.cc.(*net.TCPConn).Close()
5042 }
5043 defer close(clientDone)
5044
5045 checkNoStreams := func() error {
5046 cp, ok := ct.tr.connPool().(*clientConnPool)
5047 if !ok {
5048 return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
5049 }
5050 cp.mu.Lock()
5051 defer cp.mu.Unlock()
5052 conns, ok := cp.conns["dummy.tld:443"]
5053 if !ok {
5054 return fmt.Errorf("missing connection")
5055 }
5056 if len(conns) != 1 {
5057 return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
5058 }
5059 if activeStreams(conns[0]) != 0 {
5060 return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
5061 }
5062 return nil
5063 }
5064 bodyReadError := errors.New("body read error")
5065 body := &errReader{body, bodyReadError}
5066 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
5067 if err != nil {
5068 return err
5069 }
5070 _, err = ct.tr.RoundTrip(req)
5071 if err != bodyReadError {
5072 return fmt.Errorf("err = %v; want %v", err, bodyReadError)
5073 }
5074 if err = checkNoStreams(); err != nil {
5075 return err
5076 }
5077 return nil
5078 }
5079 ct.server = func() error {
5080 ct.greet()
5081 var receivedBody []byte
5082 var resetCount int
5083 for {
5084 f, err := ct.fr.ReadFrame()
5085 t.Logf("server: ReadFrame = %v, %v", f, err)
5086 if err != nil {
5087 select {
5088 case <-clientDone:
5089
5090
5091
5092 if bytes.Compare(receivedBody, body) != 0 {
5093 return fmt.Errorf("body: %q; expected %q", receivedBody, body)
5094 }
5095 if resetCount != 1 {
5096 return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
5097 }
5098 return nil
5099 default:
5100 return err
5101 }
5102 }
5103 switch f := f.(type) {
5104 case *WindowUpdateFrame, *SettingsFrame:
5105 case *HeadersFrame:
5106 case *DataFrame:
5107 receivedBody = append(receivedBody, f.Data()...)
5108 case *RSTStreamFrame:
5109 resetCount++
5110 default:
5111 return fmt.Errorf("Unexpected client frame %v", f)
5112 }
5113 }
5114 }
5115 ct.run()
5116 }
5117
5118 func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
5119 func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }
5120
5121
5122
5123
5124 func TestTransportBodyEagerEndStream(t *testing.T) {
5125 const reqBody = "some request body"
5126 const resBody = "some response body"
5127
5128 ct := newClientTester(t)
5129 ct.client = func() error {
5130 defer ct.cc.(*net.TCPConn).CloseWrite()
5131 if runtime.GOOS == "plan9" {
5132
5133 defer ct.cc.(*net.TCPConn).Close()
5134 }
5135 body := strings.NewReader(reqBody)
5136 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
5137 if err != nil {
5138 return err
5139 }
5140 _, err = ct.tr.RoundTrip(req)
5141 if err != nil {
5142 return err
5143 }
5144 return nil
5145 }
5146 ct.server = func() error {
5147 ct.greet()
5148
5149 for {
5150 f, err := ct.fr.ReadFrame()
5151 if err != nil {
5152 return err
5153 }
5154
5155 switch f := f.(type) {
5156 case *WindowUpdateFrame, *SettingsFrame:
5157 case *HeadersFrame:
5158 case *DataFrame:
5159 if !f.StreamEnded() {
5160 ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
5161 return fmt.Errorf("data frame without END_STREAM %v", f)
5162 }
5163 var buf bytes.Buffer
5164 enc := hpack.NewEncoder(&buf)
5165 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
5166 ct.fr.WriteHeaders(HeadersFrameParam{
5167 StreamID: f.Header().StreamID,
5168 EndHeaders: true,
5169 EndStream: false,
5170 BlockFragment: buf.Bytes(),
5171 })
5172 ct.fr.WriteData(f.StreamID, true, []byte(resBody))
5173 return nil
5174 case *RSTStreamFrame:
5175 default:
5176 return fmt.Errorf("Unexpected client frame %v", f)
5177 }
5178 }
5179 }
5180 ct.run()
5181 }
5182
5183 type chunkReader struct {
5184 chunks [][]byte
5185 }
5186
5187 func (r *chunkReader) Read(p []byte) (int, error) {
5188 if len(r.chunks) > 0 {
5189 n := copy(p, r.chunks[0])
5190 r.chunks = r.chunks[1:]
5191 return n, nil
5192 }
5193 panic("shouldn't read this many times")
5194 }
5195
5196
5197
5198
5199
5200
5201
5202
5203
5204 func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) {
5205 body := &chunkReader{[][]byte{
5206 []byte("123"),
5207 []byte("456"),
5208 }}
5209 testTransportBodyLargerThanSpecifiedContentLength(t, body, 3)
5210 }
5211
5212 func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
5213 body := &chunkReader{[][]byte{
5214 []byte("123"),
5215 }}
5216 testTransportBodyLargerThanSpecifiedContentLength(t, body, 2)
5217 }
5218
5219 func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) {
5220 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5221 r.Body.Read(make([]byte, 6))
5222 }, optOnlyServer)
5223 defer st.Close()
5224
5225 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
5226 defer tr.CloseIdleConnections()
5227
5228 req, _ := http.NewRequest("POST", st.ts.URL, body)
5229 req.ContentLength = contentLen
5230 _, err := tr.RoundTrip(req)
5231 if err != errReqBodyTooLong {
5232 t.Fatalf("expected %v, got %v", errReqBodyTooLong, err)
5233 }
5234 }
5235
5236 func TestClientConnTooIdle(t *testing.T) {
5237 tests := []struct {
5238 cc func() *ClientConn
5239 want bool
5240 }{
5241 {
5242 func() *ClientConn {
5243 return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
5244 },
5245 true,
5246 },
5247 {
5248 func() *ClientConn {
5249 return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}}
5250 },
5251 false,
5252 },
5253 {
5254 func() *ClientConn {
5255 return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
5256 },
5257 false,
5258 },
5259 {
5260 func() *ClientConn {
5261 return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)}
5262 },
5263 false,
5264 },
5265 }
5266 for i, tt := range tests {
5267 got := tt.cc().tooIdleLocked()
5268 if got != tt.want {
5269 t.Errorf("%d. got %v; want %v", i, got, tt.want)
5270 }
5271 }
5272 }
5273
5274 type fakeConnErr struct {
5275 net.Conn
5276 writeErr error
5277 closed bool
5278 }
5279
5280 func (fce *fakeConnErr) Write(b []byte) (n int, err error) {
5281 return 0, fce.writeErr
5282 }
5283
5284 func (fce *fakeConnErr) Close() error {
5285 fce.closed = true
5286 return nil
5287 }
5288
5289
5290 func TestTransportNewClientConnCloseOnWriteError(t *testing.T) {
5291 tr := &Transport{}
5292 writeErr := errors.New("write error")
5293 fakeConn := &fakeConnErr{writeErr: writeErr}
5294 _, err := tr.NewClientConn(fakeConn)
5295 if err != writeErr {
5296 t.Fatalf("expected %v, got %v", writeErr, err)
5297 }
5298 if !fakeConn.closed {
5299 t.Error("expected closed conn")
5300 }
5301 }
5302
5303 func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
5304 req, err := http.NewRequest("GET", "https://dummy.tld/", nil)
5305 if err != nil {
5306 t.Fatal(err)
5307 }
5308 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
5309 defer st.Close()
5310
5311 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
5312 defer tr.CloseIdleConnections()
5313 ctx := context.Background()
5314 cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
5315 if err != nil {
5316 t.Fatal(err)
5317 }
5318
5319 writeErr := errors.New("write error")
5320 cc.wmu.Lock()
5321 cc.werr = writeErr
5322 cc.wmu.Unlock()
5323
5324 _, err = cc.RoundTrip(req)
5325 if err != writeErr {
5326 t.Fatalf("expected %v, got %v", writeErr, err)
5327 }
5328
5329 cc.mu.Lock()
5330 closed := cc.closed
5331 cc.mu.Unlock()
5332 if !closed {
5333 t.Fatal("expected closed")
5334 }
5335 }
5336
5337
5338
5339
5340 func TestTransportBodyRewindRace(t *testing.T) {
5341 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5342 w.Header().Set("Connection", "close")
5343 w.WriteHeader(http.StatusOK)
5344 return
5345 }, optOnlyServer)
5346 defer st.Close()
5347
5348 tr := &http.Transport{
5349 TLSClientConfig: tlsConfigInsecure,
5350 MaxConnsPerHost: 1,
5351 }
5352 err := ConfigureTransport(tr)
5353 if err != nil {
5354 t.Fatal(err)
5355 }
5356 client := &http.Client{
5357 Transport: tr,
5358 }
5359
5360 const clients = 50
5361
5362 var wg sync.WaitGroup
5363 wg.Add(clients)
5364 for i := 0; i < clients; i++ {
5365 req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef"))
5366 if err != nil {
5367 t.Fatalf("unexpect new request error: %v", err)
5368 }
5369
5370 go func() {
5371 defer wg.Done()
5372 res, err := client.Do(req)
5373 if err == nil {
5374 res.Body.Close()
5375 }
5376 }()
5377 }
5378
5379 wg.Wait()
5380 }
5381
5382
5383
5384 func TestTransportServerResetStreamAtHeaders(t *testing.T) {
5385 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5386 w.WriteHeader(http.StatusUnauthorized)
5387 return
5388 }, optOnlyServer)
5389 defer st.Close()
5390
5391 tr := &http.Transport{
5392 TLSClientConfig: tlsConfigInsecure,
5393 MaxConnsPerHost: 1,
5394 ExpectContinueTimeout: 10 * time.Second,
5395 }
5396
5397 err := ConfigureTransport(tr)
5398 if err != nil {
5399 t.Fatal(err)
5400 }
5401 client := &http.Client{
5402 Transport: tr,
5403 }
5404
5405 req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF})
5406 if err != nil {
5407 t.Fatalf("unexpect new request error: %v", err)
5408 }
5409 req.ContentLength = 0
5410 req.Header.Set("Expect", "100-continue")
5411 res, err := client.Do(req)
5412 if err != nil {
5413 t.Fatal(err)
5414 }
5415 res.Body.Close()
5416 }
5417
5418 type trackingReader struct {
5419 rdr io.Reader
5420 wasRead uint32
5421 }
5422
5423 func (tr *trackingReader) Read(p []byte) (int, error) {
5424 atomic.StoreUint32(&tr.wasRead, 1)
5425 return tr.rdr.Read(p)
5426 }
5427
5428 func (tr *trackingReader) WasRead() bool {
5429 return atomic.LoadUint32(&tr.wasRead) != 0
5430 }
5431
5432 func TestTransportExpectContinue(t *testing.T) {
5433 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5434 switch r.URL.Path {
5435 case "/reject":
5436 w.WriteHeader(403)
5437 default:
5438 io.Copy(io.Discard, r.Body)
5439 }
5440 }, optOnlyServer)
5441 defer st.Close()
5442
5443 tr := &http.Transport{
5444 TLSClientConfig: tlsConfigInsecure,
5445 MaxConnsPerHost: 1,
5446 ExpectContinueTimeout: 10 * time.Second,
5447 }
5448
5449 err := ConfigureTransport(tr)
5450 if err != nil {
5451 t.Fatal(err)
5452 }
5453 client := &http.Client{
5454 Transport: tr,
5455 }
5456
5457 testCases := []struct {
5458 Name string
5459 Path string
5460 Body *trackingReader
5461 ExpectedCode int
5462 ShouldRead bool
5463 }{
5464 {
5465 Name: "read-all",
5466 Path: "/",
5467 Body: &trackingReader{rdr: strings.NewReader("hello")},
5468 ExpectedCode: 200,
5469 ShouldRead: true,
5470 },
5471 {
5472 Name: "reject",
5473 Path: "/reject",
5474 Body: &trackingReader{rdr: strings.NewReader("hello")},
5475 ExpectedCode: 403,
5476 ShouldRead: false,
5477 },
5478 }
5479
5480 for _, tc := range testCases {
5481 t.Run(tc.Name, func(t *testing.T) {
5482 startTime := time.Now()
5483
5484 req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body)
5485 if err != nil {
5486 t.Fatal(err)
5487 }
5488 req.Header.Set("Expect", "100-continue")
5489 res, err := client.Do(req)
5490 if err != nil {
5491 t.Fatal(err)
5492 }
5493 res.Body.Close()
5494
5495 if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout {
5496 t.Error("Request didn't finish before expect continue timeout")
5497 }
5498 if res.StatusCode != tc.ExpectedCode {
5499 t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode)
5500 }
5501 if tc.Body.WasRead() != tc.ShouldRead {
5502 t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead)
5503 }
5504 })
5505 }
5506 }
5507
5508 type closeChecker struct {
5509 io.ReadCloser
5510 closed chan struct{}
5511 }
5512
5513 func newCloseChecker(r io.ReadCloser) *closeChecker {
5514 return &closeChecker{r, make(chan struct{})}
5515 }
5516
5517 func newStaticCloseChecker(body string) *closeChecker {
5518 return newCloseChecker(io.NopCloser(strings.NewReader("body")))
5519 }
5520
5521 func (rc *closeChecker) Read(b []byte) (n int, err error) {
5522 select {
5523 default:
5524 case <-rc.closed:
5525
5526
5527
5528 return 0, errors.New("read after Body.Close")
5529 }
5530 return rc.ReadCloser.Read(b)
5531 }
5532
5533 func (rc *closeChecker) Close() error {
5534 close(rc.closed)
5535 return rc.ReadCloser.Close()
5536 }
5537
5538 func (rc *closeChecker) isClosed() error {
5539
5540
5541
5542 timeout := time.Duration(10 * time.Second)
5543 select {
5544 case <-rc.closed:
5545 case <-time.After(timeout):
5546 return fmt.Errorf("body not closed after %v", timeout)
5547 }
5548 return nil
5549 }
5550
5551
5552 type blockingWriteConn struct {
5553 net.Conn
5554 writeOnce sync.Once
5555 writec chan struct{}
5556 unblockc chan struct{}
5557 count, limit int
5558 }
5559
5560 func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn {
5561 return &blockingWriteConn{
5562 Conn: conn,
5563 limit: limit,
5564 writec: make(chan struct{}),
5565 unblockc: make(chan struct{}),
5566 }
5567 }
5568
5569
5570 func (c *blockingWriteConn) wait() {
5571 <-c.writec
5572 }
5573
5574
5575 func (c *blockingWriteConn) unblock() {
5576 close(c.unblockc)
5577 }
5578
5579 func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
5580 if c.count+len(b) > c.limit {
5581 c.writeOnce.Do(func() {
5582 close(c.writec)
5583 })
5584 <-c.unblockc
5585 }
5586 n, err = c.Conn.Write(b)
5587 c.count += n
5588 return n, err
5589 }
5590
5591
5592
5593 func TestTransportFrameBufferReuse(t *testing.T) {
5594 filler := hex.EncodeToString([]byte(randString(2048)))
5595
5596 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5597 if got, want := r.Header.Get("Big"), filler; got != want {
5598 t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
5599 }
5600 b, err := ioutil.ReadAll(r.Body)
5601 if err != nil {
5602 t.Errorf("error reading request body: %v", err)
5603 }
5604 if got, want := string(b), filler; got != want {
5605 t.Errorf("request body = %q, want %q", got, want)
5606 }
5607 if got, want := r.Trailer.Get("Big"), filler; got != want {
5608 t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
5609 }
5610 }, optOnlyServer)
5611 defer st.Close()
5612
5613 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
5614 defer tr.CloseIdleConnections()
5615
5616 var wg sync.WaitGroup
5617 defer wg.Wait()
5618 for i := 0; i < 10; i++ {
5619 wg.Add(1)
5620 go func() {
5621 defer wg.Done()
5622 req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler))
5623 if err != nil {
5624 t.Error(err)
5625 return
5626 }
5627 req.Header.Set("Big", filler)
5628 req.Trailer = make(http.Header)
5629 req.Trailer.Set("Big", filler)
5630 res, err := tr.RoundTrip(req)
5631 if err != nil {
5632 t.Error(err)
5633 return
5634 }
5635 if got, want := res.StatusCode, 200; got != want {
5636 t.Errorf("StatusCode = %v; want %v", got, want)
5637 }
5638 if res != nil && res.Body != nil {
5639 res.Body.Close()
5640 }
5641 }()
5642 }
5643
5644 }
5645
5646
5647
5648
5649
5650 func TestTransportBlockingRequestWrite(t *testing.T) {
5651 filler := hex.EncodeToString([]byte(randString(2048)))
5652 for _, test := range []struct {
5653 name string
5654 req func(url string) (*http.Request, error)
5655 }{{
5656 name: "headers",
5657 req: func(url string) (*http.Request, error) {
5658 req, err := http.NewRequest("POST", url, nil)
5659 if err != nil {
5660 return nil, err
5661 }
5662 req.Header.Set("Big", filler)
5663 return req, err
5664 },
5665 }, {
5666 name: "body",
5667 req: func(url string) (*http.Request, error) {
5668 req, err := http.NewRequest("POST", url, strings.NewReader(filler))
5669 if err != nil {
5670 return nil, err
5671 }
5672 return req, err
5673 },
5674 }, {
5675 name: "trailer",
5676 req: func(url string) (*http.Request, error) {
5677 req, err := http.NewRequest("POST", url, strings.NewReader("body"))
5678 if err != nil {
5679 return nil, err
5680 }
5681 req.Trailer = make(http.Header)
5682 req.Trailer.Set("Big", filler)
5683 return req, err
5684 },
5685 }} {
5686 test := test
5687 t.Run(test.name, func(t *testing.T) {
5688 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5689 if v := r.Header.Get("Big"); v != "" && v != filler {
5690 t.Errorf("request header mismatch")
5691 }
5692 if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler {
5693 t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler)
5694 }
5695 if v := r.Trailer.Get("Big"); v != "" && v != filler {
5696 t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler)
5697 }
5698 }, optOnlyServer, func(s *Server) {
5699 s.MaxConcurrentStreams = 1
5700 })
5701 defer st.Close()
5702
5703
5704 connc := make(chan *blockingWriteConn, 1)
5705 connCount := 0
5706 tr := &Transport{
5707 TLSClientConfig: tlsConfigInsecure,
5708 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
5709 connCount++
5710 c, err := tls.Dial(network, addr, cfg)
5711 wc := newBlockingWriteConn(c, 1024)
5712 select {
5713 case connc <- wc:
5714 default:
5715 }
5716 return wc, err
5717 },
5718 }
5719 defer tr.CloseIdleConnections()
5720
5721
5722 {
5723 req, err := http.NewRequest("POST", st.ts.URL, nil)
5724 if err != nil {
5725 t.Fatal(err)
5726 }
5727 res, err := tr.RoundTrip(req)
5728 if err != nil {
5729 t.Fatal(err)
5730 }
5731 if got, want := res.StatusCode, 200; got != want {
5732 t.Errorf("StatusCode = %v; want %v", got, want)
5733 }
5734 if res != nil && res.Body != nil {
5735 res.Body.Close()
5736 }
5737 }
5738
5739
5740 reqc := make(chan struct{})
5741 go func() {
5742 defer close(reqc)
5743 req, err := test.req(st.ts.URL)
5744 if err != nil {
5745 t.Error(err)
5746 return
5747 }
5748 res, _ := tr.RoundTrip(req)
5749 if res != nil && res.Body != nil {
5750 res.Body.Close()
5751 }
5752 }()
5753 conn := <-connc
5754 conn.wait()
5755
5756
5757
5758 {
5759 req, err := http.NewRequest("POST", st.ts.URL, nil)
5760 if err != nil {
5761 t.Fatal(err)
5762 }
5763 res, err := tr.RoundTrip(req)
5764 if err != nil {
5765 t.Fatal(err)
5766 }
5767 if got, want := res.StatusCode, 200; got != want {
5768 t.Errorf("StatusCode = %v; want %v", got, want)
5769 }
5770 if res != nil && res.Body != nil {
5771 res.Body.Close()
5772 }
5773 }
5774
5775
5776 select {
5777 case <-reqc:
5778 t.Errorf("request 2 unexpectedly completed")
5779 default:
5780 }
5781
5782 conn.unblock()
5783 <-reqc
5784
5785 if connCount != 2 {
5786 t.Errorf("created %v connections, want 1", connCount)
5787 }
5788 })
5789 }
5790 }
5791
5792 func TestTransportCloseRequestBody(t *testing.T) {
5793 var statusCode int
5794 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5795 w.WriteHeader(statusCode)
5796 }, optOnlyServer)
5797 defer st.Close()
5798
5799 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
5800 defer tr.CloseIdleConnections()
5801 ctx := context.Background()
5802 cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
5803 if err != nil {
5804 t.Fatal(err)
5805 }
5806
5807 for _, status := range []int{200, 401} {
5808 t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
5809 statusCode = status
5810 pr, pw := io.Pipe()
5811 body := newCloseChecker(pr)
5812 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
5813 if err != nil {
5814 t.Fatal(err)
5815 }
5816 res, err := cc.RoundTrip(req)
5817 if err != nil {
5818 t.Fatal(err)
5819 }
5820 res.Body.Close()
5821 pw.Close()
5822 if err := body.isClosed(); err != nil {
5823 t.Fatal(err)
5824 }
5825 })
5826 }
5827 }
5828
5829
5830
5831 type collectClientsConnPool struct {
5832 lower ClientConnPool
5833
5834 mu sync.Mutex
5835 getErrs int
5836 got []*ClientConn
5837 }
5838
5839 func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
5840 cc, err := p.lower.GetClientConn(req, addr)
5841 p.mu.Lock()
5842 defer p.mu.Unlock()
5843 if err != nil {
5844 p.getErrs++
5845 return nil, err
5846 }
5847 p.got = append(p.got, cc)
5848 return cc, nil
5849 }
5850
5851 func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
5852 p.lower.MarkDead(cc)
5853 }
5854
5855 func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
5856 ct := newClientTester(t)
5857 pool := &collectClientsConnPool{
5858 lower: &clientConnPool{t: ct.tr},
5859 }
5860 ct.tr.ConnPool = pool
5861
5862 gotProtoError := make(chan bool, 1)
5863 ct.tr.CountError = func(errType string) {
5864 if errType == "recv_rststream_PROTOCOL_ERROR" {
5865 select {
5866 case gotProtoError <- true:
5867 default:
5868 }
5869 }
5870 }
5871 ct.client = func() error {
5872
5873
5874
5875
5876
5877
5878
5879 req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil)
5880 res1, err := ct.tr.RoundTrip(req1)
5881 if err != nil {
5882 return err
5883 }
5884 if got, want := res1.Header.Get("Is-Long"), "1"; got != want {
5885 return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want)
5886 }
5887
5888 req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil)
5889 res, err := ct.tr.RoundTrip(req)
5890 const want = "only one dial allowed in test mode"
5891 if got := fmt.Sprint(err); got != want {
5892 t.Errorf("didn't dial again: got %#q; want %#q", got, want)
5893 }
5894 if res != nil {
5895 res.Body.Close()
5896 }
5897 select {
5898 case <-gotProtoError:
5899 default:
5900 t.Errorf("didn't get stream protocol error")
5901 }
5902
5903 if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 {
5904 t.Errorf("unexpected body read %v, %v", n, err)
5905 }
5906
5907 pool.mu.Lock()
5908 defer pool.mu.Unlock()
5909 if pool.getErrs != 1 {
5910 t.Errorf("pool get errors = %v; want 1", pool.getErrs)
5911 }
5912 if len(pool.got) == 2 {
5913 if pool.got[0] != pool.got[1] {
5914 t.Errorf("requests went on different connections")
5915 }
5916 cc := pool.got[0]
5917 cc.mu.Lock()
5918 if !cc.doNotReuse {
5919 t.Error("ClientConn not marked doNotReuse")
5920 }
5921 cc.mu.Unlock()
5922
5923 select {
5924 case <-cc.readerDone:
5925 case <-time.After(5 * time.Second):
5926 t.Errorf("timeout waiting for reader to be done")
5927 }
5928 } else {
5929 t.Errorf("pool get success = %v; want 2", len(pool.got))
5930 }
5931 return nil
5932 }
5933 ct.server = func() error {
5934 ct.greet()
5935 var sentErr bool
5936 var numHeaders int
5937 var firstStreamID uint32
5938
5939 var hbuf bytes.Buffer
5940 enc := hpack.NewEncoder(&hbuf)
5941
5942 for {
5943 f, err := ct.fr.ReadFrame()
5944 if err == io.EOF {
5945
5946 return nil
5947 }
5948 if err != nil {
5949 return nil
5950 }
5951 switch f := f.(type) {
5952 case *WindowUpdateFrame, *SettingsFrame:
5953 case *HeadersFrame:
5954 numHeaders++
5955 if numHeaders == 1 {
5956 firstStreamID = f.StreamID
5957 hbuf.Reset()
5958 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
5959 enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"})
5960 ct.fr.WriteHeaders(HeadersFrameParam{
5961 StreamID: f.StreamID,
5962 EndHeaders: true,
5963 EndStream: false,
5964 BlockFragment: hbuf.Bytes(),
5965 })
5966 continue
5967 }
5968 if !sentErr {
5969 sentErr = true
5970 ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol)
5971 ct.fr.WriteData(firstStreamID, true, nil)
5972 continue
5973 }
5974 }
5975 }
5976 }
5977 ct.run()
5978 }
5979
5980 func TestClientConnReservations(t *testing.T) {
5981 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
5982 }, func(s *Server) {
5983 s.MaxConcurrentStreams = initialMaxConcurrentStreams
5984 })
5985 defer st.Close()
5986
5987 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
5988 defer tr.CloseIdleConnections()
5989
5990 cc, err := tr.newClientConn(st.cc, false)
5991 if err != nil {
5992 t.Fatal(err)
5993 }
5994
5995 req, _ := http.NewRequest("GET", st.ts.URL, nil)
5996 n := 0
5997 for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
5998 n++
5999 }
6000 if n != initialMaxConcurrentStreams {
6001 t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
6002 }
6003 if _, err := cc.RoundTrip(req); err != nil {
6004 t.Fatalf("RoundTrip error = %v", err)
6005 }
6006 n2 := 0
6007 for n2 <= 5 && cc.ReserveNewRequest() {
6008 n2++
6009 }
6010 if n2 != 1 {
6011 t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
6012 }
6013
6014
6015 for i := 0; i < n; i++ {
6016 cc.RoundTrip(req)
6017 }
6018
6019 n2 = 0
6020 for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
6021 n2++
6022 }
6023 if n2 != n {
6024 t.Errorf("after reset, reservations = %v; want %v", n2, n)
6025 }
6026 }
6027
6028 func TestTransportTimeoutServerHangs(t *testing.T) {
6029 clientDone := make(chan struct{})
6030 ct := newClientTester(t)
6031 ct.client = func() error {
6032 defer ct.cc.(*net.TCPConn).CloseWrite()
6033 defer close(clientDone)
6034
6035 req, err := http.NewRequest("PUT", "https://dummy.tld/", nil)
6036 if err != nil {
6037 return err
6038 }
6039
6040 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
6041 defer cancel()
6042 req = req.WithContext(ctx)
6043 req.Header.Add("Big", strings.Repeat("a", 1<<20))
6044 _, err = ct.tr.RoundTrip(req)
6045 if err == nil {
6046 return errors.New("error should not be nil")
6047 }
6048 if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
6049 return fmt.Errorf("error should be a net error timeout: %v", err)
6050 }
6051 return nil
6052 }
6053 ct.server = func() error {
6054 ct.greet()
6055 select {
6056 case <-time.After(5 * time.Second):
6057 case <-clientDone:
6058 }
6059 return nil
6060 }
6061 ct.run()
6062 }
6063
6064 func TestTransportContentLengthWithoutBody(t *testing.T) {
6065 contentLength := ""
6066 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
6067 w.Header().Set("Content-Length", contentLength)
6068 }, optOnlyServer)
6069 defer st.Close()
6070 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
6071 defer tr.CloseIdleConnections()
6072
6073 for _, test := range []struct {
6074 name string
6075 contentLength string
6076 wantBody string
6077 wantErr error
6078 wantContentLength int64
6079 }{
6080 {
6081 name: "non-zero content length",
6082 contentLength: "42",
6083 wantErr: io.ErrUnexpectedEOF,
6084 wantContentLength: 42,
6085 },
6086 {
6087 name: "zero content length",
6088 contentLength: "0",
6089 wantErr: nil,
6090 wantContentLength: 0,
6091 },
6092 } {
6093 t.Run(test.name, func(t *testing.T) {
6094 contentLength = test.contentLength
6095
6096 req, _ := http.NewRequest("GET", st.ts.URL, nil)
6097 res, err := tr.RoundTrip(req)
6098 if err != nil {
6099 t.Fatal(err)
6100 }
6101 defer res.Body.Close()
6102 body, err := io.ReadAll(res.Body)
6103
6104 if err != test.wantErr {
6105 t.Errorf("Expected error %v, got: %v", test.wantErr, err)
6106 }
6107 if len(body) > 0 {
6108 t.Errorf("Expected empty body, got: %v", body)
6109 }
6110 if res.ContentLength != test.wantContentLength {
6111 t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength)
6112 }
6113 })
6114 }
6115 }
6116
6117 func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
6118 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
6119 w.WriteHeader(200)
6120 w.(http.Flusher).Flush()
6121 io.Copy(io.Discard, r.Body)
6122 }, optOnlyServer)
6123 defer st.Close()
6124
6125 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
6126 defer tr.CloseIdleConnections()
6127
6128 pr, pw := net.Pipe()
6129 req, err := http.NewRequest("GET", st.ts.URL, pr)
6130 if err != nil {
6131 t.Fatal(err)
6132 }
6133 res, err := tr.RoundTrip(req)
6134 if err != nil {
6135 t.Fatal(err)
6136 }
6137
6138 res.Body.Close()
6139 pw.Close()
6140 }
6141
6142 func TestTransport300ResponseBody(t *testing.T) {
6143 reqc := make(chan struct{})
6144 body := []byte("response body")
6145 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
6146 w.WriteHeader(300)
6147 w.(http.Flusher).Flush()
6148 <-reqc
6149 w.Write(body)
6150 }, optOnlyServer)
6151 defer st.Close()
6152
6153 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
6154 defer tr.CloseIdleConnections()
6155
6156 pr, pw := net.Pipe()
6157 req, err := http.NewRequest("GET", st.ts.URL, pr)
6158 if err != nil {
6159 t.Fatal(err)
6160 }
6161 res, err := tr.RoundTrip(req)
6162 if err != nil {
6163 t.Fatal(err)
6164 }
6165 close(reqc)
6166 got, err := io.ReadAll(res.Body)
6167 if err != nil {
6168 t.Fatalf("error reading response body: %v", err)
6169 }
6170 if !bytes.Equal(got, body) {
6171 t.Errorf("got response body %q, want %q", string(got), string(body))
6172 }
6173 res.Body.Close()
6174 pw.Close()
6175 }
6176
6177 func TestTransportWriteByteTimeout(t *testing.T) {
6178 st := newServerTester(t,
6179 func(w http.ResponseWriter, r *http.Request) {},
6180 optOnlyServer,
6181 )
6182 defer st.Close()
6183 tr := &Transport{
6184 TLSClientConfig: tlsConfigInsecure,
6185 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
6186 _, c := net.Pipe()
6187 return c, nil
6188 },
6189 WriteByteTimeout: 1 * time.Millisecond,
6190 }
6191 defer tr.CloseIdleConnections()
6192 c := &http.Client{Transport: tr}
6193
6194 _, err := c.Get(st.ts.URL)
6195 if !errors.Is(err, os.ErrDeadlineExceeded) {
6196 t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
6197 }
6198 }
6199
6200 type slowWriteConn struct {
6201 net.Conn
6202 hasWriteDeadline bool
6203 }
6204
6205 func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
6206 c.hasWriteDeadline = !t.IsZero()
6207 return nil
6208 }
6209
6210 func (c *slowWriteConn) Write(b []byte) (n int, err error) {
6211 if c.hasWriteDeadline && len(b) > 1 {
6212 n, err = c.Conn.Write(b[:1])
6213 if err != nil {
6214 return n, err
6215 }
6216 return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
6217 }
6218 return c.Conn.Write(b)
6219 }
6220
6221 func TestTransportSlowWrites(t *testing.T) {
6222 st := newServerTester(t,
6223 func(w http.ResponseWriter, r *http.Request) {},
6224 optOnlyServer,
6225 )
6226 defer st.Close()
6227 tr := &Transport{
6228 TLSClientConfig: tlsConfigInsecure,
6229 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
6230 cfg.InsecureSkipVerify = true
6231 c, err := tls.Dial(network, addr, cfg)
6232 return &slowWriteConn{Conn: c}, err
6233 },
6234 WriteByteTimeout: 1 * time.Millisecond,
6235 }
6236 defer tr.CloseIdleConnections()
6237 c := &http.Client{Transport: tr}
6238
6239 const bodySize = 1 << 20
6240 resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
6241 if err != nil {
6242 t.Fatal(err)
6243 }
6244 resp.Body.Close()
6245 }
6246
6247 func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) {
6248 testTransportClosesConnAfterGoAway(t, 0)
6249 }
6250 func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
6251 testTransportClosesConnAfterGoAway(t, 1)
6252 }
6253
6254 type closeOnceConn struct {
6255 net.Conn
6256 closed uint32
6257 }
6258
6259 var errClosed = errors.New("Close of closed connection")
6260
6261 func (c *closeOnceConn) Close() error {
6262 if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
6263 return c.Conn.Close()
6264 }
6265 return errClosed
6266 }
6267
6268
6269
6270
6271
6272
6273
6274 func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) {
6275 ct := newClientTester(t)
6276 ct.cc = &closeOnceConn{Conn: ct.cc}
6277
6278 var wg sync.WaitGroup
6279 wg.Add(1)
6280 ct.client = func() error {
6281 defer wg.Done()
6282 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
6283 res, err := ct.tr.RoundTrip(req)
6284 if err == nil {
6285 res.Body.Close()
6286 }
6287 if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
6288 t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
6289 }
6290 if err = ct.cc.Close(); err != errClosed {
6291 return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err)
6292 }
6293 return nil
6294 }
6295
6296 ct.server = func() error {
6297 defer wg.Wait()
6298 ct.greet()
6299 hf, err := ct.firstHeaders()
6300 if err != nil {
6301 return fmt.Errorf("server failed reading HEADERS: %v", err)
6302 }
6303 if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil {
6304 return fmt.Errorf("server failed writing GOAWAY: %v", err)
6305 }
6306 if lastStream > 0 {
6307
6308 var buf bytes.Buffer
6309 enc := hpack.NewEncoder(&buf)
6310 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
6311 ct.fr.WriteHeaders(HeadersFrameParam{
6312 StreamID: hf.StreamID,
6313 EndHeaders: true,
6314 EndStream: true,
6315 BlockFragment: buf.Bytes(),
6316 })
6317 }
6318 return nil
6319 }
6320
6321 ct.run()
6322 }
6323
6324 type slowCloser struct {
6325 closing chan struct{}
6326 closed chan struct{}
6327 }
6328
6329 func (r *slowCloser) Read([]byte) (int, error) {
6330 return 0, io.EOF
6331 }
6332
6333 func (r *slowCloser) Close() error {
6334 close(r.closing)
6335 <-r.closed
6336 return nil
6337 }
6338
6339 func TestTransportSlowClose(t *testing.T) {
6340 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
6341 }, optOnlyServer)
6342 defer st.Close()
6343
6344 client := st.ts.Client()
6345 body := &slowCloser{
6346 closing: make(chan struct{}),
6347 closed: make(chan struct{}),
6348 }
6349
6350 reqc := make(chan struct{})
6351 go func() {
6352 defer close(reqc)
6353 res, err := client.Post(st.ts.URL, "text/plain", body)
6354 if err != nil {
6355 t.Error(err)
6356 }
6357 res.Body.Close()
6358 }()
6359 defer func() {
6360 close(body.closed)
6361 <-reqc
6362 }()
6363
6364 <-body.closing
6365
6366 res, err := client.Get(st.ts.URL)
6367 if err != nil {
6368 t.Fatal(err)
6369 }
6370 res.Body.Close()
6371 }
6372
6373 func TestTransportDialTLSContext(t *testing.T) {
6374 blockCh := make(chan struct{})
6375 serverTLSConfigFunc := func(ts *httptest.Server) {
6376 ts.Config.TLSConfig = &tls.Config{
6377
6378
6379 ClientAuth: tls.RequestClientCert,
6380 }
6381 }
6382 ts := newServerTester(t,
6383 func(w http.ResponseWriter, r *http.Request) {},
6384 optOnlyServer,
6385 serverTLSConfigFunc,
6386 )
6387 defer ts.Close()
6388 tr := &Transport{
6389 TLSClientConfig: &tls.Config{
6390 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
6391
6392
6393 close(blockCh)
6394 <-cri.Context().Done()
6395 return nil, cri.Context().Err()
6396 },
6397 InsecureSkipVerify: true,
6398 },
6399 }
6400 defer tr.CloseIdleConnections()
6401 req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
6402 if err != nil {
6403 t.Fatal(err)
6404 }
6405 ctx, cancel := context.WithCancel(context.Background())
6406 defer cancel()
6407 req = req.WithContext(ctx)
6408 errCh := make(chan error)
6409 go func() {
6410 defer close(errCh)
6411 res, err := tr.RoundTrip(req)
6412 if err != nil {
6413 errCh <- err
6414 return
6415 }
6416 res.Body.Close()
6417 }()
6418
6419 <-blockCh
6420
6421 cancel()
6422
6423 err = <-errCh
6424 if err == nil {
6425 t.Fatal("cancelling context during client certificate fetch did not error as expected")
6426 return
6427 }
6428 if !errors.Is(err, context.Canceled) {
6429 t.Fatalf("unexpected error returned after cancellation: %v", err)
6430 }
6431 }
6432
6433
6434
6435
6436
6437 func TestDialRaceResumesDial(t *testing.T) {
6438 blockCh := make(chan struct{})
6439 serverTLSConfigFunc := func(ts *httptest.Server) {
6440 ts.Config.TLSConfig = &tls.Config{
6441
6442
6443 ClientAuth: tls.RequestClientCert,
6444 }
6445 }
6446 ts := newServerTester(t,
6447 func(w http.ResponseWriter, r *http.Request) {},
6448 optOnlyServer,
6449 serverTLSConfigFunc,
6450 )
6451 defer ts.Close()
6452 tr := &Transport{
6453 TLSClientConfig: &tls.Config{
6454 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
6455 select {
6456 case <-blockCh:
6457
6458 return &tls.Certificate{}, nil
6459 default:
6460 }
6461 close(blockCh)
6462 <-cri.Context().Done()
6463 return nil, cri.Context().Err()
6464 },
6465 InsecureSkipVerify: true,
6466 },
6467 }
6468 defer tr.CloseIdleConnections()
6469 req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
6470 if err != nil {
6471 t.Fatal(err)
6472 }
6473
6474 ctx1, cancel1 := context.WithCancel(context.Background())
6475 defer cancel1()
6476 req1 := req.WithContext(ctx1)
6477 ctx2, cancel2 := context.WithCancel(context.Background())
6478 defer cancel2()
6479 req2 := req.WithContext(ctx2)
6480 errCh := make(chan error)
6481 go func() {
6482 res, err := tr.RoundTrip(req1)
6483 if err != nil {
6484 errCh <- err
6485 return
6486 }
6487 res.Body.Close()
6488 }()
6489 successCh := make(chan struct{})
6490 go func() {
6491
6492
6493 <-blockCh
6494 res, err := tr.RoundTrip(req2)
6495 if err != nil {
6496 errCh <- err
6497 return
6498 }
6499 res.Body.Close()
6500
6501
6502 close(successCh)
6503 }()
6504
6505 <-blockCh
6506
6507 cancel1()
6508
6509 err = <-errCh
6510 if err == nil {
6511 t.Fatal("cancelling context during client certificate fetch did not error as expected")
6512 return
6513 }
6514 if !errors.Is(err, context.Canceled) {
6515 t.Fatalf("unexpected error returned after cancellation: %v", err)
6516 }
6517 select {
6518 case err := <-errCh:
6519 t.Fatalf("unexpected second error: %v", err)
6520 case <-successCh:
6521 }
6522 }
6523
View as plain text