// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package http2 import ( "bufio" "bytes" "compress/gzip" "context" "crypto/tls" "encoding/hex" "errors" "flag" "fmt" "io" "io/fs" "io/ioutil" "log" "math/rand" "net" "net/http" "net/http/httptest" "net/http/httptrace" "net/textproto" "net/url" "os" "reflect" "runtime" "sort" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "golang.org/x/net/http2/hpack" ) var ( extNet = flag.Bool("extnet", false, "do external network tests") transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport") insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove? ) var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true} var canceledCtx context.Context func init() { ctx, cancel := context.WithCancel(context.Background()) cancel() canceledCtx = ctx } func TestTransportExternal(t *testing.T) { if !*extNet { t.Skip("skipping external network test") } req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) rt := &Transport{TLSClientConfig: tlsConfigInsecure} res, err := rt.RoundTrip(req) if err != nil { t.Fatalf("%v", err) } res.Write(os.Stdout) } type fakeTLSConn struct { net.Conn } func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { return tls.ConnectionState{ Version: tls.VersionTLS12, CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } } func startH2cServer(t *testing.T) net.Listener { h2Server := &Server{} l := newLocalListener(t) go func() { conn, err := l.Accept() if err != nil { t.Error(err) return } h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) })}) }() return l } func TestTransportH2c(t *testing.T) { l := startH2cServer(t) defer l.Close() req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil) if err != nil { t.Fatal(err) } var gotConnCnt int32 trace := &httptrace.ClientTrace{ GotConn: func(connInfo httptrace.GotConnInfo) { if !connInfo.Reused { atomic.AddInt32(&gotConnCnt, 1) } }, } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) tr := &Transport{ AllowHTTP: true, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return net.Dial(network, addr) }, } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } if res.ProtoMajor != 2 { t.Fatal("proto not h2c") } body, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatal(err) } if got, want := string(body), "Hello, /foobar, http: true"; got != want { t.Fatalf("response got %v, want %v", got, want) } if got, want := gotConnCnt, int32(1); got != want { t.Errorf("Too many got connections: %d", gotConnCnt) } } func TestTransport(t *testing.T) { const body = "sup" st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, body) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() u, err := url.Parse(st.ts.URL) if err != nil { t.Fatal(err) } for i, m := range []string{"GET", ""} { req := &http.Request{ Method: m, URL: u, } res, err := tr.RoundTrip(req) if err != nil { t.Fatalf("%d: %s", i, err) } t.Logf("%d: Got res: %+v", i, res) if g, w := res.StatusCode, 200; g != w { t.Errorf("%d: StatusCode = %v; want %v", i, g, w) } if g, w := res.Status, "200 OK"; g != w { t.Errorf("%d: Status = %q; want %q", i, g, w) } wantHeader := http.Header{ "Content-Length": []string{"3"}, "Content-Type": []string{"text/plain; charset=utf-8"}, "Date": []string{"XXX"}, // see cleanDate } cleanDate(res) if !reflect.DeepEqual(res.Header, wantHeader) { t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader) } if res.Request != req { t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req) } if res.TLS == nil { t.Errorf("%d: Response.TLS = nil; want non-nil", i) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Errorf("%d: Body read: %v", i, err) } else if string(slurp) != body { t.Errorf("%d: Body = %q; want %q", i, slurp, body) } res.Body.Close() } } func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq func(*http.Request)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr) }, optOnlyServer, func(c net.Conn, st http.ConnState) { t.Logf("conn %v is now state %v", c.RemoteAddr(), st) }) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} if useClient { tr.ConnPool = noDialClientConnPool{new(clientConnPool)} } defer tr.CloseIdleConnections() get := func() string { req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } modReq(req) var res *http.Response if useClient { c := st.ts.Client() ConfigureTransports(c.Transport.(*http.Transport)) res, err = c.Do(req) } else { res, err = tr.RoundTrip(req) } if err != nil { t.Fatal(err) } defer res.Body.Close() slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("Body read: %v", err) } addr := strings.TrimSpace(string(slurp)) if addr == "" { t.Fatalf("didn't get an addr in response") } return addr } first := get() second := get() if got := first == second; got != wantSame { t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame) } } func TestTransportReusesConns(t *testing.T) { for _, test := range []struct { name string modReq func(*http.Request) wantSame bool }{{ name: "ReuseConn", modReq: func(*http.Request) {}, wantSame: true, }, { name: "RequestClose", modReq: func(r *http.Request) { r.Close = true }, wantSame: false, }, { name: "ConnClose", modReq: func(r *http.Request) { r.Header.Set("Connection", "close") }, wantSame: false, }} { t.Run(test.name, func(t *testing.T) { t.Run("Transport", func(t *testing.T) { const useClient = false testTransportReusesConns(t, useClient, test.wantSame, test.modReq) }) t.Run("Client", func(t *testing.T) { const useClient = true testTransportReusesConns(t, useClient, test.wantSame, test.modReq) }) }) } } func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) { testTransportGetGotConnHooks(t, false) } func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) } func testTransportGetGotConnHooks(t *testing.T, useClient bool) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr) }, func(s *httptest.Server) { s.EnableHTTP2 = true }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} client := st.ts.Client() ConfigureTransports(client.Transport.(*http.Transport)) var ( getConns int32 gotConns int32 ) for i := 0; i < 2; i++ { trace := &httptrace.ClientTrace{ GetConn: func(hostport string) { atomic.AddInt32(&getConns, 1) }, GotConn: func(connInfo httptrace.GotConnInfo) { got := atomic.AddInt32(&gotConns, 1) wantReused, wantWasIdle := false, false if got > 1 { wantReused, wantWasIdle = true, true } if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle { t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle) } }, } req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) var res *http.Response if useClient { res, err = client.Do(req) } else { res, err = tr.RoundTrip(req) } if err != nil { t.Fatal(err) } res.Body.Close() if get := atomic.LoadInt32(&getConns); get != int32(i+1) { t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1) } if got := atomic.LoadInt32(&gotConns); got != int32(i+1) { t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1) } } } type testNetConn struct { net.Conn closed bool onClose func() } func (c *testNetConn) Close() error { if !c.closed { // We can call Close multiple times on the same net.Conn. c.onClose() } c.closed = true return c.Conn.Close() } // Tests that the Transport only keeps one pending dial open per destination address. // https://golang.org/issue/13397 func TestTransportGroupsPendingDials(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, optOnlyServer) defer st.Close() var ( mu sync.Mutex dialCount int closeCount int ) tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() dialCount++ mu.Unlock() c, err := tls.Dial(network, addr, cfg) return &testNetConn{ Conn: c, onClose: func() { mu.Lock() closeCount++ mu.Unlock() }, }, err }, } defer tr.CloseIdleConnections() var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Error(err) return } res, err := tr.RoundTrip(req) if err != nil { t.Error(err) return } res.Body.Close() }() } wg.Wait() tr.CloseIdleConnections() if dialCount != 1 { t.Errorf("saw %d dials; want 1", dialCount) } if closeCount != 1 { t.Errorf("saw %d closes; want 1", closeCount) } } func retry(tries int, delay time.Duration, fn func() error) error { var err error for i := 0; i < tries; i++ { err = fn() if err == nil { return nil } time.Sleep(delay) } return err } func TestTransportAbortClosesPipes(t *testing.T) { shutdown := make(chan struct{}) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() <-shutdown }, optOnlyServer, ) defer st.Close() defer close(shutdown) // we must shutdown before st.Close() to avoid hanging errCh := make(chan error) go func() { defer close(errCh) tr := &Transport{TLSClientConfig: tlsConfigInsecure} req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { errCh <- err return } res, err := tr.RoundTrip(req) if err != nil { errCh <- err return } defer res.Body.Close() st.closeConn() _, err = ioutil.ReadAll(res.Body) if err == nil { errCh <- errors.New("expected error from res.Body.Read") return } }() select { case err := <-errCh: if err != nil { t.Fatal(err) } // deadlock? that's a bug. case <-time.After(3 * time.Second): t.Fatal("timeout") } } // TODO: merge this with TestTransportBody to make TestTransportRequest? This // could be a table-driven test with extra goodies. func TestTransportPath(t *testing.T) { gotc := make(chan *url.URL, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { gotc <- r.URL }, optOnlyServer, ) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() const ( path = "/testpath" query = "q=1" ) surl := st.ts.URL + path + "?" + query req, err := http.NewRequest("POST", surl, nil) if err != nil { t.Fatal(err) } c := &http.Client{Transport: tr} res, err := c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() got := <-gotc if got.Path != path { t.Errorf("Read Path = %q; want %q", got.Path, path) } if got.RawQuery != query { t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query) } } func randString(n int) string { rnd := rand.New(rand.NewSource(int64(n))) b := make([]byte, n) for i := range b { b[i] = byte(rnd.Intn(256)) } return string(b) } type panicReader struct{} func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } func (panicReader) Close() error { panic("unexpected Close") } func TestActualContentLength(t *testing.T) { tests := []struct { req *http.Request want int64 }{ // Verify we don't read from Body: 0: { req: &http.Request{Body: panicReader{}}, want: -1, }, // nil Body means 0, regardless of ContentLength: 1: { req: &http.Request{Body: nil, ContentLength: 5}, want: 0, }, // ContentLength is used if set. 2: { req: &http.Request{Body: panicReader{}, ContentLength: 5}, want: 5, }, // http.NoBody means 0, not -1. 3: { req: &http.Request{Body: http.NoBody}, want: 0, }, } for i, tt := range tests { got := actualContentLength(tt.req) if got != tt.want { t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) } } } func TestTransportBody(t *testing.T) { bodyTests := []struct { body string noContentLen bool }{ {body: "some message"}, {body: "some message", noContentLen: true}, {body: strings.Repeat("a", 1<<20), noContentLen: true}, {body: strings.Repeat("a", 1<<20)}, {body: randString(16<<10 - 1)}, {body: randString(16 << 10)}, {body: randString(16<<10 + 1)}, {body: randString(512<<10 - 1)}, {body: randString(512 << 10)}, {body: randString(512<<10 + 1)}, {body: randString(1<<20 - 1)}, {body: randString(1 << 20)}, {body: randString(1<<20 + 2)}, } type reqInfo struct { req *http.Request slurp []byte err error } gotc := make(chan reqInfo, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { slurp, err := ioutil.ReadAll(r.Body) if err != nil { gotc <- reqInfo{err: err} } else { gotc <- reqInfo{req: r, slurp: slurp} } }, optOnlyServer, ) defer st.Close() for i, tt := range bodyTests { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() var body io.Reader = strings.NewReader(tt.body) if tt.noContentLen { body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods } req, err := http.NewRequest("POST", st.ts.URL, body) if err != nil { t.Fatalf("#%d: %v", i, err) } c := &http.Client{Transport: tr} res, err := c.Do(req) if err != nil { t.Fatalf("#%d: %v", i, err) } defer res.Body.Close() ri := <-gotc if ri.err != nil { t.Errorf("#%d: read error: %v", i, ri.err) continue } if got := string(ri.slurp); got != tt.body { t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body)) } wantLen := int64(len(tt.body)) if tt.noContentLen && tt.body != "" { wantLen = -1 } if ri.req.ContentLength != wantLen { t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen) } } } func shortString(v string) string { const maxLen = 100 if len(v) <= maxLen { return v } return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:]) } func TestTransportDialTLS(t *testing.T) { var mu sync.Mutex // guards following var gotReq, didDial bool ts := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { mu.Lock() gotReq = true mu.Unlock() }, optOnlyServer, ) defer ts.Close() tr := &Transport{ DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() didDial = true mu.Unlock() cfg.InsecureSkipVerify = true c, err := tls.Dial(netw, addr, cfg) if err != nil { return nil, err } return c, c.Handshake() }, } defer tr.CloseIdleConnections() client := &http.Client{Transport: tr} res, err := client.Get(ts.ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if !didDial { t.Error("didn't use dial hook") } } func TestConfigureTransport(t *testing.T) { t1 := &http.Transport{} err := ConfigureTransport(t1) if err != nil { t.Fatal(err) } if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) { // Laziness, to avoid buildtags. t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got) } wantNextProtos := []string{"h2", "http/1.1"} if t1.TLSClientConfig == nil { t.Errorf("nil t1.TLSClientConfig") } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) { t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos) } if err := ConfigureTransport(t1); err == nil { t.Error("unexpected success on second call to ConfigureTransport") } // And does it work? st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.Proto) }, optOnlyServer) defer st.Close() t1.TLSClientConfig.InsecureSkipVerify = true c := &http.Client{Transport: t1} res, err := c.Get(st.ts.URL) if err != nil { t.Fatal(err) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatal(err) } if got, want := string(slurp), "HTTP/2.0"; got != want { t.Errorf("body = %q; want %q", got, want) } } type capitalizeReader struct { r io.Reader } func (cr capitalizeReader) Read(p []byte) (n int, err error) { n, err = cr.r.Read(p) for i, b := range p[:n] { if b >= 'a' && b <= 'z' { p[i] = b - ('a' - 'A') } } return } type flushWriter struct { w io.Writer } func (fw flushWriter) Write(p []byte) (n int, err error) { n, err = fw.w.Write(p) if f, ok := fw.w.(http.Flusher); ok { f.Flush() } return } type clientTester struct { t *testing.T tr *Transport sc, cc net.Conn // server and client conn fr *Framer // server's framer settings *SettingsFrame client func() error server func() error } func newClientTester(t *testing.T) *clientTester { var dialOnce struct { sync.Mutex dialed bool } ct := &clientTester{ t: t, } ct.tr = &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { dialOnce.Lock() defer dialOnce.Unlock() if dialOnce.dialed { return nil, errors.New("only one dial allowed in test mode") } dialOnce.dialed = true return ct.cc, nil }, } ln := newLocalListener(t) cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } sc, err := ln.Accept() if err != nil { t.Fatal(err) } ln.Close() ct.cc = cc ct.sc = sc ct.fr = NewFramer(sc, sc) return ct } func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp4", "127.0.0.1:0") if err == nil { return ln } ln, err = net.Listen("tcp6", "[::1]:0") if err != nil { t.Fatal(err) } return ln } func (ct *clientTester) greet(settings ...Setting) { buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) if err != nil { ct.t.Fatalf("reading client preface: %v", err) } f, err := ct.fr.ReadFrame() if err != nil { ct.t.Fatalf("Reading client settings frame: %v", err) } var ok bool if ct.settings, ok = f.(*SettingsFrame); !ok { ct.t.Fatalf("Wanted client settings frame; got %v", f) } if err := ct.fr.WriteSettings(settings...); err != nil { ct.t.Fatal(err) } if err := ct.fr.WriteSettingsAck(); err != nil { ct.t.Fatal(err) } } func (ct *clientTester) readNonSettingsFrame() (Frame, error) { for { f, err := ct.fr.ReadFrame() if err != nil { return nil, err } if _, ok := f.(*SettingsFrame); ok { continue } return f, nil } } // writeReadPing sends a PING and immediately reads the PING ACK. // It will fail if any other unread data was pending on the connection, // aside from SETTINGS frames. func (ct *clientTester) writeReadPing() error { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} if err := ct.fr.WritePing(false, data); err != nil { return fmt.Errorf("Error writing PING: %v", err) } f, err := ct.readNonSettingsFrame() if err != nil { return err } p, ok := f.(*PingFrame) if !ok { return fmt.Errorf("got a %v, want a PING ACK", f) } if p.Flags&FlagPingAck == 0 { return fmt.Errorf("got a PING, want a PING ACK") } if p.Data != data { return fmt.Errorf("got PING data = %x, want %x", p.Data, data) } return nil } func (ct *clientTester) inflowWindow(streamID uint32) int32 { pool := ct.tr.connPoolOrDef.(*clientConnPool) pool.mu.Lock() defer pool.mu.Unlock() if n := len(pool.keys); n != 1 { ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) return -1 } for cc := range pool.keys { cc.mu.Lock() defer cc.mu.Unlock() if streamID == 0 { return cc.inflow.avail + cc.inflow.unsent } cs := cc.streams[streamID] if cs == nil { ct.t.Errorf("no stream with id %v", streamID) return -1 } return cs.inflow.avail + cs.inflow.unsent } return -1 } func (ct *clientTester) cleanup() { ct.tr.CloseIdleConnections() // close both connections, ignore the error if its already closed ct.sc.Close() ct.cc.Close() } func (ct *clientTester) run() { var errOnce sync.Once var wg sync.WaitGroup run := func(which string, fn func() error) { defer wg.Done() if err := fn(); err != nil { errOnce.Do(func() { ct.t.Errorf("%s: %v", which, err) ct.cleanup() }) } } wg.Add(2) go run("client", ct.client) go run("server", ct.server) wg.Wait() errOnce.Do(ct.cleanup) // clean up if no error } func (ct *clientTester) readFrame() (Frame, error) { return ct.fr.ReadFrame() } func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { for { f, err := ct.readFrame() if err != nil { return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f.(type) { case *WindowUpdateFrame, *SettingsFrame: continue } hf, ok := f.(*HeadersFrame) if !ok { return nil, fmt.Errorf("Got %T; want HeadersFrame", f) } return hf, nil } } type countingReader struct { n *int64 } func (r countingReader) Read(p []byte) (n int, err error) { for i := range p { p[i] = byte(i) } atomic.AddInt64(r.n, int64(len(p))) return len(p), err } func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } func testTransportReqBodyAfterResponse(t *testing.T, status int) { const bodySize = 10 << 20 clientDone := make(chan struct{}) ct := newClientTester(t) recvLen := make(chan int64, 1) ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } defer close(clientDone) body := &pipe{b: new(bytes.Buffer)} io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) req, err := http.NewRequest("PUT", "https://dummy.tld/", body) if err != nil { return err } res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } if res.StatusCode != status { return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) } io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) body.CloseWithError(io.EOF) slurp, err := ioutil.ReadAll(res.Body) if err != nil { return fmt.Errorf("Slurp: %v", err) } if len(slurp) > 0 { return fmt.Errorf("unexpected body: %q", slurp) } res.Body.Close() if status == 200 { if got := <-recvLen; got != bodySize { return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) } } else { if got := <-recvLen; got == 0 || got >= bodySize { return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) } } return nil } ct.server = func() error { ct.greet() defer close(recvLen) var buf bytes.Buffer enc := hpack.NewEncoder(&buf) var dataRecv int64 var closed bool for { f, err := ct.fr.ReadFrame() if err != nil { select { case <-clientDone: // If the client's done, it // will have reported any // errors on its side. return nil default: return err } } //println(fmt.Sprintf("server got frame: %v", f)) ended := false switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } if f.StreamEnded() { return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) } case *DataFrame: dataLen := len(f.Data()) if dataLen > 0 { if dataRecv == 0 { enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) } if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { return err } if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { return err } } dataRecv += int64(dataLen) if !closed && ((status != 200 && dataRecv > 0) || (status == 200 && f.StreamEnded())) { closed = true if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { return err } } if f.StreamEnded() { ended = true } case *RSTStreamFrame: if status == 200 { return fmt.Errorf("Unexpected client frame %v", f) } ended = true default: return fmt.Errorf("Unexpected client frame %v", f) } if ended { select { case recvLen <- dataRecv: default: } } } } ct.run() } // See golang.org/issue/13444 func TestTransportFullDuplex(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) // redundant but for clarity w.(http.Flusher).Flush() io.Copy(flushWriter{w}, capitalizeReader{r.Body}) fmt.Fprintf(w, "bye.\n") }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} pr, pw := io.Pipe() req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr)) if err != nil { t.Fatal(err) } req.ContentLength = -1 res, err := c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != 200 { t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200) } bs := bufio.NewScanner(res.Body) want := func(v string) { if !bs.Scan() { t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err()) } } write := func(v string) { _, err := io.WriteString(pw, v) if err != nil { t.Fatalf("pipe write: %v", err) } } write("foo\n") want("FOO") write("bar\n") want("BAR") pw.Close() want("bye.") if err := bs.Err(); err != nil { t.Fatal(err) } } func TestTransportConnectRequest(t *testing.T) { gotc := make(chan *http.Request, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { gotc <- r }, optOnlyServer) defer st.Close() u, err := url.Parse(st.ts.URL) if err != nil { t.Fatal(err) } tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} tests := []struct { req *http.Request want string }{ { req: &http.Request{ Method: "CONNECT", Header: http.Header{}, URL: u, }, want: u.Host, }, { req: &http.Request{ Method: "CONNECT", Header: http.Header{}, URL: u, Host: "example.com:123", }, want: "example.com:123", }, } for i, tt := range tests { res, err := c.Do(tt.req) if err != nil { t.Errorf("%d. RoundTrip = %v", i, err) continue } res.Body.Close() req := <-gotc if req.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", req.Method) } if req.Host != tt.want { t.Errorf("Host = %q; want %q", req.Host, tt.want) } if req.URL.Host != tt.want { t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) } } } type headerType int const ( noHeader headerType = iota // omitted oneHeader splitHeader // broken into continuation on purpose ) const ( f0 = noHeader f1 = oneHeader f2 = splitHeader d0 = false d1 = true ) // Test all 36 combinations of response frame orders: // // (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) } // // Generated by http://play.golang.org/p/SScqYKJYXd func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) } func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) } func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) } func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) } func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) } func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) } func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) } func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) } func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) } func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) } func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) } func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) } func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) } func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) } func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) } func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) } func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) } func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) } func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) } func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) } func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) } func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) } func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) } func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) } func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) } func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) } func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) } func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) } func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) } func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) } func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) } func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) } func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) } func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) } func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) } func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) } func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) { const reqBody = "some request body" const resBody = "some response body" if resHeader == noHeader { // TODO: test 100-continue followed by immediate // server stream reset, without headers in the middle? panic("invalid combination") } ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) if expect100Continue != noHeader { req.Header.Set("Expect", "100-continue") } res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } defer res.Body.Close() if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want 200", res.StatusCode) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { return fmt.Errorf("Slurp: %v", err) } wantBody := resBody if !withData { wantBody = "" } if string(slurp) != wantBody { return fmt.Errorf("body = %q; want %q", slurp, wantBody) } if trailers == noHeader { if len(res.Trailer) > 0 { t.Errorf("Trailer = %v; want none", res.Trailer) } } else { want := http.Header{"Some-Trailer": {"some-value"}} if !reflect.DeepEqual(res.Trailer, want) { t.Errorf("Trailer = %v; want %v", res.Trailer, want) } } return nil } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for { f, err := ct.fr.ReadFrame() if err != nil { return err } endStream := false send := func(mode headerType) { hbf := buf.Bytes() switch mode { case oneHeader: ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.Header().StreamID, EndHeaders: true, EndStream: endStream, BlockFragment: hbf, }) case splitHeader: if len(hbf) < 2 { panic("too small") } ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.Header().StreamID, EndHeaders: false, EndStream: endStream, BlockFragment: hbf[:1], }) ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) default: panic("bogus mode") } } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *DataFrame: if !f.StreamEnded() { // No need to send flow control tokens. The test request body is tiny. continue } // Response headers (1+ frames; 1 or 2 in this test, but never 0) { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) if trailers != noHeader { enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) } endStream = withData == false && trailers == noHeader send(resHeader) } if withData { endStream = trailers == noHeader ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) } if trailers != noHeader { endStream = true buf.Reset() enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) send(trailers) } if endStream { return nil } case *HeadersFrame: if expect100Continue != noHeader { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) send(expect100Continue) } } } } ct.run() } // Issue 26189, Issue 17739: ignore unknown 1xx responses func TestTransportUnknown1xx(t *testing.T) { var buf bytes.Buffer defer func() { got1xxFuncForTests = nil }() got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { fmt.Fprintf(&buf, "code=%d header=%v\n", code, header) return nil } ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } defer res.Body.Close() if res.StatusCode != 204 { return fmt.Errorf("status code = %v; want 204", res.StatusCode) } want := `code=110 header=map[Foo-Bar:[110]] code=111 header=map[Foo-Bar:[111]] code=112 header=map[Foo-Bar:[112]] code=113 header=map[Foo-Bar:[113]] code=114 header=map[Foo-Bar:[114]] ` if got := buf.String(); got != want { t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) } return nil } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: for i := 110; i <= 114; i++ { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) } buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) return nil } } } ct.run() } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } defer res.Body.Close() if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want 200", res.StatusCode) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) } if len(slurp) > 0 { return fmt.Errorf("body = %q; want nothing", slurp) } if _, ok := res.Trailer["Some-Trailer"]; !ok { return fmt.Errorf("expected Some-Trailer") } return nil } ct.server = func() error { ct.greet() var n int var hf *HeadersFrame for hf == nil && n < 10 { f, err := ct.fr.ReadFrame() if err != nil { return err } hf, _ = f.(*HeadersFrame) n++ } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) // send headers without Trailer header enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) // send trailers buf.Reset() enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) return nil } ct.run() } func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, oneHeader) } func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, splitHeader) } func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) }) } func TestTransportInvalidTrailer_Capital1(t *testing.T) { testTransportInvalidTrailer_Capital(t, oneHeader) } func TestTransportInvalidTrailer_Capital2(t *testing.T) { testTransportInvalidTrailer_Capital(t, splitHeader) } func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) }) } func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) }) } func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) }) } func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } defer res.Body.Close() if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want 200", res.StatusCode) } slurp, err := ioutil.ReadAll(res.Body) se, ok := err.(StreamError) if !ok || se.Cause != wantErr { return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) } if len(slurp) > 0 { return fmt.Errorf("body = %q; want nothing", slurp) } return nil } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *HeadersFrame: var endStream bool send := func(mode headerType) { hbf := buf.Bytes() switch mode { case oneHeader: ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: endStream, BlockFragment: hbf, }) case splitHeader: if len(hbf) < 2 { panic("too small") } ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: false, EndStream: endStream, BlockFragment: hbf[:1], }) ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) default: panic("bogus mode") } } // Response headers (1+ frames; 1 or 2 in this test, but never 0) { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) endStream = false send(oneHeader) } // Trailers: { endStream = true buf.Reset() writeTrailer(enc) send(trailers) } return nil } } } ct.run() } // headerListSize returns the HTTP2 header list size of h. // // http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE // http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock func headerListSize(h http.Header) (size uint32) { for k, vv := range h { for _, v := range vv { hf := hpack.HeaderField{Name: k, Value: v} size += hf.Size() } } return size } // padHeaders adds data to an http.Header until headerListSize(h) == // limit. Due to the way header list sizes are calculated, padHeaders // cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will // call t.Fatal if asked to do so. PadHeaders first reserves enough // space for an empty "Pad-Headers" key, then adds as many copies of // filler as possible. Any remaining bytes necessary to push the // header list size up to limit are added to h["Pad-Headers"]. func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) { if limit > 0xffffffff { t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit) } hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} minPadding := uint64(hf.Size()) size := uint64(headerListSize(h)) minlimit := size + minPadding if limit < minlimit { t.Fatalf("padHeaders: limit %v < %v", limit, minlimit) } // Use a fixed-width format for name so that fieldSize // remains constant. nameFmt := "Pad-Headers-%06d" hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler} fieldSize := uint64(hf.Size()) // Add as many complete filler values as possible, leaving // room for at least one empty "Pad-Headers" key. limit = limit - minPadding for i := 0; size+fieldSize < limit; i++ { name := fmt.Sprintf(nameFmt, i) h.Add(name, filler) size += fieldSize } // Add enough bytes to reach limit. remain := limit - size lastValue := strings.Repeat("*", int(remain)) h.Add("Pad-Headers", lastValue) } func TestPadHeaders(t *testing.T) { check := func(h http.Header, limit uint32, fillerLen int) { if h == nil { h = make(http.Header) } filler := strings.Repeat("f", fillerLen) padHeaders(t, h, uint64(limit), filler) gotSize := headerListSize(h) if gotSize != limit { t.Errorf("Got size = %v; want %v", gotSize, limit) } } // Try all possible combinations for small fillerLen and limit. hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} minLimit := hf.Size() for limit := minLimit; limit <= 128; limit++ { for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ { check(nil, limit, fillerLen) } } // Try a few tests with larger limits, plus cumulative // tests. Since these tests are cumulative, tests[i+1].limit // must be >= tests[i].limit + minLimit. See the comment on // padHeaders for more info on why the limit arg has this // restriction. tests := []struct { fillerLen int limit uint32 }{ { fillerLen: 64, limit: 1024, }, { fillerLen: 1024, limit: 1286, }, { fillerLen: 256, limit: 2048, }, { fillerLen: 1024, limit: 10 * 1024, }, { fillerLen: 1023, limit: 11 * 1024, }, } h := make(http.Header) for _, tc := range tests { check(nil, tc.limit, tc.fillerLen) check(h, tc.limit, tc.fillerLen) } } func TestTransportChecksRequestHeaderListSize(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // Consume body & force client to send // trailers before writing response. // ioutil.ReadAll returns non-nil err for // requests that attempt to send greater than // maxHeaderListSize bytes of trailers, since // those requests generate a stream reset. ioutil.ReadAll(r.Body) r.Body.Close() }, func(ts *httptest.Server) { ts.Config.MaxHeaderBytes = 16 << 10 }, optOnlyServer, optQuiet, ) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() checkRoundTrip := func(req *http.Request, wantErr error, desc string) { // Make an arbitrary request to ensure we get the server's // settings frame and initialize peerMaxHeaderListSize. req0, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatalf("newRequest: NewRequest: %v", err) } res0, err := tr.RoundTrip(req0) if err != nil { t.Errorf("%v: Initial RoundTrip err = %v", desc, err) } res0.Body.Close() res, err := tr.RoundTrip(req) if err != wantErr { if res != nil { res.Body.Close() } t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) return } if err == nil { if res == nil { t.Errorf("%v: response nil; want non-nil.", desc) return } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK) } return } if res != nil { t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err) } } headerListSizeForRequest := func(req *http.Request) (size uint64) { contentLen := actualContentLength(req) trailers, err := commaSeparatedTrailers(req) if err != nil { t.Fatalf("headerListSizeForRequest: %v", err) } cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} cc.henc = hpack.NewEncoder(&cc.hbuf) cc.mu.Lock() hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen) cc.mu.Unlock() if err != nil { t.Fatalf("headerListSizeForRequest: %v", err) } hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) { size += uint64(hf.Size()) }) if len(hdrs) > 0 { if _, err := hpackDec.Write(hdrs); err != nil { t.Fatalf("headerListSizeForRequest: %v", err) } } return size } // Create a new Request for each test, rather than reusing the // same Request, to avoid a race when modifying req.Headers. // See https://github.com/golang/go/issues/21316 newRequest := func() *http.Request { // Body must be non-nil to enable writing trailers. body := strings.NewReader("hello") req, err := http.NewRequest("POST", st.ts.URL, body) if err != nil { t.Fatalf("newRequest: NewRequest: %v", err) } return req } // Validate peerMaxHeaderListSize. req := newRequest() checkRoundTrip(req, nil, "Initial request") addr := authorityAddr(req.URL.Scheme, req.URL.Host) cc, err := tr.connPool().GetClientConn(req, addr) if err != nil { t.Fatalf("GetClientConn: %v", err) } cc.mu.Lock() peerSize := cc.peerMaxHeaderListSize cc.mu.Unlock() st.scMu.Lock() wantSize := uint64(st.sc.maxHeaderListSize()) st.scMu.Unlock() if peerSize != wantSize { t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize) } // Sanity check peerSize. (*serverConn) maxHeaderListSize adds // 320 bytes of padding. wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320 if peerSize != wantHeaderBytes { t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes) } // Pad headers & trailers, but stay under peerSize. req = newRequest() req.Header = make(http.Header) req.Trailer = make(http.Header) filler := strings.Repeat("*", 1024) padHeaders(t, req.Trailer, peerSize, filler) // cc.encodeHeaders adds some default headers to the request, // so we need to leave room for those. defaultBytes := headerListSizeForRequest(req) padHeaders(t, req.Header, peerSize-defaultBytes, filler) checkRoundTrip(req, nil, "Headers & Trailers under limit") // Add enough header bytes to push us over peerSize. req = newRequest() req.Header = make(http.Header) padHeaders(t, req.Header, peerSize, filler) checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit") // Push trailers over the limit. req = newRequest() req.Trailer = make(http.Header) padHeaders(t, req.Trailer, peerSize+1, filler) checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit") // Send headers with a single large value. req = newRequest() filler = strings.Repeat("*", int(peerSize)) req.Header = make(http.Header) req.Header.Set("Big", filler) checkRoundTrip(req, errRequestHeaderListSize, "Single large header") // Send trailers with a single large value. req = newRequest() req.Trailer = make(http.Header) req.Trailer.Set("Big", filler) checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer") } func TestTransportChecksResponseHeaderListSize(t *testing.T) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if e, ok := err.(StreamError); ok { err = e.Cause } if err != errResponseHeaderListSize { size := int64(0) if res != nil { res.Body.Close() for k, vv := range res.Header { for _, v := range vv { size += int64(len(k)) + int64(len(v)) + 32 } } } return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) } return nil } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *HeadersFrame: enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) large := strings.Repeat("a", 1<<10) for i := 0; i < 5042; i++ { enc.WriteField(hpack.HeaderField{Name: large, Value: large}) } if size, want := buf.Len(), 6329; size != want { // Note: this number might change if // our hpack implementation // changes. That's fine. This is // just a sanity check that our // response can fit in a single // header block fragment frame. return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) } ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) return nil } } } ct.run() } func TestTransportCookieHeaderSplit(t *testing.T) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) req.Header.Add("Cookie", "a=b;c=d; e=f;") req.Header.Add("Cookie", "e=f;g=h; ") req.Header.Add("Cookie", "i=j") _, err := ct.tr.RoundTrip(req) return err } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *HeadersFrame: dec := hpack.NewDecoder(initialHeaderTableSize, nil) hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) if err != nil { return err } got := []string{} want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} for _, hf := range hfs { if hf.Name == "cookie" { got = append(got, hf.Value) } } if !reflect.DeepEqual(got, want) { t.Errorf("Cookies = %#v, want %#v", got, want) } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) return nil } } } ct.run() } // Test that the Transport returns a typed error from Response.Body.Read calls // when the server sends an error. (here we use a panic, since that should generate // a stream error, but others like cancel should be similar) func TestTransportBodyReadErrorType(t *testing.T) { doPanic := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() // force headers out <-doPanic panic("boom") }, optOnlyServer, optQuiet, ) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} res, err := c.Get(st.ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() doPanic <- true buf := make([]byte, 100) n, err := res.Body.Read(buf) got, ok := err.(StreamError) want := StreamError{StreamID: 0x1, Code: 0x2} if !ok || got.StreamID != want.StreamID || got.Code != want.Code { t.Errorf("Read = %v, %#v; want error %#v", n, err, want) } } // golang.org/issue/13924 // This used to fail after many iterations, especially with -race: // go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race func TestTransportDoubleCloseOnWriteError(t *testing.T) { var ( mu sync.Mutex conn net.Conn // to close if set ) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { mu.Lock() defer mu.Unlock() if conn != nil { conn.Close() } }, optOnlyServer, ) defer st.Close() tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) if err != nil { return nil, err } mu.Lock() defer mu.Unlock() conn = tc return tc, nil }, } defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} c.Get(st.ts.URL) } // Test that the http1 Transport.DisableKeepAlives option is respected // and connections are closed as soon as idle. // See golang.org/issue/14008 func TestTransportDisableKeepAlives(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "hi") }, optOnlyServer, ) defer st.Close() connClosed := make(chan struct{}) // closed on tls.Conn.Close tr := &Transport{ t1: &http.Transport{ DisableKeepAlives: true, }, TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) if err != nil { return nil, err } return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil }, } c := &http.Client{Transport: tr} res, err := c.Get(st.ts.URL) if err != nil { t.Fatal(err) } if _, err := ioutil.ReadAll(res.Body); err != nil { t.Fatal(err) } defer res.Body.Close() select { case <-connClosed: case <-time.After(1 * time.Second): t.Errorf("timeout") } } // Test concurrent requests with Transport.DisableKeepAlives. We can share connections, // but when things are totally idle, it still needs to close. func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { const D = 25 * time.Millisecond st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { time.Sleep(D) io.WriteString(w, "hi") }, optOnlyServer, ) defer st.Close() var dials int32 var conns sync.WaitGroup tr := &Transport{ t1: &http.Transport{ DisableKeepAlives: true, }, TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) if err != nil { return nil, err } atomic.AddInt32(&dials, 1) conns.Add(1) return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil }, } c := &http.Client{Transport: tr} var reqs sync.WaitGroup const N = 20 for i := 0; i < N; i++ { reqs.Add(1) if i == N-1 { // For the final request, try to make all the // others close. This isn't verified in the // count, other than the Log statement, since // it's so timing dependent. This test is // really to make sure we don't interrupt a // valid request. time.Sleep(D * 2) } go func() { defer reqs.Done() res, err := c.Get(st.ts.URL) if err != nil { t.Error(err) return } if _, err := ioutil.ReadAll(res.Body); err != nil { t.Error(err) return } res.Body.Close() }() } reqs.Wait() conns.Wait() t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N) } type noteCloseConn struct { net.Conn onceClose sync.Once closefn func() } func (c *noteCloseConn) Close() error { c.onceClose.Do(c.closefn) return c.Conn.Close() } func isTimeout(err error) bool { switch err := err.(type) { case nil: return false case *url.Error: return isTimeout(err.Err) case net.Error: return err.Timeout() } return false } // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent. func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) { testTransportResponseHeaderTimeout(t, false) } func TestTransportResponseHeaderTimeout_Body(t *testing.T) { testTransportResponseHeaderTimeout(t, true) } func testTransportResponseHeaderTimeout(t *testing.T, body bool) { ct := newClientTester(t) ct.tr.t1 = &http.Transport{ ResponseHeaderTimeout: 5 * time.Millisecond, } ct.client = func() error { c := &http.Client{Transport: ct.tr} var err error var n int64 const bodySize = 4 << 20 if body { _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) } else { _, err = c.Get("https://dummy.tld/") } if !isTimeout(err) { t.Errorf("client expected timeout error; got %#v", err) } if body && n != bodySize { t.Errorf("only read %d bytes of body; want %d", n, bodySize) } return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { t.Logf("ReadFrame: %v", err) return nil } switch f := f.(type) { case *DataFrame: dataLen := len(f.Data()) if dataLen > 0 { if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { return err } if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { return err } } case *RSTStreamFrame: if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { return nil } } } } ct.run() } func TestTransportDisableCompression(t *testing.T) { const body = "sup" st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { want := http.Header{ "User-Agent": []string{"Go-http-client/2.0"}, } if !reflect.DeepEqual(r.Header, want) { t.Errorf("request headers = %v; want %v", r.Header, want) } }, optOnlyServer) defer st.Close() tr := &Transport{ TLSClientConfig: tlsConfigInsecure, t1: &http.Transport{ DisableCompression: true, }, } defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() } // RFC 7540 section 8.1.2.2 func TestTransportRejectsConnHeaders(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { var got []string for k := range r.Header { got = append(got, k) } sort.Strings(got) w.Header().Set("Got-Header", strings.Join(got, ",")) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() tests := []struct { key string value []string want string }{ { key: "Upgrade", value: []string{"anything"}, want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]", }, { key: "Connection", value: []string{"foo"}, want: "ERROR: http2: invalid Connection request header: [\"foo\"]", }, { key: "Connection", value: []string{"close"}, want: "Accept-Encoding,User-Agent", }, { key: "Connection", value: []string{"CLoSe"}, want: "Accept-Encoding,User-Agent", }, { key: "Connection", value: []string{"close", "something-else"}, want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]", }, { key: "Connection", value: []string{"keep-alive"}, want: "Accept-Encoding,User-Agent", }, { key: "Connection", value: []string{"Keep-ALIVE"}, want: "Accept-Encoding,User-Agent", }, { key: "Proxy-Connection", // just deleted and ignored value: []string{"keep-alive"}, want: "Accept-Encoding,User-Agent", }, { key: "Transfer-Encoding", value: []string{""}, want: "Accept-Encoding,User-Agent", }, { key: "Transfer-Encoding", value: []string{"foo"}, want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]", }, { key: "Transfer-Encoding", value: []string{"chunked"}, want: "Accept-Encoding,User-Agent", }, { key: "Transfer-Encoding", value: []string{"chunKed"}, // Kelvin sign want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]", }, { key: "Transfer-Encoding", value: []string{"chunked", "other"}, want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]", }, { key: "Content-Length", value: []string{"123"}, want: "Accept-Encoding,User-Agent", }, { key: "Keep-Alive", value: []string{"doop"}, want: "Accept-Encoding,User-Agent", }, } for _, tt := range tests { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header[tt.key] = tt.value res, err := tr.RoundTrip(req) var got string if err != nil { got = fmt.Sprintf("ERROR: %v", err) } else { got = res.Header.Get("Got-Header") res.Body.Close() } if got != tt.want { t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want) } } } // Reject content-length headers containing a sign. // See https://golang.org/issue/39017 func TestTransportRejectsContentLengthWithSign(t *testing.T) { tests := []struct { name string cl []string wantCL string }{ { name: "proper content-length", cl: []string{"3"}, wantCL: "3", }, { name: "ignore cl with plus sign", cl: []string{"+3"}, wantCL: "", }, { name: "ignore cl with minus sign", cl: []string{"-3"}, wantCL: "", }, { name: "max int64, for safe uint64->int64 conversion", cl: []string{"9223372036854775807"}, wantCL: "9223372036854775807", }, { name: "overflows int64, so ignored", cl: []string{"9223372036854775808"}, wantCL: "", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", tt.cl[0]) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, _ := http.NewRequest("HEAD", st.ts.URL, nil) res, err := tr.RoundTrip(req) var got string if err != nil { got = fmt.Sprintf("ERROR: %v", err) } else { got = res.Header.Get("Content-Length") res.Body.Close() } if got != tt.wantCL { t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL) } }) } } // golang.org/issue/14048 func TestTransportFailsOnInvalidHeaders(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { var got []string for k := range r.Header { got = append(got, k) } sort.Strings(got) w.Header().Set("Got-Header", strings.Join(got, ",")) }, optOnlyServer) defer st.Close() tests := [...]struct { h http.Header wantErr string }{ 0: { h: http.Header{"with space": {"foo"}}, wantErr: `invalid HTTP header name "with space"`, }, 1: { h: http.Header{"name": {"Брэд"}}, wantErr: "", // okay }, 2: { h: http.Header{"имя": {"Brad"}}, wantErr: `invalid HTTP header name "имя"`, }, 3: { h: http.Header{"foo": {"foo\x01bar"}}, wantErr: `invalid HTTP header value for header "foo"`, }, } tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() for i, tt := range tests { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header = tt.h res, err := tr.RoundTrip(req) var bad bool if tt.wantErr == "" { if err != nil { bad = true t.Errorf("case %d: error = %v; want no error", i, err) } } else { if !strings.Contains(fmt.Sprint(err), tt.wantErr) { bad = true t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr) } } if err == nil { if bad { t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header")) } res.Body.Close() } } } // Tests that gzipReader doesn't crash on a second Read call following // the first Read call's gzip.NewReader returning an error. func TestGzipReader_DoubleReadCrash(t *testing.T) { gz := &gzipReader{ body: ioutil.NopCloser(strings.NewReader("0123456789")), } var buf [1]byte n, err1 := gz.Read(buf[:]) if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") { t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1) } n, err2 := gz.Read(buf[:]) if n != 0 || err2 != err1 { t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1) } } func TestGzipReader_ReadAfterClose(t *testing.T) { body := bytes.Buffer{} w := gzip.NewWriter(&body) w.Write([]byte("012345679")) w.Close() gz := &gzipReader{ body: ioutil.NopCloser(&body), } var buf [1]byte n, err := gz.Read(buf[:]) if n != 1 || err != nil { t.Fatalf("first Read = %v, %v; want 1, nil", n, err) } if err := gz.Close(); err != nil { t.Fatalf("gz Close error: %v", err) } n, err = gz.Read(buf[:]) if n != 0 || err != fs.ErrClosed { t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err) } } func TestTransportNewTLSConfig(t *testing.T) { tests := [...]struct { conf *tls.Config host string want *tls.Config }{ // Normal case. 0: { conf: nil, host: "foo.com", want: &tls.Config{ ServerName: "foo.com", NextProtos: []string{NextProtoTLS}, }, }, // User-provided name (bar.com) takes precedence: 1: { conf: &tls.Config{ ServerName: "bar.com", }, host: "foo.com", want: &tls.Config{ ServerName: "bar.com", NextProtos: []string{NextProtoTLS}, }, }, // NextProto is prepended: 2: { conf: &tls.Config{ NextProtos: []string{"foo", "bar"}, }, host: "example.com", want: &tls.Config{ ServerName: "example.com", NextProtos: []string{NextProtoTLS, "foo", "bar"}, }, }, // NextProto is not duplicated: 3: { conf: &tls.Config{ NextProtos: []string{"foo", "bar", NextProtoTLS}, }, host: "example.com", want: &tls.Config{ ServerName: "example.com", NextProtos: []string{"foo", "bar", NextProtoTLS}, }, }, } for i, tt := range tests { // Ignore the session ticket keys part, which ends up populating // unexported fields in the Config: if tt.conf != nil { tt.conf.SessionTicketsDisabled = true } tr := &Transport{TLSClientConfig: tt.conf} got := tr.newTLSConfig(tt.host) got.SessionTicketsDisabled = false if !reflect.DeepEqual(got, tt.want) { t.Errorf("%d. got %#v; want %#v", i, got, tt.want) } } } // The Google GFE responds to HEAD requests with a HEADERS frame // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) func TestTransportReadHeadResponse(t *testing.T) { ct := newClientTester(t) clientDone := make(chan struct{}) ct.client = func() error { defer close(clientDone) req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return err } if res.ContentLength != 123 { return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { return fmt.Errorf("ReadAll: %v", err) } if len(slurp) > 0 { return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) } return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { t.Logf("ReadFrame: %v", err) return nil } hf, ok := f.(*HeadersFrame) if !ok { continue } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, // as the GFE does BlockFragment: buf.Bytes(), }) ct.fr.WriteData(hf.StreamID, true, nil) <-clientDone return nil } } ct.run() } func TestTransportReadHeadResponseWithBody(t *testing.T) { // This test use not valid response format. // Discarding logger output to not spam tests output. log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) response := "redirecting to /elsewhere" ct := newClientTester(t) clientDone := make(chan struct{}) ct.client = func() error { defer close(clientDone) req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return err } if res.ContentLength != int64(len(response)) { return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { return fmt.Errorf("ReadAll: %v", err) } if len(slurp) > 0 { return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) } return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { t.Logf("ReadFrame: %v", err) return nil } hf, ok := f.(*HeadersFrame) if !ok { continue } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) ct.fr.WriteData(hf.StreamID, true, []byte(response)) <-clientDone return nil } } ct.run() } type neverEnding byte func (b neverEnding) Read(p []byte) (int, error) { for i := range p { p[i] = byte(b) } return len(p), nil } // golang.org/issue/15425: test that a handler closing the request // body doesn't terminate the stream to the peer. (It just stops // readability from the handler's side, and eventually the client // runs out of flow control tokens) func TestTransportHandlerBodyClose(t *testing.T) { const bodySize = 10 << 20 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { r.Body.Close() io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() g0 := runtime.NumGoroutine() const numReq = 10 for i := 0; i < numReq; i++ { req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } n, err := io.Copy(ioutil.Discard, res.Body) res.Body.Close() if n != bodySize || err != nil { t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) } } tr.CloseIdleConnections() if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool { gd := runtime.NumGoroutine() - g0 return gd < numReq/2 }) { t.Errorf("appeared to leak goroutines") } } // https://golang.org/issue/15930 func TestTransportFlowControl(t *testing.T) { const bufLen = 64 << 10 var total int64 = 100 << 20 // 100MB if testing.Short() { total = 10 << 20 } var wrote int64 // updated atomically st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { b := make([]byte, bufLen) for wrote < total { n, err := w.Write(b) atomic.AddInt64(&wrote, int64(n)) if err != nil { t.Errorf("ResponseWriter.Write error: %v", err) break } w.(http.Flusher).Flush() } }, optOnlyServer) tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal("NewRequest error:", err) } resp, err := tr.RoundTrip(req) if err != nil { t.Fatal("RoundTrip error:", err) } defer resp.Body.Close() var read int64 b := make([]byte, bufLen) for { n, err := resp.Body.Read(b) if err == io.EOF { break } if err != nil { t.Fatal("Read error:", err) } read += int64(n) const max = transportDefaultStreamFlow if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) } // Let the server get ahead of the client. time.Sleep(1 * time.Millisecond) } } // golang.org/issue/14627 -- if the server sends a GOAWAY frame, make // the Transport remember it and return it back to users (via // RoundTrip or request body reads) if needed (e.g. if the server // proceeds to close the TCP connection before the client gets its // response) func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { testTransportUsesGoAwayDebugError(t, false) } func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { testTransportUsesGoAwayDebugError(t, true) } func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { ct := newClientTester(t) clientDone := make(chan struct{}) const goAwayErrCode = ErrCodeHTTP11Required // arbitrary const goAwayDebugData = "some debug data" ct.client = func() error { defer close(clientDone) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if failMidBody { if err != nil { return fmt.Errorf("unexpected client RoundTrip error: %v", err) } _, err = io.Copy(ioutil.Discard, res.Body) res.Body.Close() } want := GoAwayError{ LastStreamID: 5, ErrCode: goAwayErrCode, DebugData: goAwayDebugData, } if !reflect.DeepEqual(err, want) { t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) } return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { t.Logf("ReadFrame: %v", err) return nil } hf, ok := f.(*HeadersFrame) if !ok { continue } if failMidBody { var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) } // Write two GOAWAY frames, to test that the Transport takes // the interesting parts of both. ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) ct.fr.WriteGoAway(5, goAwayErrCode, nil) ct.sc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 ct.sc.(*net.TCPConn).Close() } <-clientDone return nil } } ct.run() } func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return err } if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) } res.Body.Close() // leaving 4999 bytes unread return nil } ct.server = func() error { ct.greet() var hf *HeadersFrame for { f, err := ct.fr.ReadFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f.(type) { case *WindowUpdateFrame, *SettingsFrame: continue } var ok bool hf, ok = f.(*HeadersFrame) if !ok { return fmt.Errorf("Got %T; want HeadersFrame", f) } break } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) initialInflow := ct.inflowWindow(0) // Two cases: // - Send one DATA frame with 5000 bytes. // - Send two DATA frames with 1 and 4999 bytes each. // // In both cases, the client should consume one byte of data, // refund that byte, then refund the following 4999 bytes. // // In the second case, the server waits for the client to reset the // stream before sending the second DATA frame. This tests the case // where the client receives a DATA frame after it has reset the stream. if oneDataFrame { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) } else { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) } wantRST := true wantWUF := true if !oneDataFrame { wantWUF = false // flow control update is small, and will not be sent } for wantRST || wantWUF { f, err := ct.readNonSettingsFrame() if err != nil { return err } switch f := f.(type) { case *RSTStreamFrame: if !wantRST { return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } if f.ErrCode != ErrCodeCancel { return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } wantRST = false case *WindowUpdateFrame: if !wantWUF { return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } if f.Increment != 5000 { return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } wantWUF = false default: return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } } if !oneDataFrame { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) f, err := ct.readNonSettingsFrame() if err != nil { return err } wuf, ok := f.(*WindowUpdateFrame) if !ok || wuf.Increment != 5000 { return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) } } if err := ct.writeReadPing(); err != nil { return err } if got, want := ct.inflowWindow(0), initialInflow; got != want { return fmt.Errorf("connection flow tokens = %v, want %v", got, want) } return nil } ct.run() } // See golang.org/issue/16481 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) { testTransportReturnsUnusedFlowControl(t, true) } // See golang.org/issue/20469 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { testTransportReturnsUnusedFlowControl(t, false) } // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. func TestTransportAdjustsFlowControl(t *testing.T) { ct := newClientTester(t) clientDone := make(chan struct{}) const bodySize = 1 << 20 ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } defer close(clientDone) req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) res, err := ct.tr.RoundTrip(req) if err != nil { return err } res.Body.Close() return nil } ct.server = func() error { _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) if err != nil { return fmt.Errorf("reading client preface: %v", err) } var gotBytes int64 var sentSettings bool for { f, err := ct.fr.ReadFrame() if err != nil { select { case <-clientDone: return nil default: return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } } switch f := f.(type) { case *DataFrame: gotBytes += int64(len(f.Data())) // After we've got half the client's // initial flow control window's worth // of request body data, give it just // enough flow control to finish. if gotBytes >= initialWindowSize/2 && !sentSettings { sentSettings = true ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) ct.fr.WriteWindowUpdate(0, bodySize) ct.fr.WriteSettingsAck() } if f.StreamEnded() { var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) } } } } ct.run() } // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { ct := newClientTester(t) unblockClient := make(chan bool, 1) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return err } defer res.Body.Close() <-unblockClient return nil } ct.server = func() error { ct.greet() var hf *HeadersFrame for { f, err := ct.fr.ReadFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f.(type) { case *WindowUpdateFrame, *SettingsFrame: continue } var ok bool hf, ok = f.(*HeadersFrame) if !ok { return fmt.Errorf("Got %T; want HeadersFrame", f) } break } initialConnWindow := ct.inflowWindow(0) var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) initialStreamWindow := ct.inflowWindow(hf.StreamID) pad := make([]byte, 5) ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream if err := ct.writeReadPing(); err != nil { return err } // Padding flow control should have been returned. if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { t.Errorf("conn inflow window = %v, want %v", got, want) } if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { t.Errorf("stream inflow window = %v, want %v", got, want) } unblockClient <- true return nil } ct.run() } // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err == nil { res.Body.Close() return errors.New("unexpected successful GET") } want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} if !reflect.DeepEqual(want, err) { t.Errorf("RoundTrip error = %#v; want %#v", err, want) } return nil } ct.server = func() error { ct.greet() hf, err := ct.firstHeaders() if err != nil { return err } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) for { fr, err := ct.readFrame() if err != nil { return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) } if _, ok := fr.(*SettingsFrame); ok { continue } if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) } break } return nil } ct.run() } // byteAndEOFReader returns is in an io.Reader which reads one byte // (the underlying byte) and io.EOF at once in its Read call. type byteAndEOFReader byte func (b byteAndEOFReader) Read(p []byte) (n int, err error) { if len(p) == 0 { panic("unexpected useless call") } p[0] = byte(b) return 1, io.EOF } // Issue 16788: the Transport had a regression where it started // sending a spurious DATA frame with a duplicate END_STREAM bit after // the request body writer goroutine had already read an EOF from the // Request.Body and included the END_STREAM on a data-carrying DATA // frame. // // Notably, to trigger this, the requests need to use a Request.Body // which returns (non-0, io.EOF) and also needs to set the ContentLength // explicitly. func TestTransportBodyDoubleEndStream(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // Nothing. }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() for i := 0; i < 2; i++ { req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a')) req.ContentLength = 1 res, err := tr.RoundTrip(req) if err != nil { t.Fatalf("failure on req %d: %v", i+1, err) } defer res.Body.Close() } } // golang.org/issue/16847, golang.org/issue/19103 func TestTransportRequestPathPseudo(t *testing.T) { type result struct { path string err string } tests := []struct { req *http.Request want result }{ 0: { req: &http.Request{ Method: "GET", URL: &url.URL{ Host: "foo.com", Path: "/foo", }, }, want: result{path: "/foo"}, }, // In Go 1.7, we accepted paths of "//foo". // In Go 1.8, we rejected it (issue 16847). // In Go 1.9, we accepted it again (issue 19103). 1: { req: &http.Request{ Method: "GET", URL: &url.URL{ Host: "foo.com", Path: "//foo", }, }, want: result{path: "//foo"}, }, // Opaque with //$Matching_Hostname/path 2: { req: &http.Request{ Method: "GET", URL: &url.URL{ Scheme: "https", Opaque: "//foo.com/path", Host: "foo.com", Path: "/ignored", }, }, want: result{path: "/path"}, }, // Opaque with some other Request.Host instead: 3: { req: &http.Request{ Method: "GET", Host: "bar.com", URL: &url.URL{ Scheme: "https", Opaque: "//bar.com/path", Host: "foo.com", Path: "/ignored", }, }, want: result{path: "/path"}, }, // Opaque without the leading "//": 4: { req: &http.Request{ Method: "GET", URL: &url.URL{ Opaque: "/path", Host: "foo.com", Path: "/ignored", }, }, want: result{path: "/path"}, }, // Opaque we can't handle: 5: { req: &http.Request{ Method: "GET", URL: &url.URL{ Scheme: "https", Opaque: "//unknown_host/path", Host: "foo.com", Path: "/ignored", }, }, want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`}, }, // A CONNECT request: 6: { req: &http.Request{ Method: "CONNECT", URL: &url.URL{ Host: "foo.com", }, }, want: result{}, }, } for i, tt := range tests { cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} cc.henc = hpack.NewEncoder(&cc.hbuf) cc.mu.Lock() hdrs, err := cc.encodeHeaders(tt.req, false, "", -1) cc.mu.Unlock() var got result hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) { if f.Name == ":path" { got.path = f.Value } }) if err != nil { got.err = err.Error() } else if len(hdrs) > 0 { if _, err := hpackDec.Write(hdrs); err != nil { t.Errorf("%d. bogus hpack: %v", i, err) continue } } if got != tt.want { t.Errorf("%d. got %+v; want %+v", i, got, tt.want) } } } // golang.org/issue/17071 -- don't sniff the first byte of the request body // before we've determined that the ClientConn is usable. func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body))) cc := &ClientConn{ closed: true, reqHeaderMu: make(chan struct{}, 1), } _, err := cc.RoundTrip(req) if err != errClientConnUnusable { t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) } slurp, err := ioutil.ReadAll(req.Body) if err != nil { t.Errorf("ReadAll = %v", err) } if string(slurp) != body { t.Errorf("Body = %q; want %q", slurp, body) } } func TestClientConnPing(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) } if err = cc.Ping(context.Background()); err != nil { t.Fatal(err) } } // Issue 16974: if the server sent a DATA frame after the user // canceled the Transport's Request, the Transport previously wrote to a // closed pipe, got an error, and ended up closing the whole TCP // connection. func TestTransportCancelDataResponseRace(t *testing.T) { cancel := make(chan struct{}) clientGotResponse := make(chan bool, 1) const msg = "Hello." st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/hello") { time.Sleep(50 * time.Millisecond) io.WriteString(w, msg) return } for i := 0; i < 50; i++ { io.WriteString(w, "Some data.") w.(http.Flusher).Flush() if i == 2 { <-clientGotResponse close(cancel) } time.Sleep(10 * time.Millisecond) } }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Cancel = cancel res, err := c.Do(req) clientGotResponse <- true if err != nil { t.Fatal(err) } if _, err = io.Copy(ioutil.Discard, res.Body); err == nil { t.Fatal("unexpected success") } res, err = c.Get(st.ts.URL + "/hello") if err != nil { t.Fatal(err) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatal(err) } if string(slurp) != msg { t.Errorf("Got = %q; want %q", slurp, msg) } } // Issue 21316: It should be safe to reuse an http.Request after the // request has completed. func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) io.WriteString(w, "body") }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, _ := http.NewRequest("GET", st.ts.URL, nil) resp, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil { t.Fatalf("error reading response body: %v", err) } if err := resp.Body.Close(); err != nil { t.Fatalf("error closing response body: %v", err) } // This access of req.Header should not race with code in the transport. req.Header = http.Header{} } func TestTransportCloseAfterLostPing(t *testing.T) { clientDone := make(chan struct{}) ct := newClientTester(t) ct.tr.PingTimeout = 1 * time.Second ct.tr.ReadIdleTimeout = 1 * time.Second ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() defer close(clientDone) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) _, err := ct.tr.RoundTrip(req) if err == nil || !strings.Contains(err.Error(), "client connection lost") { return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) } return nil } ct.server = func() error { ct.greet() <-clientDone return nil } ct.run() } func TestTransportPingWriteBlocks(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer, ) defer st.Close() tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { s, c := net.Pipe() // unbuffered, unlike a TCP conn go func() { // Read initial handshake frames. // Without this, we block indefinitely in newClientConn, // and never get to the point of sending a PING. var buf [1024]byte s.Read(buf[:]) }() return c, nil }, PingTimeout: 1 * time.Millisecond, ReadIdleTimeout: 1 * time.Millisecond, } defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} _, err := c.Get(st.ts.URL) if err == nil { t.Fatalf("Get = nil, want error") } } func TestTransportPingWhenReading(t *testing.T) { testCases := []struct { name string readIdleTimeout time.Duration deadline time.Duration expectedPingCount int }{ { name: "two pings", readIdleTimeout: 100 * time.Millisecond, deadline: time.Second, expectedPingCount: 2, }, { name: "zero ping", readIdleTimeout: time.Second, deadline: 200 * time.Millisecond, expectedPingCount: 0, }, { name: "0 readIdleTimeout means no ping", readIdleTimeout: 0 * time.Millisecond, deadline: 500 * time.Millisecond, expectedPingCount: 0, }, } for _, tc := range testCases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) }) } } func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { var pingCount int ct := newClientTester(t) ct.tr.ReadIdleTimeout = readIdleTimeout ctx, cancel := context.WithTimeout(context.Background(), deadline) defer cancel() ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } defer res.Body.Close() if res.StatusCode != 200 { return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) } _, err = ioutil.ReadAll(res.Body) if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { return nil } cancel() return err } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) var streamID uint32 for { f, err := ct.fr.ReadFrame() if err != nil { select { case <-ctx.Done(): // If the client's done, it // will have reported any // errors on its side. return nil default: return err } } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) streamID = f.StreamID case *PingFrame: pingCount++ if pingCount == expectedPingCount { if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { return err } } if err := ct.fr.WritePing(true, f.Data); err != nil { return err } case *RSTStreamFrame: default: return fmt.Errorf("Unexpected client frame %v", f) } } } ct.run() } func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { ln := newLocalListener(t) defer ln.Close() var ( mu sync.Mutex count int conns []net.Conn ) var wg sync.WaitGroup tr := &Transport{ TLSClientConfig: tlsConfigInsecure, } tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() defer mu.Unlock() count++ cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { return nil, fmt.Errorf("dial error: %v", err) } conns = append(conns, cc) sc, err := ln.Accept() if err != nil { return nil, fmt.Errorf("accept error: %v", err) } conns = append(conns, sc) ct := &clientTester{ t: t, tr: tr, cc: cc, sc: sc, fr: NewFramer(sc, sc), } wg.Add(1) go func(count int) { defer wg.Done() server(count, ct) }(count) return cc, nil } client(tr) tr.CloseIdleConnections() ln.Close() for _, c := range conns { c.Close() } wg.Wait() } func TestTransportRetryAfterGOAWAY(t *testing.T) { client := func(tr *Transport) { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := tr.RoundTrip(req) if res != nil { res.Body.Close() if got := res.Header.Get("Foo"); got != "bar" { err = fmt.Errorf("foo header = %q; want bar", got) } } if err != nil { t.Errorf("RoundTrip: %v", err) } } server := func(count int, ct *clientTester) { switch count { case 1: ct.greet() hf, err := ct.firstHeaders() if err != nil { t.Errorf("server1 failed reading HEADERS: %v", err) return } t.Logf("server1 got %v", hf) if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { t.Errorf("server1 failed writing GOAWAY: %v", err) return } case 2: ct.greet() hf, err := ct.firstHeaders() if err != nil { t.Errorf("server2 failed reading HEADERS: %v", err) return } t.Logf("server2 got %v", hf) var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) err = ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) if err != nil { t.Errorf("server2 failed writing response HEADERS: %v", err) } default: t.Errorf("unexpected number of dials") return } } testClientMultipleDials(t, client, server) } func TestTransportRetryAfterRefusedStream(t *testing.T) { clientDone := make(chan struct{}) client := func(tr *Transport) { defer close(clientDone) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) resp, err := tr.RoundTrip(req) if err != nil { t.Errorf("RoundTrip: %v", err) return } resp.Body.Close() if resp.StatusCode != 204 { t.Errorf("Status = %v; want 204", resp.StatusCode) return } } server := func(_ int, ct *clientTester) { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) var count int for { f, err := ct.fr.ReadFrame() if err != nil { select { case <-clientDone: // If the client's done, it // will have reported any // errors on its side. default: t.Error(err) } return } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { t.Errorf("headers should have END_HEADERS be ended: %v", f) return } count++ if count == 1 { ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) } else { enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) } default: t.Errorf("Unexpected client frame %v", f) return } } } testClientMultipleDials(t, client, server) } func TestTransportRetryHasLimit(t *testing.T) { // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. if testing.Short() { t.Skip("skipping long test in short mode") } retryBackoffHook = func(d time.Duration) *time.Timer { return time.NewTimer(0) // fires immediately } defer func() { retryBackoffHook = nil }() clientDone := make(chan struct{}) ct := newClientTester(t) ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } defer close(clientDone) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) resp, err := ct.tr.RoundTrip(req) if err == nil { return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) } t.Logf("expected error, got: %v", err) return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { select { case <-clientDone: // If the client's done, it // will have reported any // errors on its side. return nil default: return err } } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) default: return fmt.Errorf("Unexpected client frame %v", f) } } } ct.run() } func TestTransportResponseDataBeforeHeaders(t *testing.T) { // This test use not valid response format. // Discarding logger output to not spam tests output. log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) ct := newClientTester(t) ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } req := httptest.NewRequest("GET", "https://dummy.tld/", nil) // First request is normal to ensure the check is per stream and not per connection. _, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip expected no error, got: %v", err) } // Second request returns a DATA frame with no HEADERS. resp, err := ct.tr.RoundTrip(req) if err == nil { return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) } if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) } return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err == io.EOF { return nil } else if err != nil { return err } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: case *HeadersFrame: switch f.StreamID { case 1: // Send a valid response to first request. var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) case 3: ct.fr.WriteData(f.StreamID, true, []byte("payload")) } default: return fmt.Errorf("Unexpected client frame %v", f) } } } ct.run() } func TestTransportMaxFrameReadSize(t *testing.T) { for _, test := range []struct { maxReadFrameSize uint32 want uint32 }{{ maxReadFrameSize: 64000, want: 64000, }, { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { ct := newClientTester(t) ct.tr.MaxReadFrameSize = test.maxReadFrameSize ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) ct.tr.RoundTrip(req) return nil } ct.server = func() error { defer ct.cc.(*net.TCPConn).Close() ct.greet() var got uint32 ct.settings.ForeachSetting(func(s Setting) error { switch s.ID { case SettingMaxFrameSize: got = s.Val } return nil }) if got != test.want { t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) } return nil } ct.run() } } func TestTransportRequestsLowServerLimit(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, optOnlyServer, func(s *Server) { s.MaxConcurrentStreams = 1 }) defer st.Close() var ( connCountMu sync.Mutex connCount int ) tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { connCountMu.Lock() defer connCountMu.Unlock() connCount++ return tls.Dial(network, addr, cfg) }, } defer tr.CloseIdleConnections() const reqCount = 3 for i := 0; i < reqCount; i++ { req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("StatusCode = %v; want %v", got, want) } if res != nil && res.Body != nil { res.Body.Close() } } if connCount != 1 { t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) } } // tests Transport.StrictMaxConcurrentStreams func TestTransportRequestsStallAtServerLimit(t *testing.T) { const maxConcurrent = 2 greet := make(chan struct{}) // server sends initial SETTINGS frame gotRequest := make(chan struct{}) // server received a request clientDone := make(chan struct{}) cancelClientRequest := make(chan struct{}) // Collect errors from goroutines. var wg sync.WaitGroup errs := make(chan error, 100) defer func() { wg.Wait() close(errs) for err := range errs { t.Error(err) } }() // We will send maxConcurrent+2 requests. This checker goroutine waits for the // following stages: // 1. The first maxConcurrent requests are received by the server. // 2. The client will cancel the next request // 3. The server is unblocked so it can service the first maxConcurrent requests // 4. The client will send the final request wg.Add(1) unblockClient := make(chan struct{}) clientRequestCancelled := make(chan struct{}) unblockServer := make(chan struct{}) go func() { defer wg.Done() // Stage 1. for k := 0; k < maxConcurrent; k++ { <-gotRequest } // Stage 2. close(unblockClient) <-clientRequestCancelled // Stage 3: give some time for the final RoundTrip call to be scheduled and // verify that the final request is not sent. time.Sleep(50 * time.Millisecond) select { case <-gotRequest: errs <- errors.New("last request did not stall") close(unblockServer) return default: } close(unblockServer) // Stage 4. <-gotRequest }() ct := newClientTester(t) ct.tr.StrictMaxConcurrentStreams = true ct.client = func() error { var wg sync.WaitGroup defer func() { wg.Wait() close(clientDone) ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 ct.cc.(*net.TCPConn).Close() } }() for k := 0; k < maxConcurrent+2; k++ { wg.Add(1) go func(k int) { defer wg.Done() // Don't send the second request until after receiving SETTINGS from the server // to avoid a race where we use the default SettingMaxConcurrentStreams, which // is much larger than maxConcurrent. We have to send the first request before // waiting because the first request triggers the dial and greet. if k > 0 { <-greet } // Block until maxConcurrent requests are sent before sending any more. if k >= maxConcurrent { <-unblockClient } body := newStaticCloseChecker("") req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) if k == maxConcurrent { // This request will be canceled. req.Cancel = cancelClientRequest close(cancelClientRequest) _, err := ct.tr.RoundTrip(req) close(clientRequestCancelled) if err == nil { errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) return } } else { resp, err := ct.tr.RoundTrip(req) if err != nil { errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) return } ioutil.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode != 204 { errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) return } } if err := body.isClosed(); err != nil { errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) } }(k) } return nil } ct.server = func() error { var wg sync.WaitGroup defer wg.Wait() ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) // Server write loop. var buf bytes.Buffer enc := hpack.NewEncoder(&buf) writeResp := make(chan uint32, maxConcurrent+1) wg.Add(1) go func() { defer wg.Done() <-unblockServer for id := range writeResp { buf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: id, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) } }() // Server read loop. var nreq int for { f, err := ct.fr.ReadFrame() if err != nil { select { case <-clientDone: // If the client's done, it will have reported any errors on its side. return nil default: return err } } switch f := f.(type) { case *WindowUpdateFrame: case *SettingsFrame: // Wait for the client SETTINGS ack until ending the greet. close(greet) case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } gotRequest <- struct{}{} nreq++ writeResp <- f.StreamID if nreq == maxConcurrent+1 { close(writeResp) } case *DataFrame: default: return fmt.Errorf("Unexpected client frame %v", f) } } } ct.run() } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { ct := newClientTester(t) var reqSize, resSize uint32 = 8192, 16384 ct.tr.MaxDecoderHeaderTableSize = reqSize ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) cc, err := ct.tr.NewClientConn(ct.cc) if err != nil { return err } _, err = cc.RoundTrip(req) if err != nil { return err } if got, want := cc.peerMaxHeaderTableSize, resSize; got != want { return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want) } return nil } ct.server = func() error { buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) if err != nil { return fmt.Errorf("reading client preface: %v", err) } f, err := ct.fr.ReadFrame() if err != nil { return err } sf, ok := f.(*SettingsFrame) if !ok { ct.t.Fatalf("wanted client settings frame; got %v", f) _ = sf // stash it away? } var found bool err = sf.ForeachSetting(func(s Setting) error { if s.ID == SettingHeaderTableSize { found = true if got, want := s.Val, reqSize; got != want { return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want) } } return nil }) if err != nil { return err } if !found { return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting") } if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil { ct.t.Fatal(err) } if err := ct.fr.WriteSettingsAck(); err != nil { ct.t.Fatal(err) } for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *HeadersFrame: var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) return nil } } } ct.run() } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { ct := newClientTester(t) var peerAdvertisedMaxHeaderTableSize uint32 = 16384 ct.tr.MaxEncoderHeaderTableSize = 8192 ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) cc, err := ct.tr.NewClientConn(ct.cc) if err != nil { return err } _, err = cc.RoundTrip(req) if err != nil { return err } if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want { return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want) } return nil } ct.server = func() error { buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) if err != nil { return fmt.Errorf("reading client preface: %v", err) } f, err := ct.fr.ReadFrame() if err != nil { return err } sf, ok := f.(*SettingsFrame) if !ok { ct.t.Fatalf("wanted client settings frame; got %v", f) _ = sf // stash it away? } if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil { ct.t.Fatal(err) } if err := ct.fr.WriteSettingsAck(); err != nil { ct.t.Fatal(err) } for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *HeadersFrame: var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) return nil } } } ct.run() } func TestAuthorityAddr(t *testing.T) { tests := []struct { scheme, authority string want string }{ {"http", "foo.com", "foo.com:80"}, {"https", "foo.com", "foo.com:443"}, {"https", "foo.com:", "foo.com:443"}, {"https", "foo.com:1234", "foo.com:1234"}, {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, {"https", "1.2.3.4", "1.2.3.4:443"}, {"https", "1.2.3.4:", "1.2.3.4:443"}, {"https", "[::1]:1234", "[::1]:1234"}, {"https", "[::1]", "[::1]:443"}, {"https", "[::1]:", "[::1]:443"}, } for _, tt := range tests { got := authorityAddr(tt.scheme, tt.authority) if got != tt.want { t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) } } } // Issue 20448: stop allocating for DATA frames' payload after // Response.Body.Close is called. func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { megabyteZero := make([]byte, 1<<20) writeErr := make(chan error, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() var sum int64 for i := 0; i < 100; i++ { n, err := w.Write(megabyteZero) sum += int64(n) if err != nil { writeErr <- err return } } t.Logf("wrote all %d bytes", sum) writeErr <- nil }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} res, err := c.Get(st.ts.URL) if err != nil { t.Fatal(err) } var buf [1]byte if _, err := res.Body.Read(buf[:]); err != nil { t.Error(err) } if err := res.Body.Close(); err != nil { t.Error(err) } trb, ok := res.Body.(transportResponseBody) if !ok { t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) } if trb.cs.bufPipe.b != nil { t.Errorf("response body pipe is still open") } gotErr := <-writeErr if gotErr == nil { t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") } else if gotErr != errStreamClosed { t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) } } // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. func TestTransportNoBodyMeansNoDATA(t *testing.T) { ct := newClientTester(t) unblockClient := make(chan bool) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) ct.tr.RoundTrip(req) <-unblockClient return nil } ct.server = func() error { defer close(unblockClient) defer ct.cc.(*net.TCPConn).Close() ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) } switch f := f.(type) { default: return fmt.Errorf("Got %T; want HeadersFrame", f) case *WindowUpdateFrame, *SettingsFrame: continue case *HeadersFrame: if !f.StreamEnded() { return fmt.Errorf("got headers frame without END_STREAM") } return nil } } } ct.run() } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { defer disableGoroutineTracking()() b.ReportAllocs() st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { for i := 0; i < nResHeader; i++ { name := fmt.Sprint("A-", i) w.Header().Set(name, "*") } }, optOnlyServer, optQuiet, ) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { b.Fatal(err) } for i := 0; i < nReqHeaders; i++ { name := fmt.Sprint("A-", i) req.Header.Set(name, "*") } b.ResetTimer() for i := 0; i < b.N; i++ { res, err := tr.RoundTrip(req) if err != nil { if res != nil { res.Body.Close() } b.Fatalf("RoundTrip err = %v; want nil", err) } res.Body.Close() if res.StatusCode != http.StatusOK { b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) } } } type infiniteReader struct{} func (r infiniteReader) Read(b []byte) (int, error) { return len(b), nil } // Issue 20521: it is not an error to receive a response and end stream // from the server without the body being consumed. func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() // The request body needs to be big enough to trigger flow control. req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{}) res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) } } // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { ct := newClientTester(t) ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) _, err := ct.tr.RoundTrip(req) const substr = "malformed response from server: missing status pseudo header" if !strings.Contains(fmt.Sprint(err), substr) { return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) } return nil } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *HeadersFrame: enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, // we'll send some DATA to try to crash the transport BlockFragment: buf.Bytes(), }) ct.fr.WriteData(f.StreamID, true, []byte("payload")) return nil } } } ct.run() } func BenchmarkClientRequestHeaders(b *testing.B) { b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) }) b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) }) b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) }) b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) }) } func BenchmarkClientResponseHeaders(b *testing.B) { b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) }) b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) }) b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) }) b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) }) } func BenchmarkDownloadFrameSize(b *testing.B) { b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) }) b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) }) b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) }) b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) }) b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) }) } func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { defer disableGoroutineTracking()() const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M b.ReportAllocs() st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { // test 1GB transfer w.Header().Set("Content-Length", strconv.Itoa(transferSize)) w.Header().Set("Content-Transfer-Encoding", "binary") var data [1024 * 1024]byte for i := 0; i < transferSize/(1024*1024); i++ { w.Write(data[:]) } }, optQuiet, ) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { b.Fatal(err) } b.N = 3 b.SetBytes(transferSize) b.ResetTimer() for i := 0; i < b.N; i++ { res, err := tr.RoundTrip(req) if err != nil { if res != nil { res.Body.Close() } b.Fatalf("RoundTrip err = %v; want nil", err) } data, _ := io.ReadAll(res.Body) if len(data) != transferSize { b.Fatalf("Response length invalid") } res.Body.Close() if res.StatusCode != http.StatusOK { b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) } } } func activeStreams(cc *ClientConn) int { count := 0 cc.mu.Lock() defer cc.mu.Unlock() for _, cs := range cc.streams { select { case <-cs.abort: default: count++ } } return count } type closeMode int const ( closeAtHeaders closeMode = iota closeAtBody shutdown shutdownCancel ) // See golang.org/issue/17292 func testClientConnClose(t *testing.T, closeMode closeMode) { clientDone := make(chan struct{}) defer close(clientDone) handlerDone := make(chan struct{}) closeDone := make(chan struct{}) beforeHeader := func() {} bodyWrite := func(w http.ResponseWriter) {} st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { defer close(handlerDone) beforeHeader() w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() bodyWrite(w) select { case <-w.(http.CloseNotifier).CloseNotify(): // client closed connection before completion if closeMode == shutdown || closeMode == shutdownCancel { t.Error("expected request to complete") } case <-clientDone: if closeMode == closeAtHeaders || closeMode == closeAtBody { t.Error("expected connection closed by client") } } }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } if closeMode == closeAtHeaders { beforeHeader = func() { if err := cc.Close(); err != nil { t.Error(err) } close(closeDone) } } var sendBody chan struct{} if closeMode == closeAtBody { sendBody = make(chan struct{}) bodyWrite = func(w http.ResponseWriter) { <-sendBody b := make([]byte, 32) w.Write(b) w.(http.Flusher).Flush() if err := cc.Close(); err != nil { t.Errorf("unexpected ClientConn close error: %v", err) } close(closeDone) w.Write(b) w.(http.Flusher).Flush() } } res, err := cc.RoundTrip(req) if res != nil { defer res.Body.Close() } if closeMode == closeAtHeaders { got := fmt.Sprint(err) want := "http2: client connection force closed via ClientConn.Close" if got != want { t.Fatalf("RoundTrip error = %v, want %v", got, want) } } else { if err != nil { t.Fatalf("RoundTrip: %v", err) } if got, want := activeStreams(cc), 1; got != want { t.Errorf("got %d active streams, want %d", got, want) } } switch closeMode { case shutdownCancel: if err = cc.Shutdown(canceledCtx); err != context.Canceled { t.Errorf("got %v, want %v", err, context.Canceled) } if cc.closing == false { t.Error("expected closing to be true") } if cc.CanTakeNewRequest() == true { t.Error("CanTakeNewRequest to return false") } if v, want := len(cc.streams), 1; v != want { t.Errorf("expected %d active streams, got %d", want, v) } clientDone <- struct{}{} <-handlerDone case shutdown: wait := make(chan struct{}) shutdownEnterWaitStateHook = func() { close(wait) shutdownEnterWaitStateHook = func() {} } defer func() { shutdownEnterWaitStateHook = func() {} }() shutdown := make(chan struct{}, 1) go func() { if err = cc.Shutdown(context.Background()); err != nil { t.Error(err) } close(shutdown) }() // Let the shutdown to enter wait state <-wait cc.mu.Lock() if cc.closing == false { t.Error("expected closing to be true") } cc.mu.Unlock() if cc.CanTakeNewRequest() == true { t.Error("CanTakeNewRequest to return false") } if got, want := activeStreams(cc), 1; got != want { t.Errorf("got %d active streams, want %d", got, want) } // Let the active request finish clientDone <- struct{}{} // Wait for the shutdown to end select { case <-shutdown: case <-time.After(2 * time.Second): t.Fatal("expected server connection to close") } case closeAtHeaders, closeAtBody: if closeMode == closeAtBody { go close(sendBody) if _, err := io.Copy(ioutil.Discard, res.Body); err == nil { t.Error("expected a Copy error, got nil") } } <-closeDone if got, want := activeStreams(cc), 0; got != want { t.Errorf("got %d active streams, want %d", got, want) } // wait for server to get the connection close notice select { case <-handlerDone: case <-time.After(2 * time.Second): t.Fatal("expected server connection to close") } } } // The client closes the connection just after the server got the client's HEADERS // frame, but before the server sends its HEADERS response back. The expected // result is an error on RoundTrip explaining the client closed the connection. func TestClientConnCloseAtHeaders(t *testing.T) { testClientConnClose(t, closeAtHeaders) } // The client closes the connection between two server's response DATA frames. // The expected behavior is a response body io read error on the client. func TestClientConnCloseAtBody(t *testing.T) { testClientConnClose(t, closeAtBody) } // The client sends a GOAWAY frame before the server finished processing a request. // We expect the connection not to close until the request is completed. func TestClientConnShutdown(t *testing.T) { testClientConnClose(t, shutdown) } // The client sends a GOAWAY frame before the server finishes processing a request, // but cancels the passed context before the request is completed. The expected // behavior is the client closing the connection after the context is canceled. func TestClientConnShutdownCancel(t *testing.T) { testClientConnClose(t, shutdownCancel) } // Issue 25009: use Request.GetBody if present, even if it seems like // we might not need it. Apparently something else can still read from // the original request body. Data race? In any case, rewinding // unconditionally on retry is a nicer model anyway and should // simplify code in the future (after the Go 1.11 freeze) func TestTransportUsesGetBodyWhenPresent(t *testing.T) { calls := 0 someBody := func() io.ReadCloser { return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))} } req := &http.Request{ Body: someBody(), GetBody: func() (io.ReadCloser, error) { calls++ return someBody(), nil }, } req2, err := shouldRetryRequest(req, errClientConnUnusable) if err != nil { t.Fatal(err) } if calls != 1 { t.Errorf("Calls = %d; want 1", calls) } if req2 == req { t.Error("req2 changed") } if req2 == nil { t.Fatal("req2 is nil") } if req2.Body == nil { t.Fatal("req2.Body is nil") } if req2.GetBody == nil { t.Fatal("req2.GetBody is nil") } if req2.Body == req.Body { t.Error("req2.Body unchanged") } } // Issue 22891: verify that the "https" altproto we register with net/http // is a certain type: a struct with one field with our *http2.Transport in it. func TestNoDialH2RoundTripperType(t *testing.T) { t1 := new(http.Transport) t2 := new(Transport) rt := noDialH2RoundTripper{t2} if err := registerHTTPSProtocol(t1, rt); err != nil { t.Fatal(err) } rv := reflect.ValueOf(rt) if rv.Type().Kind() != reflect.Struct { t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind()) } if n := rv.Type().NumField(); n != 1 { t.Fatalf("fields = %d; net/http expects 1", n) } v := rv.Field(0) if _, ok := v.Interface().(*Transport); !ok { t.Fatalf("wrong kind %T; want *Transport", v.Interface()) } } type errReader struct { body []byte err error } func (r *errReader) Read(p []byte) (int, error) { if len(r.body) > 0 { n := copy(p, r.body) r.body = r.body[n:] return n, nil } return 0, r.err } func testTransportBodyReadError(t *testing.T, body []byte) { if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { // So far we've only seen this be flaky on Windows and Plan 9, // perhaps due to TCP behavior on shutdowns while // unread data is in flight. This test should be // fixed, but a skip is better than annoying people // for now. t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) } clientDone := make(chan struct{}) ct := newClientTester(t) ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } defer close(clientDone) checkNoStreams := func() error { cp, ok := ct.tr.connPool().(*clientConnPool) if !ok { return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) } cp.mu.Lock() defer cp.mu.Unlock() conns, ok := cp.conns["dummy.tld:443"] if !ok { return fmt.Errorf("missing connection") } if len(conns) != 1 { return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) } if activeStreams(conns[0]) != 0 { return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) } return nil } bodyReadError := errors.New("body read error") body := &errReader{body, bodyReadError} req, err := http.NewRequest("PUT", "https://dummy.tld/", body) if err != nil { return err } _, err = ct.tr.RoundTrip(req) if err != bodyReadError { return fmt.Errorf("err = %v; want %v", err, bodyReadError) } if err = checkNoStreams(); err != nil { return err } return nil } ct.server = func() error { ct.greet() var receivedBody []byte var resetCount int for { f, err := ct.fr.ReadFrame() t.Logf("server: ReadFrame = %v, %v", f, err) if err != nil { select { case <-clientDone: // If the client's done, it // will have reported any // errors on its side. if bytes.Compare(receivedBody, body) != 0 { return fmt.Errorf("body: %q; expected %q", receivedBody, body) } if resetCount != 1 { return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) } return nil default: return err } } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: case *DataFrame: receivedBody = append(receivedBody, f.Data()...) case *RSTStreamFrame: resetCount++ default: return fmt.Errorf("Unexpected client frame %v", f) } } } ct.run() } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) } // Issue 32254: verify that the client sends END_STREAM flag eagerly with the last // (or in this test-case the only one) request body data frame, and does not send // extra zero-len data frames. func TestTransportBodyEagerEndStream(t *testing.T) { const reqBody = "some request body" const resBody = "some response body" ct := newClientTester(t) ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() if runtime.GOOS == "plan9" { // CloseWrite not supported on Plan 9; Issue 17906 defer ct.cc.(*net.TCPConn).Close() } body := strings.NewReader(reqBody) req, err := http.NewRequest("PUT", "https://dummy.tld/", body) if err != nil { return err } _, err = ct.tr.RoundTrip(req) if err != nil { return err } return nil } ct.server = func() error { ct.greet() for { f, err := ct.fr.ReadFrame() if err != nil { return err } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: case *DataFrame: if !f.StreamEnded() { ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) return fmt.Errorf("data frame without END_STREAM %v", f) } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.Header().StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) ct.fr.WriteData(f.StreamID, true, []byte(resBody)) return nil case *RSTStreamFrame: default: return fmt.Errorf("Unexpected client frame %v", f) } } } ct.run() } type chunkReader struct { chunks [][]byte } func (r *chunkReader) Read(p []byte) (int, error) { if len(r.chunks) > 0 { n := copy(p, r.chunks[0]) r.chunks = r.chunks[1:] return n, nil } panic("shouldn't read this many times") } // Issue 32254: if the request body is larger than the specified // content length, the client should refuse to send the extra part // and abort the stream. // // In _len3 case, the first Read() matches the expected content length // but the second read returns more data. // // In _len2 case, the first Read() exceeds the expected content length. func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) { body := &chunkReader{[][]byte{ []byte("123"), []byte("456"), }} testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) } func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) { body := &chunkReader{[][]byte{ []byte("123"), }} testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) } func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { r.Body.Read(make([]byte, 6)) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, _ := http.NewRequest("POST", st.ts.URL, body) req.ContentLength = contentLen _, err := tr.RoundTrip(req) if err != errReqBodyTooLong { t.Fatalf("expected %v, got %v", errReqBodyTooLong, err) } } func TestClientConnTooIdle(t *testing.T) { tests := []struct { cc func() *ClientConn want bool }{ { func() *ClientConn { return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} }, true, }, { func() *ClientConn { return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} }, false, }, { func() *ClientConn { return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} }, false, }, { func() *ClientConn { return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} }, false, }, } for i, tt := range tests { got := tt.cc().tooIdleLocked() if got != tt.want { t.Errorf("%d. got %v; want %v", i, got, tt.want) } } } type fakeConnErr struct { net.Conn writeErr error closed bool } func (fce *fakeConnErr) Write(b []byte) (n int, err error) { return 0, fce.writeErr } func (fce *fakeConnErr) Close() error { fce.closed = true return nil } // issue 39337: close the connection on a failed write func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { tr := &Transport{} writeErr := errors.New("write error") fakeConn := &fakeConnErr{writeErr: writeErr} _, err := tr.NewClientConn(fakeConn) if err != writeErr { t.Fatalf("expected %v, got %v", writeErr, err) } if !fakeConn.closed { t.Error("expected closed conn") } } func TestTransportRoundtripCloseOnWriteError(t *testing.T) { req, err := http.NewRequest("GET", "https://dummy.tld/", nil) if err != nil { t.Fatal(err) } st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) } writeErr := errors.New("write error") cc.wmu.Lock() cc.werr = writeErr cc.wmu.Unlock() _, err = cc.RoundTrip(req) if err != writeErr { t.Fatalf("expected %v, got %v", writeErr, err) } cc.mu.Lock() closed := cc.closed cc.mu.Unlock() if !closed { t.Fatal("expected closed") } } // Issue 31192: A failed request may be retried if the body has not been read // already. If the request body has started to be sent, one must wait until it // is completed. func TestTransportBodyRewindRace(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "close") w.WriteHeader(http.StatusOK) return }, optOnlyServer) defer st.Close() tr := &http.Transport{ TLSClientConfig: tlsConfigInsecure, MaxConnsPerHost: 1, } err := ConfigureTransport(tr) if err != nil { t.Fatal(err) } client := &http.Client{ Transport: tr, } const clients = 50 var wg sync.WaitGroup wg.Add(clients) for i := 0; i < clients; i++ { req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef")) if err != nil { t.Fatalf("unexpect new request error: %v", err) } go func() { defer wg.Done() res, err := client.Do(req) if err == nil { res.Body.Close() } }() } wg.Wait() } // Issue 42498: A request with a body will never be sent if the stream is // reset prior to sending any data. func TestTransportServerResetStreamAtHeaders(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) return }, optOnlyServer) defer st.Close() tr := &http.Transport{ TLSClientConfig: tlsConfigInsecure, MaxConnsPerHost: 1, ExpectContinueTimeout: 10 * time.Second, } err := ConfigureTransport(tr) if err != nil { t.Fatal(err) } client := &http.Client{ Transport: tr, } req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF}) if err != nil { t.Fatalf("unexpect new request error: %v", err) } req.ContentLength = 0 // so transport is tempted to sniff it req.Header.Set("Expect", "100-continue") res, err := client.Do(req) if err != nil { t.Fatal(err) } res.Body.Close() } type trackingReader struct { rdr io.Reader wasRead uint32 } func (tr *trackingReader) Read(p []byte) (int, error) { atomic.StoreUint32(&tr.wasRead, 1) return tr.rdr.Read(p) } func (tr *trackingReader) WasRead() bool { return atomic.LoadUint32(&tr.wasRead) != 0 } func TestTransportExpectContinue(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/reject": w.WriteHeader(403) default: io.Copy(io.Discard, r.Body) } }, optOnlyServer) defer st.Close() tr := &http.Transport{ TLSClientConfig: tlsConfigInsecure, MaxConnsPerHost: 1, ExpectContinueTimeout: 10 * time.Second, } err := ConfigureTransport(tr) if err != nil { t.Fatal(err) } client := &http.Client{ Transport: tr, } testCases := []struct { Name string Path string Body *trackingReader ExpectedCode int ShouldRead bool }{ { Name: "read-all", Path: "/", Body: &trackingReader{rdr: strings.NewReader("hello")}, ExpectedCode: 200, ShouldRead: true, }, { Name: "reject", Path: "/reject", Body: &trackingReader{rdr: strings.NewReader("hello")}, ExpectedCode: 403, ShouldRead: false, }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { startTime := time.Now() req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body) if err != nil { t.Fatal(err) } req.Header.Set("Expect", "100-continue") res, err := client.Do(req) if err != nil { t.Fatal(err) } res.Body.Close() if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout { t.Error("Request didn't finish before expect continue timeout") } if res.StatusCode != tc.ExpectedCode { t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode) } if tc.Body.WasRead() != tc.ShouldRead { t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead) } }) } } type closeChecker struct { io.ReadCloser closed chan struct{} } func newCloseChecker(r io.ReadCloser) *closeChecker { return &closeChecker{r, make(chan struct{})} } func newStaticCloseChecker(body string) *closeChecker { return newCloseChecker(io.NopCloser(strings.NewReader("body"))) } func (rc *closeChecker) Read(b []byte) (n int, err error) { select { default: case <-rc.closed: // TODO(dneil): Consider restructuring the request write to avoid reading // from the request body after closing it, and check for read-after-close here. // Currently, abortRequestBodyWrite races with writeRequestBody. return 0, errors.New("read after Body.Close") } return rc.ReadCloser.Read(b) } func (rc *closeChecker) Close() error { close(rc.closed) return rc.ReadCloser.Close() } func (rc *closeChecker) isClosed() error { // The RoundTrip contract says that it will close the request body, // but that it may do so in a separate goroutine. Wait a reasonable // amount of time before concluding that the body isn't being closed. timeout := time.Duration(10 * time.Second) select { case <-rc.closed: case <-time.After(timeout): return fmt.Errorf("body not closed after %v", timeout) } return nil } // A blockingWriteConn is a net.Conn that blocks in Write after some number of bytes are written. type blockingWriteConn struct { net.Conn writeOnce sync.Once writec chan struct{} // closed after the write limit is reached unblockc chan struct{} // closed to unblock writes count, limit int } func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn { return &blockingWriteConn{ Conn: conn, limit: limit, writec: make(chan struct{}), unblockc: make(chan struct{}), } } // wait waits until the conn blocks writing the limit+1st byte. func (c *blockingWriteConn) wait() { <-c.writec } // unblock unblocks writes to the conn. func (c *blockingWriteConn) unblock() { close(c.unblockc) } func (c *blockingWriteConn) Write(b []byte) (n int, err error) { if c.count+len(b) > c.limit { c.writeOnce.Do(func() { close(c.writec) }) <-c.unblockc } n, err = c.Conn.Write(b) c.count += n return n, err } // Write several requests to a ClientConn at the same time, looking for race conditions. // See golang.org/issue/48340 func TestTransportFrameBufferReuse(t *testing.T) { filler := hex.EncodeToString([]byte(randString(2048))) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if got, want := r.Header.Get("Big"), filler; got != want { t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want) } b, err := ioutil.ReadAll(r.Body) if err != nil { t.Errorf("error reading request body: %v", err) } if got, want := string(b), filler; got != want { t.Errorf("request body = %q, want %q", got, want) } if got, want := r.Trailer.Get("Big"), filler; got != want { t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want) } }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() var wg sync.WaitGroup defer wg.Wait() for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler)) if err != nil { t.Error(err) return } req.Header.Set("Big", filler) req.Trailer = make(http.Header) req.Trailer.Set("Big", filler) res, err := tr.RoundTrip(req) if err != nil { t.Error(err) return } if got, want := res.StatusCode, 200; got != want { t.Errorf("StatusCode = %v; want %v", got, want) } if res != nil && res.Body != nil { res.Body.Close() } }() } } // Ensure that a request blocking while being written to the underlying net.Conn doesn't // block access to the ClientConn pool. Test requests blocking while writing headers, the body, // and trailers. // See golang.org/issue/32388 func TestTransportBlockingRequestWrite(t *testing.T) { filler := hex.EncodeToString([]byte(randString(2048))) for _, test := range []struct { name string req func(url string) (*http.Request, error) }{{ name: "headers", req: func(url string) (*http.Request, error) { req, err := http.NewRequest("POST", url, nil) if err != nil { return nil, err } req.Header.Set("Big", filler) return req, err }, }, { name: "body", req: func(url string) (*http.Request, error) { req, err := http.NewRequest("POST", url, strings.NewReader(filler)) if err != nil { return nil, err } return req, err }, }, { name: "trailer", req: func(url string) (*http.Request, error) { req, err := http.NewRequest("POST", url, strings.NewReader("body")) if err != nil { return nil, err } req.Trailer = make(http.Header) req.Trailer.Set("Big", filler) return req, err }, }} { test := test t.Run(test.name, func(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if v := r.Header.Get("Big"); v != "" && v != filler { t.Errorf("request header mismatch") } if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler { t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler) } if v := r.Trailer.Get("Big"); v != "" && v != filler { t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler) } }, optOnlyServer, func(s *Server) { s.MaxConcurrentStreams = 1 }) defer st.Close() // This Transport creates connections that block on writes after 1024 bytes. connc := make(chan *blockingWriteConn, 1) connCount := 0 tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { connCount++ c, err := tls.Dial(network, addr, cfg) wc := newBlockingWriteConn(c, 1024) select { case connc <- wc: default: } return wc, err }, } defer tr.CloseIdleConnections() // Request 1: A small request to ensure we read the server MaxConcurrentStreams. { req, err := http.NewRequest("POST", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("StatusCode = %v; want %v", got, want) } if res != nil && res.Body != nil { res.Body.Close() } } // Request 2: A large request that blocks while being written. reqc := make(chan struct{}) go func() { defer close(reqc) req, err := test.req(st.ts.URL) if err != nil { t.Error(err) return } res, _ := tr.RoundTrip(req) if res != nil && res.Body != nil { res.Body.Close() } }() conn := <-connc conn.wait() // wait for the request to block // Request 3: A small request that is sent on a new connection, since request 2 // is hogging the only available stream on the previous connection. { req, err := http.NewRequest("POST", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("StatusCode = %v; want %v", got, want) } if res != nil && res.Body != nil { res.Body.Close() } } // Request 2 should still be blocking at this point. select { case <-reqc: t.Errorf("request 2 unexpectedly completed") default: } conn.unblock() <-reqc if connCount != 2 { t.Errorf("created %v connections, want 1", connCount) } }) } } func TestTransportCloseRequestBody(t *testing.T) { var statusCode int st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() ctx := context.Background() cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) } for _, status := range []int{200, 401} { t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) { statusCode = status pr, pw := io.Pipe() body := newCloseChecker(pr) req, err := http.NewRequest("PUT", "https://dummy.tld/", body) if err != nil { t.Fatal(err) } res, err := cc.RoundTrip(req) if err != nil { t.Fatal(err) } res.Body.Close() pw.Close() if err := body.isClosed(); err != nil { t.Fatal(err) } }) } } // collectClientsConnPool is a ClientConnPool that wraps lower and // collects what calls were made on it. type collectClientsConnPool struct { lower ClientConnPool mu sync.Mutex getErrs int got []*ClientConn } func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { cc, err := p.lower.GetClientConn(req, addr) p.mu.Lock() defer p.mu.Unlock() if err != nil { p.getErrs++ return nil, err } p.got = append(p.got, cc) return cc, nil } func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { p.lower.MarkDead(cc) } func TestTransportRetriesOnStreamProtocolError(t *testing.T) { ct := newClientTester(t) pool := &collectClientsConnPool{ lower: &clientConnPool{t: ct.tr}, } ct.tr.ConnPool = pool gotProtoError := make(chan bool, 1) ct.tr.CountError = func(errType string) { if errType == "recv_rststream_PROTOCOL_ERROR" { select { case gotProtoError <- true: default: } } } ct.client = func() error { // Start two requests. The first is a long request // that will finish after the second. The second one // will result in the protocol error. We check that // after the first one closes, the connection then // shuts down. // The long, outer request. req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) res1, err := ct.tr.RoundTrip(req1) if err != nil { return err } if got, want := res1.Header.Get("Is-Long"), "1"; got != want { return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) } req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) res, err := ct.tr.RoundTrip(req) const want = "only one dial allowed in test mode" if got := fmt.Sprint(err); got != want { t.Errorf("didn't dial again: got %#q; want %#q", got, want) } if res != nil { res.Body.Close() } select { case <-gotProtoError: default: t.Errorf("didn't get stream protocol error") } if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { t.Errorf("unexpected body read %v, %v", n, err) } pool.mu.Lock() defer pool.mu.Unlock() if pool.getErrs != 1 { t.Errorf("pool get errors = %v; want 1", pool.getErrs) } if len(pool.got) == 2 { if pool.got[0] != pool.got[1] { t.Errorf("requests went on different connections") } cc := pool.got[0] cc.mu.Lock() if !cc.doNotReuse { t.Error("ClientConn not marked doNotReuse") } cc.mu.Unlock() select { case <-cc.readerDone: case <-time.After(5 * time.Second): t.Errorf("timeout waiting for reader to be done") } } else { t.Errorf("pool get success = %v; want 2", len(pool.got)) } return nil } ct.server = func() error { ct.greet() var sentErr bool var numHeaders int var firstStreamID uint32 var hbuf bytes.Buffer enc := hpack.NewEncoder(&hbuf) for { f, err := ct.fr.ReadFrame() if err == io.EOF { // Client hung up on us, as it should at the end. return nil } if err != nil { return nil } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: numHeaders++ if numHeaders == 1 { firstStreamID = f.StreamID hbuf.Reset() enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: hbuf.Bytes(), }) continue } if !sentErr { sentErr = true ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) ct.fr.WriteData(firstStreamID, true, nil) continue } } } } ct.run() } func TestClientConnReservations(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { s.MaxConcurrentStreams = initialMaxConcurrentStreams }) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() cc, err := tr.newClientConn(st.cc, false) if err != nil { t.Fatal(err) } req, _ := http.NewRequest("GET", st.ts.URL, nil) n := 0 for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n++ } if n != initialMaxConcurrentStreams { t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) } if _, err := cc.RoundTrip(req); err != nil { t.Fatalf("RoundTrip error = %v", err) } n2 := 0 for n2 <= 5 && cc.ReserveNewRequest() { n2++ } if n2 != 1 { t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2) } // Use up all the reservations for i := 0; i < n; i++ { cc.RoundTrip(req) } n2 = 0 for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n2++ } if n2 != n { t.Errorf("after reset, reservations = %v; want %v", n2, n) } } func TestTransportTimeoutServerHangs(t *testing.T) { clientDone := make(chan struct{}) ct := newClientTester(t) ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() defer close(clientDone) req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() req = req.WithContext(ctx) req.Header.Add("Big", strings.Repeat("a", 1<<20)) _, err = ct.tr.RoundTrip(req) if err == nil { return errors.New("error should not be nil") } if ne, ok := err.(net.Error); !ok || !ne.Timeout() { return fmt.Errorf("error should be a net error timeout: %v", err) } return nil } ct.server = func() error { ct.greet() select { case <-time.After(5 * time.Second): case <-clientDone: } return nil } ct.run() } func TestTransportContentLengthWithoutBody(t *testing.T) { contentLength := "" st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", contentLength) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() for _, test := range []struct { name string contentLength string wantBody string wantErr error wantContentLength int64 }{ { name: "non-zero content length", contentLength: "42", wantErr: io.ErrUnexpectedEOF, wantContentLength: 42, }, { name: "zero content length", contentLength: "0", wantErr: nil, wantContentLength: 0, }, } { t.Run(test.name, func(t *testing.T) { contentLength = test.contentLength req, _ := http.NewRequest("GET", st.ts.URL, nil) res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != test.wantErr { t.Errorf("Expected error %v, got: %v", test.wantErr, err) } if len(body) > 0 { t.Errorf("Expected empty body, got: %v", body) } if res.ContentLength != test.wantContentLength { t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength) } }) } } func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.(http.Flusher).Flush() io.Copy(io.Discard, r.Body) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() pr, pw := net.Pipe() req, err := http.NewRequest("GET", st.ts.URL, pr) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } // Closing the Response's Body interrupts the blocked body read. res.Body.Close() pw.Close() } func TestTransport300ResponseBody(t *testing.T) { reqc := make(chan struct{}) body := []byte("response body") st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(300) w.(http.Flusher).Flush() <-reqc w.Write(body) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() pr, pw := net.Pipe() req, err := http.NewRequest("GET", st.ts.URL, pr) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } close(reqc) got, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error reading response body: %v", err) } if !bytes.Equal(got, body) { t.Errorf("got response body %q, want %q", string(got), string(body)) } res.Body.Close() pw.Close() } func TestTransportWriteByteTimeout(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer, ) defer st.Close() tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { _, c := net.Pipe() return c, nil }, WriteByteTimeout: 1 * time.Millisecond, } defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} _, err := c.Get(st.ts.URL) if !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err) } } type slowWriteConn struct { net.Conn hasWriteDeadline bool } func (c *slowWriteConn) SetWriteDeadline(t time.Time) error { c.hasWriteDeadline = !t.IsZero() return nil } func (c *slowWriteConn) Write(b []byte) (n int, err error) { if c.hasWriteDeadline && len(b) > 1 { n, err = c.Conn.Write(b[:1]) if err != nil { return n, err } return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded) } return c.Conn.Write(b) } func TestTransportSlowWrites(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer, ) defer st.Close() tr := &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { cfg.InsecureSkipVerify = true c, err := tls.Dial(network, addr, cfg) return &slowWriteConn{Conn: c}, err }, WriteByteTimeout: 1 * time.Millisecond, } defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} const bodySize = 1 << 20 resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize)) if err != nil { t.Fatal(err) } resp.Body.Close() } func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) { testTransportClosesConnAfterGoAway(t, 0) } func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { testTransportClosesConnAfterGoAway(t, 1) } type closeOnceConn struct { net.Conn closed uint32 } var errClosed = errors.New("Close of closed connection") func (c *closeOnceConn) Close() error { if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return c.Conn.Close() } return errClosed } // testTransportClosesConnAfterGoAway verifies that the transport // closes a connection after reading a GOAWAY from it. // // lastStream is the last stream ID in the GOAWAY frame. // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { ct := newClientTester(t) ct.cc = &closeOnceConn{Conn: ct.cc} var wg sync.WaitGroup wg.Add(1) ct.client = func() error { defer wg.Done() req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err == nil { res.Body.Close() } if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) } if err = ct.cc.Close(); err != errClosed { return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) } return nil } ct.server = func() error { defer wg.Wait() ct.greet() hf, err := ct.firstHeaders() if err != nil { return fmt.Errorf("server failed reading HEADERS: %v", err) } if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { return fmt.Errorf("server failed writing GOAWAY: %v", err) } if lastStream > 0 { // Send a valid response to first request. var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, EndStream: true, BlockFragment: buf.Bytes(), }) } return nil } ct.run() } type slowCloser struct { closing chan struct{} closed chan struct{} } func (r *slowCloser) Read([]byte) (int, error) { return 0, io.EOF } func (r *slowCloser) Close() error { close(r.closing) <-r.closed return nil } func TestTransportSlowClose(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, optOnlyServer) defer st.Close() client := st.ts.Client() body := &slowCloser{ closing: make(chan struct{}), closed: make(chan struct{}), } reqc := make(chan struct{}) go func() { defer close(reqc) res, err := client.Post(st.ts.URL, "text/plain", body) if err != nil { t.Error(err) } res.Body.Close() }() defer func() { close(body.closed) <-reqc // wait for POST request to finish }() <-body.closing // wait for POST request to call body.Close // This GET request should not be blocked by the in-progress POST. res, err := client.Get(st.ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() } func TestTransportDialTLSContext(t *testing.T) { blockCh := make(chan struct{}) serverTLSConfigFunc := func(ts *httptest.Server) { ts.Config.TLSConfig = &tls.Config{ // Triggers the server to request the clients certificate // during TLS handshake. ClientAuth: tls.RequestClientCert, } } ts := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer, serverTLSConfigFunc, ) defer ts.Close() tr := &Transport{ TLSClientConfig: &tls.Config{ GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { // Tests that the context provided to `req` is // passed into this function. close(blockCh) <-cri.Context().Done() return nil, cri.Context().Err() }, InsecureSkipVerify: true, }, } defer tr.CloseIdleConnections() req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) if err != nil { t.Fatal(err) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() req = req.WithContext(ctx) errCh := make(chan error) go func() { defer close(errCh) res, err := tr.RoundTrip(req) if err != nil { errCh <- err return } res.Body.Close() }() // Wait for GetClientCertificate handler to be called <-blockCh // Cancel the context cancel() // Expect the cancellation error here err = <-errCh if err == nil { t.Fatal("cancelling context during client certificate fetch did not error as expected") return } if !errors.Is(err, context.Canceled) { t.Fatalf("unexpected error returned after cancellation: %v", err) } } // TestDialRaceResumesDial tests that, given two concurrent requests // to the same address, when the first Dial is interrupted because // the first request's context is cancelled, the second request // resumes the dial automatically. func TestDialRaceResumesDial(t *testing.T) { blockCh := make(chan struct{}) serverTLSConfigFunc := func(ts *httptest.Server) { ts.Config.TLSConfig = &tls.Config{ // Triggers the server to request the clients certificate // during TLS handshake. ClientAuth: tls.RequestClientCert, } } ts := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer, serverTLSConfigFunc, ) defer ts.Close() tr := &Transport{ TLSClientConfig: &tls.Config{ GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { select { case <-blockCh: // If we already errored, return without error. return &tls.Certificate{}, nil default: } close(blockCh) <-cri.Context().Done() return nil, cri.Context().Err() }, InsecureSkipVerify: true, }, } defer tr.CloseIdleConnections() req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) if err != nil { t.Fatal(err) } // Create two requests with independent cancellation. ctx1, cancel1 := context.WithCancel(context.Background()) defer cancel1() req1 := req.WithContext(ctx1) ctx2, cancel2 := context.WithCancel(context.Background()) defer cancel2() req2 := req.WithContext(ctx2) errCh := make(chan error) go func() { res, err := tr.RoundTrip(req1) if err != nil { errCh <- err return } res.Body.Close() }() successCh := make(chan struct{}) go func() { // Don't start request until first request // has initiated the handshake. <-blockCh res, err := tr.RoundTrip(req2) if err != nil { errCh <- err return } res.Body.Close() // Close successCh to indicate that the second request // made it to the server successfully. close(successCh) }() // Wait for GetClientCertificate handler to be called <-blockCh // Cancel the context first cancel1() // Expect the cancellation error here err = <-errCh if err == nil { t.Fatal("cancelling context during client certificate fetch did not error as expected") return } if !errors.Is(err, context.Canceled) { t.Fatalf("unexpected error returned after cancellation: %v", err) } select { case err := <-errCh: t.Fatalf("unexpected second error: %v", err) case <-successCh: } }