...

Source file src/golang.org/x/net/http2/transport_test.go

Documentation: golang.org/x/net/http2

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"compress/gzip"
    11  	"context"
    12  	"crypto/tls"
    13  	"encoding/hex"
    14  	"errors"
    15  	"flag"
    16  	"fmt"
    17  	"io"
    18  	"io/fs"
    19  	"io/ioutil"
    20  	"log"
    21  	"math/rand"
    22  	"net"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"net/http/httptrace"
    26  	"net/textproto"
    27  	"net/url"
    28  	"os"
    29  	"reflect"
    30  	"runtime"
    31  	"sort"
    32  	"strconv"
    33  	"strings"
    34  	"sync"
    35  	"sync/atomic"
    36  	"testing"
    37  	"time"
    38  
    39  	"golang.org/x/net/http2/hpack"
    40  )
    41  
    42  var (
    43  	extNet        = flag.Bool("extnet", false, "do external network tests")
    44  	transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
    45  	insecure      = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove?
    46  )
    47  
    48  var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
    49  
    50  var canceledCtx context.Context
    51  
    52  func init() {
    53  	ctx, cancel := context.WithCancel(context.Background())
    54  	cancel()
    55  	canceledCtx = ctx
    56  }
    57  
    58  func TestTransportExternal(t *testing.T) {
    59  	if !*extNet {
    60  		t.Skip("skipping external network test")
    61  	}
    62  	req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
    63  	rt := &Transport{TLSClientConfig: tlsConfigInsecure}
    64  	res, err := rt.RoundTrip(req)
    65  	if err != nil {
    66  		t.Fatalf("%v", err)
    67  	}
    68  	res.Write(os.Stdout)
    69  }
    70  
    71  type fakeTLSConn struct {
    72  	net.Conn
    73  }
    74  
    75  func (c *fakeTLSConn) ConnectionState() tls.ConnectionState {
    76  	return tls.ConnectionState{
    77  		Version:     tls.VersionTLS12,
    78  		CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
    79  	}
    80  }
    81  
    82  func startH2cServer(t *testing.T) net.Listener {
    83  	h2Server := &Server{}
    84  	l := newLocalListener(t)
    85  	go func() {
    86  		conn, err := l.Accept()
    87  		if err != nil {
    88  			t.Error(err)
    89  			return
    90  		}
    91  		h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    92  			fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
    93  		})})
    94  	}()
    95  	return l
    96  }
    97  
    98  func TestTransportH2c(t *testing.T) {
    99  	l := startH2cServer(t)
   100  	defer l.Close()
   101  	req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil)
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  	var gotConnCnt int32
   106  	trace := &httptrace.ClientTrace{
   107  		GotConn: func(connInfo httptrace.GotConnInfo) {
   108  			if !connInfo.Reused {
   109  				atomic.AddInt32(&gotConnCnt, 1)
   110  			}
   111  		},
   112  	}
   113  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   114  	tr := &Transport{
   115  		AllowHTTP: true,
   116  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   117  			return net.Dial(network, addr)
   118  		},
   119  	}
   120  	res, err := tr.RoundTrip(req)
   121  	if err != nil {
   122  		t.Fatal(err)
   123  	}
   124  	if res.ProtoMajor != 2 {
   125  		t.Fatal("proto not h2c")
   126  	}
   127  	body, err := ioutil.ReadAll(res.Body)
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  	if got, want := string(body), "Hello, /foobar, http: true"; got != want {
   132  		t.Fatalf("response got %v, want %v", got, want)
   133  	}
   134  	if got, want := gotConnCnt, int32(1); got != want {
   135  		t.Errorf("Too many got connections: %d", gotConnCnt)
   136  	}
   137  }
   138  
   139  func TestTransport(t *testing.T) {
   140  	const body = "sup"
   141  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   142  		io.WriteString(w, body)
   143  	}, optOnlyServer)
   144  	defer st.Close()
   145  
   146  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   147  	defer tr.CloseIdleConnections()
   148  
   149  	u, err := url.Parse(st.ts.URL)
   150  	if err != nil {
   151  		t.Fatal(err)
   152  	}
   153  	for i, m := range []string{"GET", ""} {
   154  		req := &http.Request{
   155  			Method: m,
   156  			URL:    u,
   157  		}
   158  		res, err := tr.RoundTrip(req)
   159  		if err != nil {
   160  			t.Fatalf("%d: %s", i, err)
   161  		}
   162  
   163  		t.Logf("%d: Got res: %+v", i, res)
   164  		if g, w := res.StatusCode, 200; g != w {
   165  			t.Errorf("%d: StatusCode = %v; want %v", i, g, w)
   166  		}
   167  		if g, w := res.Status, "200 OK"; g != w {
   168  			t.Errorf("%d: Status = %q; want %q", i, g, w)
   169  		}
   170  		wantHeader := http.Header{
   171  			"Content-Length": []string{"3"},
   172  			"Content-Type":   []string{"text/plain; charset=utf-8"},
   173  			"Date":           []string{"XXX"}, // see cleanDate
   174  		}
   175  		cleanDate(res)
   176  		if !reflect.DeepEqual(res.Header, wantHeader) {
   177  			t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader)
   178  		}
   179  		if res.Request != req {
   180  			t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req)
   181  		}
   182  		if res.TLS == nil {
   183  			t.Errorf("%d: Response.TLS = nil; want non-nil", i)
   184  		}
   185  		slurp, err := ioutil.ReadAll(res.Body)
   186  		if err != nil {
   187  			t.Errorf("%d: Body read: %v", i, err)
   188  		} else if string(slurp) != body {
   189  			t.Errorf("%d: Body = %q; want %q", i, slurp, body)
   190  		}
   191  		res.Body.Close()
   192  	}
   193  }
   194  
   195  func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq func(*http.Request)) {
   196  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   197  		io.WriteString(w, r.RemoteAddr)
   198  	}, optOnlyServer, func(c net.Conn, st http.ConnState) {
   199  		t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
   200  	})
   201  	defer st.Close()
   202  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   203  	if useClient {
   204  		tr.ConnPool = noDialClientConnPool{new(clientConnPool)}
   205  	}
   206  	defer tr.CloseIdleConnections()
   207  	get := func() string {
   208  		req, err := http.NewRequest("GET", st.ts.URL, nil)
   209  		if err != nil {
   210  			t.Fatal(err)
   211  		}
   212  		modReq(req)
   213  		var res *http.Response
   214  		if useClient {
   215  			c := st.ts.Client()
   216  			ConfigureTransports(c.Transport.(*http.Transport))
   217  			res, err = c.Do(req)
   218  		} else {
   219  			res, err = tr.RoundTrip(req)
   220  		}
   221  		if err != nil {
   222  			t.Fatal(err)
   223  		}
   224  		defer res.Body.Close()
   225  		slurp, err := ioutil.ReadAll(res.Body)
   226  		if err != nil {
   227  			t.Fatalf("Body read: %v", err)
   228  		}
   229  		addr := strings.TrimSpace(string(slurp))
   230  		if addr == "" {
   231  			t.Fatalf("didn't get an addr in response")
   232  		}
   233  		return addr
   234  	}
   235  	first := get()
   236  	second := get()
   237  	if got := first == second; got != wantSame {
   238  		t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame)
   239  	}
   240  }
   241  
   242  func TestTransportReusesConns(t *testing.T) {
   243  	for _, test := range []struct {
   244  		name     string
   245  		modReq   func(*http.Request)
   246  		wantSame bool
   247  	}{{
   248  		name:     "ReuseConn",
   249  		modReq:   func(*http.Request) {},
   250  		wantSame: true,
   251  	}, {
   252  		name:     "RequestClose",
   253  		modReq:   func(r *http.Request) { r.Close = true },
   254  		wantSame: false,
   255  	}, {
   256  		name:     "ConnClose",
   257  		modReq:   func(r *http.Request) { r.Header.Set("Connection", "close") },
   258  		wantSame: false,
   259  	}} {
   260  		t.Run(test.name, func(t *testing.T) {
   261  			t.Run("Transport", func(t *testing.T) {
   262  				const useClient = false
   263  				testTransportReusesConns(t, useClient, test.wantSame, test.modReq)
   264  			})
   265  			t.Run("Client", func(t *testing.T) {
   266  				const useClient = true
   267  				testTransportReusesConns(t, useClient, test.wantSame, test.modReq)
   268  			})
   269  		})
   270  	}
   271  }
   272  
   273  func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
   274  	testTransportGetGotConnHooks(t, false)
   275  }
   276  func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
   277  
   278  func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
   279  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   280  		io.WriteString(w, r.RemoteAddr)
   281  	}, func(s *httptest.Server) {
   282  		s.EnableHTTP2 = true
   283  	}, optOnlyServer)
   284  	defer st.Close()
   285  
   286  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   287  	client := st.ts.Client()
   288  	ConfigureTransports(client.Transport.(*http.Transport))
   289  
   290  	var (
   291  		getConns int32
   292  		gotConns int32
   293  	)
   294  	for i := 0; i < 2; i++ {
   295  		trace := &httptrace.ClientTrace{
   296  			GetConn: func(hostport string) {
   297  				atomic.AddInt32(&getConns, 1)
   298  			},
   299  			GotConn: func(connInfo httptrace.GotConnInfo) {
   300  				got := atomic.AddInt32(&gotConns, 1)
   301  				wantReused, wantWasIdle := false, false
   302  				if got > 1 {
   303  					wantReused, wantWasIdle = true, true
   304  				}
   305  				if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle {
   306  					t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle)
   307  				}
   308  			},
   309  		}
   310  		req, err := http.NewRequest("GET", st.ts.URL, nil)
   311  		if err != nil {
   312  			t.Fatal(err)
   313  		}
   314  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   315  
   316  		var res *http.Response
   317  		if useClient {
   318  			res, err = client.Do(req)
   319  		} else {
   320  			res, err = tr.RoundTrip(req)
   321  		}
   322  		if err != nil {
   323  			t.Fatal(err)
   324  		}
   325  		res.Body.Close()
   326  		if get := atomic.LoadInt32(&getConns); get != int32(i+1) {
   327  			t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1)
   328  		}
   329  		if got := atomic.LoadInt32(&gotConns); got != int32(i+1) {
   330  			t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1)
   331  		}
   332  	}
   333  }
   334  
   335  type testNetConn struct {
   336  	net.Conn
   337  	closed  bool
   338  	onClose func()
   339  }
   340  
   341  func (c *testNetConn) Close() error {
   342  	if !c.closed {
   343  		// We can call Close multiple times on the same net.Conn.
   344  		c.onClose()
   345  	}
   346  	c.closed = true
   347  	return c.Conn.Close()
   348  }
   349  
   350  // Tests that the Transport only keeps one pending dial open per destination address.
   351  // https://golang.org/issue/13397
   352  func TestTransportGroupsPendingDials(t *testing.T) {
   353  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   354  	}, optOnlyServer)
   355  	defer st.Close()
   356  	var (
   357  		mu         sync.Mutex
   358  		dialCount  int
   359  		closeCount int
   360  	)
   361  	tr := &Transport{
   362  		TLSClientConfig: tlsConfigInsecure,
   363  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   364  			mu.Lock()
   365  			dialCount++
   366  			mu.Unlock()
   367  			c, err := tls.Dial(network, addr, cfg)
   368  			return &testNetConn{
   369  				Conn: c,
   370  				onClose: func() {
   371  					mu.Lock()
   372  					closeCount++
   373  					mu.Unlock()
   374  				},
   375  			}, err
   376  		},
   377  	}
   378  	defer tr.CloseIdleConnections()
   379  	var wg sync.WaitGroup
   380  	for i := 0; i < 10; i++ {
   381  		wg.Add(1)
   382  		go func() {
   383  			defer wg.Done()
   384  			req, err := http.NewRequest("GET", st.ts.URL, nil)
   385  			if err != nil {
   386  				t.Error(err)
   387  				return
   388  			}
   389  			res, err := tr.RoundTrip(req)
   390  			if err != nil {
   391  				t.Error(err)
   392  				return
   393  			}
   394  			res.Body.Close()
   395  		}()
   396  	}
   397  	wg.Wait()
   398  	tr.CloseIdleConnections()
   399  	if dialCount != 1 {
   400  		t.Errorf("saw %d dials; want 1", dialCount)
   401  	}
   402  	if closeCount != 1 {
   403  		t.Errorf("saw %d closes; want 1", closeCount)
   404  	}
   405  }
   406  
   407  func retry(tries int, delay time.Duration, fn func() error) error {
   408  	var err error
   409  	for i := 0; i < tries; i++ {
   410  		err = fn()
   411  		if err == nil {
   412  			return nil
   413  		}
   414  		time.Sleep(delay)
   415  	}
   416  	return err
   417  }
   418  
   419  func TestTransportAbortClosesPipes(t *testing.T) {
   420  	shutdown := make(chan struct{})
   421  	st := newServerTester(t,
   422  		func(w http.ResponseWriter, r *http.Request) {
   423  			w.(http.Flusher).Flush()
   424  			<-shutdown
   425  		},
   426  		optOnlyServer,
   427  	)
   428  	defer st.Close()
   429  	defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
   430  
   431  	errCh := make(chan error)
   432  	go func() {
   433  		defer close(errCh)
   434  		tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   435  		req, err := http.NewRequest("GET", st.ts.URL, nil)
   436  		if err != nil {
   437  			errCh <- err
   438  			return
   439  		}
   440  		res, err := tr.RoundTrip(req)
   441  		if err != nil {
   442  			errCh <- err
   443  			return
   444  		}
   445  		defer res.Body.Close()
   446  		st.closeConn()
   447  		_, err = ioutil.ReadAll(res.Body)
   448  		if err == nil {
   449  			errCh <- errors.New("expected error from res.Body.Read")
   450  			return
   451  		}
   452  	}()
   453  
   454  	select {
   455  	case err := <-errCh:
   456  		if err != nil {
   457  			t.Fatal(err)
   458  		}
   459  	// deadlock? that's a bug.
   460  	case <-time.After(3 * time.Second):
   461  		t.Fatal("timeout")
   462  	}
   463  }
   464  
   465  // TODO: merge this with TestTransportBody to make TestTransportRequest? This
   466  // could be a table-driven test with extra goodies.
   467  func TestTransportPath(t *testing.T) {
   468  	gotc := make(chan *url.URL, 1)
   469  	st := newServerTester(t,
   470  		func(w http.ResponseWriter, r *http.Request) {
   471  			gotc <- r.URL
   472  		},
   473  		optOnlyServer,
   474  	)
   475  	defer st.Close()
   476  
   477  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   478  	defer tr.CloseIdleConnections()
   479  	const (
   480  		path  = "/testpath"
   481  		query = "q=1"
   482  	)
   483  	surl := st.ts.URL + path + "?" + query
   484  	req, err := http.NewRequest("POST", surl, nil)
   485  	if err != nil {
   486  		t.Fatal(err)
   487  	}
   488  	c := &http.Client{Transport: tr}
   489  	res, err := c.Do(req)
   490  	if err != nil {
   491  		t.Fatal(err)
   492  	}
   493  	defer res.Body.Close()
   494  	got := <-gotc
   495  	if got.Path != path {
   496  		t.Errorf("Read Path = %q; want %q", got.Path, path)
   497  	}
   498  	if got.RawQuery != query {
   499  		t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
   500  	}
   501  }
   502  
   503  func randString(n int) string {
   504  	rnd := rand.New(rand.NewSource(int64(n)))
   505  	b := make([]byte, n)
   506  	for i := range b {
   507  		b[i] = byte(rnd.Intn(256))
   508  	}
   509  	return string(b)
   510  }
   511  
   512  type panicReader struct{}
   513  
   514  func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") }
   515  func (panicReader) Close() error             { panic("unexpected Close") }
   516  
   517  func TestActualContentLength(t *testing.T) {
   518  	tests := []struct {
   519  		req  *http.Request
   520  		want int64
   521  	}{
   522  		// Verify we don't read from Body:
   523  		0: {
   524  			req:  &http.Request{Body: panicReader{}},
   525  			want: -1,
   526  		},
   527  		// nil Body means 0, regardless of ContentLength:
   528  		1: {
   529  			req:  &http.Request{Body: nil, ContentLength: 5},
   530  			want: 0,
   531  		},
   532  		// ContentLength is used if set.
   533  		2: {
   534  			req:  &http.Request{Body: panicReader{}, ContentLength: 5},
   535  			want: 5,
   536  		},
   537  		// http.NoBody means 0, not -1.
   538  		3: {
   539  			req:  &http.Request{Body: http.NoBody},
   540  			want: 0,
   541  		},
   542  	}
   543  	for i, tt := range tests {
   544  		got := actualContentLength(tt.req)
   545  		if got != tt.want {
   546  			t.Errorf("test[%d]: got %d; want %d", i, got, tt.want)
   547  		}
   548  	}
   549  }
   550  
   551  func TestTransportBody(t *testing.T) {
   552  	bodyTests := []struct {
   553  		body         string
   554  		noContentLen bool
   555  	}{
   556  		{body: "some message"},
   557  		{body: "some message", noContentLen: true},
   558  		{body: strings.Repeat("a", 1<<20), noContentLen: true},
   559  		{body: strings.Repeat("a", 1<<20)},
   560  		{body: randString(16<<10 - 1)},
   561  		{body: randString(16 << 10)},
   562  		{body: randString(16<<10 + 1)},
   563  		{body: randString(512<<10 - 1)},
   564  		{body: randString(512 << 10)},
   565  		{body: randString(512<<10 + 1)},
   566  		{body: randString(1<<20 - 1)},
   567  		{body: randString(1 << 20)},
   568  		{body: randString(1<<20 + 2)},
   569  	}
   570  
   571  	type reqInfo struct {
   572  		req   *http.Request
   573  		slurp []byte
   574  		err   error
   575  	}
   576  	gotc := make(chan reqInfo, 1)
   577  	st := newServerTester(t,
   578  		func(w http.ResponseWriter, r *http.Request) {
   579  			slurp, err := ioutil.ReadAll(r.Body)
   580  			if err != nil {
   581  				gotc <- reqInfo{err: err}
   582  			} else {
   583  				gotc <- reqInfo{req: r, slurp: slurp}
   584  			}
   585  		},
   586  		optOnlyServer,
   587  	)
   588  	defer st.Close()
   589  
   590  	for i, tt := range bodyTests {
   591  		tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   592  		defer tr.CloseIdleConnections()
   593  
   594  		var body io.Reader = strings.NewReader(tt.body)
   595  		if tt.noContentLen {
   596  			body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
   597  		}
   598  		req, err := http.NewRequest("POST", st.ts.URL, body)
   599  		if err != nil {
   600  			t.Fatalf("#%d: %v", i, err)
   601  		}
   602  		c := &http.Client{Transport: tr}
   603  		res, err := c.Do(req)
   604  		if err != nil {
   605  			t.Fatalf("#%d: %v", i, err)
   606  		}
   607  		defer res.Body.Close()
   608  		ri := <-gotc
   609  		if ri.err != nil {
   610  			t.Errorf("#%d: read error: %v", i, ri.err)
   611  			continue
   612  		}
   613  		if got := string(ri.slurp); got != tt.body {
   614  			t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
   615  		}
   616  		wantLen := int64(len(tt.body))
   617  		if tt.noContentLen && tt.body != "" {
   618  			wantLen = -1
   619  		}
   620  		if ri.req.ContentLength != wantLen {
   621  			t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
   622  		}
   623  	}
   624  }
   625  
   626  func shortString(v string) string {
   627  	const maxLen = 100
   628  	if len(v) <= maxLen {
   629  		return v
   630  	}
   631  	return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
   632  }
   633  
   634  func TestTransportDialTLS(t *testing.T) {
   635  	var mu sync.Mutex // guards following
   636  	var gotReq, didDial bool
   637  
   638  	ts := newServerTester(t,
   639  		func(w http.ResponseWriter, r *http.Request) {
   640  			mu.Lock()
   641  			gotReq = true
   642  			mu.Unlock()
   643  		},
   644  		optOnlyServer,
   645  	)
   646  	defer ts.Close()
   647  	tr := &Transport{
   648  		DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
   649  			mu.Lock()
   650  			didDial = true
   651  			mu.Unlock()
   652  			cfg.InsecureSkipVerify = true
   653  			c, err := tls.Dial(netw, addr, cfg)
   654  			if err != nil {
   655  				return nil, err
   656  			}
   657  			return c, c.Handshake()
   658  		},
   659  	}
   660  	defer tr.CloseIdleConnections()
   661  	client := &http.Client{Transport: tr}
   662  	res, err := client.Get(ts.ts.URL)
   663  	if err != nil {
   664  		t.Fatal(err)
   665  	}
   666  	res.Body.Close()
   667  	mu.Lock()
   668  	if !gotReq {
   669  		t.Error("didn't get request")
   670  	}
   671  	if !didDial {
   672  		t.Error("didn't use dial hook")
   673  	}
   674  }
   675  
   676  func TestConfigureTransport(t *testing.T) {
   677  	t1 := &http.Transport{}
   678  	err := ConfigureTransport(t1)
   679  	if err != nil {
   680  		t.Fatal(err)
   681  	}
   682  	if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) {
   683  		// Laziness, to avoid buildtags.
   684  		t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
   685  	}
   686  	wantNextProtos := []string{"h2", "http/1.1"}
   687  	if t1.TLSClientConfig == nil {
   688  		t.Errorf("nil t1.TLSClientConfig")
   689  	} else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
   690  		t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
   691  	}
   692  	if err := ConfigureTransport(t1); err == nil {
   693  		t.Error("unexpected success on second call to ConfigureTransport")
   694  	}
   695  
   696  	// And does it work?
   697  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   698  		io.WriteString(w, r.Proto)
   699  	}, optOnlyServer)
   700  	defer st.Close()
   701  
   702  	t1.TLSClientConfig.InsecureSkipVerify = true
   703  	c := &http.Client{Transport: t1}
   704  	res, err := c.Get(st.ts.URL)
   705  	if err != nil {
   706  		t.Fatal(err)
   707  	}
   708  	slurp, err := ioutil.ReadAll(res.Body)
   709  	if err != nil {
   710  		t.Fatal(err)
   711  	}
   712  	if got, want := string(slurp), "HTTP/2.0"; got != want {
   713  		t.Errorf("body = %q; want %q", got, want)
   714  	}
   715  }
   716  
   717  type capitalizeReader struct {
   718  	r io.Reader
   719  }
   720  
   721  func (cr capitalizeReader) Read(p []byte) (n int, err error) {
   722  	n, err = cr.r.Read(p)
   723  	for i, b := range p[:n] {
   724  		if b >= 'a' && b <= 'z' {
   725  			p[i] = b - ('a' - 'A')
   726  		}
   727  	}
   728  	return
   729  }
   730  
   731  type flushWriter struct {
   732  	w io.Writer
   733  }
   734  
   735  func (fw flushWriter) Write(p []byte) (n int, err error) {
   736  	n, err = fw.w.Write(p)
   737  	if f, ok := fw.w.(http.Flusher); ok {
   738  		f.Flush()
   739  	}
   740  	return
   741  }
   742  
   743  type clientTester struct {
   744  	t        *testing.T
   745  	tr       *Transport
   746  	sc, cc   net.Conn // server and client conn
   747  	fr       *Framer  // server's framer
   748  	settings *SettingsFrame
   749  	client   func() error
   750  	server   func() error
   751  }
   752  
   753  func newClientTester(t *testing.T) *clientTester {
   754  	var dialOnce struct {
   755  		sync.Mutex
   756  		dialed bool
   757  	}
   758  	ct := &clientTester{
   759  		t: t,
   760  	}
   761  	ct.tr = &Transport{
   762  		TLSClientConfig: tlsConfigInsecure,
   763  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   764  			dialOnce.Lock()
   765  			defer dialOnce.Unlock()
   766  			if dialOnce.dialed {
   767  				return nil, errors.New("only one dial allowed in test mode")
   768  			}
   769  			dialOnce.dialed = true
   770  			return ct.cc, nil
   771  		},
   772  	}
   773  
   774  	ln := newLocalListener(t)
   775  	cc, err := net.Dial("tcp", ln.Addr().String())
   776  	if err != nil {
   777  		t.Fatal(err)
   778  	}
   779  	sc, err := ln.Accept()
   780  	if err != nil {
   781  		t.Fatal(err)
   782  	}
   783  	ln.Close()
   784  	ct.cc = cc
   785  	ct.sc = sc
   786  	ct.fr = NewFramer(sc, sc)
   787  	return ct
   788  }
   789  
   790  func newLocalListener(t *testing.T) net.Listener {
   791  	ln, err := net.Listen("tcp4", "127.0.0.1:0")
   792  	if err == nil {
   793  		return ln
   794  	}
   795  	ln, err = net.Listen("tcp6", "[::1]:0")
   796  	if err != nil {
   797  		t.Fatal(err)
   798  	}
   799  	return ln
   800  }
   801  
   802  func (ct *clientTester) greet(settings ...Setting) {
   803  	buf := make([]byte, len(ClientPreface))
   804  	_, err := io.ReadFull(ct.sc, buf)
   805  	if err != nil {
   806  		ct.t.Fatalf("reading client preface: %v", err)
   807  	}
   808  	f, err := ct.fr.ReadFrame()
   809  	if err != nil {
   810  		ct.t.Fatalf("Reading client settings frame: %v", err)
   811  	}
   812  	var ok bool
   813  	if ct.settings, ok = f.(*SettingsFrame); !ok {
   814  		ct.t.Fatalf("Wanted client settings frame; got %v", f)
   815  	}
   816  	if err := ct.fr.WriteSettings(settings...); err != nil {
   817  		ct.t.Fatal(err)
   818  	}
   819  	if err := ct.fr.WriteSettingsAck(); err != nil {
   820  		ct.t.Fatal(err)
   821  	}
   822  }
   823  
   824  func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
   825  	for {
   826  		f, err := ct.fr.ReadFrame()
   827  		if err != nil {
   828  			return nil, err
   829  		}
   830  		if _, ok := f.(*SettingsFrame); ok {
   831  			continue
   832  		}
   833  		return f, nil
   834  	}
   835  }
   836  
   837  // writeReadPing sends a PING and immediately reads the PING ACK.
   838  // It will fail if any other unread data was pending on the connection,
   839  // aside from SETTINGS frames.
   840  func (ct *clientTester) writeReadPing() error {
   841  	data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
   842  	if err := ct.fr.WritePing(false, data); err != nil {
   843  		return fmt.Errorf("Error writing PING: %v", err)
   844  	}
   845  	f, err := ct.readNonSettingsFrame()
   846  	if err != nil {
   847  		return err
   848  	}
   849  	p, ok := f.(*PingFrame)
   850  	if !ok {
   851  		return fmt.Errorf("got a %v, want a PING ACK", f)
   852  	}
   853  	if p.Flags&FlagPingAck == 0 {
   854  		return fmt.Errorf("got a PING, want a PING ACK")
   855  	}
   856  	if p.Data != data {
   857  		return fmt.Errorf("got PING data = %x, want %x", p.Data, data)
   858  	}
   859  	return nil
   860  }
   861  
   862  func (ct *clientTester) inflowWindow(streamID uint32) int32 {
   863  	pool := ct.tr.connPoolOrDef.(*clientConnPool)
   864  	pool.mu.Lock()
   865  	defer pool.mu.Unlock()
   866  	if n := len(pool.keys); n != 1 {
   867  		ct.t.Errorf("clientConnPool contains %v keys, expected 1", n)
   868  		return -1
   869  	}
   870  	for cc := range pool.keys {
   871  		cc.mu.Lock()
   872  		defer cc.mu.Unlock()
   873  		if streamID == 0 {
   874  			return cc.inflow.avail + cc.inflow.unsent
   875  		}
   876  		cs := cc.streams[streamID]
   877  		if cs == nil {
   878  			ct.t.Errorf("no stream with id %v", streamID)
   879  			return -1
   880  		}
   881  		return cs.inflow.avail + cs.inflow.unsent
   882  	}
   883  	return -1
   884  }
   885  
   886  func (ct *clientTester) cleanup() {
   887  	ct.tr.CloseIdleConnections()
   888  
   889  	// close both connections, ignore the error if its already closed
   890  	ct.sc.Close()
   891  	ct.cc.Close()
   892  }
   893  
   894  func (ct *clientTester) run() {
   895  	var errOnce sync.Once
   896  	var wg sync.WaitGroup
   897  
   898  	run := func(which string, fn func() error) {
   899  		defer wg.Done()
   900  		if err := fn(); err != nil {
   901  			errOnce.Do(func() {
   902  				ct.t.Errorf("%s: %v", which, err)
   903  				ct.cleanup()
   904  			})
   905  		}
   906  	}
   907  
   908  	wg.Add(2)
   909  	go run("client", ct.client)
   910  	go run("server", ct.server)
   911  	wg.Wait()
   912  
   913  	errOnce.Do(ct.cleanup) // clean up if no error
   914  }
   915  
   916  func (ct *clientTester) readFrame() (Frame, error) {
   917  	return ct.fr.ReadFrame()
   918  }
   919  
   920  func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
   921  	for {
   922  		f, err := ct.readFrame()
   923  		if err != nil {
   924  			return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
   925  		}
   926  		switch f.(type) {
   927  		case *WindowUpdateFrame, *SettingsFrame:
   928  			continue
   929  		}
   930  		hf, ok := f.(*HeadersFrame)
   931  		if !ok {
   932  			return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
   933  		}
   934  		return hf, nil
   935  	}
   936  }
   937  
   938  type countingReader struct {
   939  	n *int64
   940  }
   941  
   942  func (r countingReader) Read(p []byte) (n int, err error) {
   943  	for i := range p {
   944  		p[i] = byte(i)
   945  	}
   946  	atomic.AddInt64(r.n, int64(len(p)))
   947  	return len(p), err
   948  }
   949  
   950  func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
   951  func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
   952  
   953  func testTransportReqBodyAfterResponse(t *testing.T, status int) {
   954  	const bodySize = 10 << 20
   955  	clientDone := make(chan struct{})
   956  	ct := newClientTester(t)
   957  	recvLen := make(chan int64, 1)
   958  	ct.client = func() error {
   959  		defer ct.cc.(*net.TCPConn).CloseWrite()
   960  		if runtime.GOOS == "plan9" {
   961  			// CloseWrite not supported on Plan 9; Issue 17906
   962  			defer ct.cc.(*net.TCPConn).Close()
   963  		}
   964  		defer close(clientDone)
   965  
   966  		body := &pipe{b: new(bytes.Buffer)}
   967  		io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
   968  		req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
   969  		if err != nil {
   970  			return err
   971  		}
   972  		res, err := ct.tr.RoundTrip(req)
   973  		if err != nil {
   974  			return fmt.Errorf("RoundTrip: %v", err)
   975  		}
   976  		if res.StatusCode != status {
   977  			return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
   978  		}
   979  		io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
   980  		body.CloseWithError(io.EOF)
   981  		slurp, err := ioutil.ReadAll(res.Body)
   982  		if err != nil {
   983  			return fmt.Errorf("Slurp: %v", err)
   984  		}
   985  		if len(slurp) > 0 {
   986  			return fmt.Errorf("unexpected body: %q", slurp)
   987  		}
   988  		res.Body.Close()
   989  		if status == 200 {
   990  			if got := <-recvLen; got != bodySize {
   991  				return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
   992  			}
   993  		} else {
   994  			if got := <-recvLen; got == 0 || got >= bodySize {
   995  				return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
   996  			}
   997  		}
   998  		return nil
   999  	}
  1000  	ct.server = func() error {
  1001  		ct.greet()
  1002  		defer close(recvLen)
  1003  		var buf bytes.Buffer
  1004  		enc := hpack.NewEncoder(&buf)
  1005  		var dataRecv int64
  1006  		var closed bool
  1007  		for {
  1008  			f, err := ct.fr.ReadFrame()
  1009  			if err != nil {
  1010  				select {
  1011  				case <-clientDone:
  1012  					// If the client's done, it
  1013  					// will have reported any
  1014  					// errors on its side.
  1015  					return nil
  1016  				default:
  1017  					return err
  1018  				}
  1019  			}
  1020  			//println(fmt.Sprintf("server got frame: %v", f))
  1021  			ended := false
  1022  			switch f := f.(type) {
  1023  			case *WindowUpdateFrame, *SettingsFrame:
  1024  			case *HeadersFrame:
  1025  				if !f.HeadersEnded() {
  1026  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  1027  				}
  1028  				if f.StreamEnded() {
  1029  					return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
  1030  				}
  1031  			case *DataFrame:
  1032  				dataLen := len(f.Data())
  1033  				if dataLen > 0 {
  1034  					if dataRecv == 0 {
  1035  						enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
  1036  						ct.fr.WriteHeaders(HeadersFrameParam{
  1037  							StreamID:      f.StreamID,
  1038  							EndHeaders:    true,
  1039  							EndStream:     false,
  1040  							BlockFragment: buf.Bytes(),
  1041  						})
  1042  					}
  1043  					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
  1044  						return err
  1045  					}
  1046  					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
  1047  						return err
  1048  					}
  1049  				}
  1050  				dataRecv += int64(dataLen)
  1051  
  1052  				if !closed && ((status != 200 && dataRecv > 0) ||
  1053  					(status == 200 && f.StreamEnded())) {
  1054  					closed = true
  1055  					if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
  1056  						return err
  1057  					}
  1058  				}
  1059  
  1060  				if f.StreamEnded() {
  1061  					ended = true
  1062  				}
  1063  			case *RSTStreamFrame:
  1064  				if status == 200 {
  1065  					return fmt.Errorf("Unexpected client frame %v", f)
  1066  				}
  1067  				ended = true
  1068  			default:
  1069  				return fmt.Errorf("Unexpected client frame %v", f)
  1070  			}
  1071  			if ended {
  1072  				select {
  1073  				case recvLen <- dataRecv:
  1074  				default:
  1075  				}
  1076  			}
  1077  		}
  1078  	}
  1079  	ct.run()
  1080  }
  1081  
  1082  // See golang.org/issue/13444
  1083  func TestTransportFullDuplex(t *testing.T) {
  1084  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  1085  		w.WriteHeader(200) // redundant but for clarity
  1086  		w.(http.Flusher).Flush()
  1087  		io.Copy(flushWriter{w}, capitalizeReader{r.Body})
  1088  		fmt.Fprintf(w, "bye.\n")
  1089  	}, optOnlyServer)
  1090  	defer st.Close()
  1091  
  1092  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1093  	defer tr.CloseIdleConnections()
  1094  	c := &http.Client{Transport: tr}
  1095  
  1096  	pr, pw := io.Pipe()
  1097  	req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
  1098  	if err != nil {
  1099  		t.Fatal(err)
  1100  	}
  1101  	req.ContentLength = -1
  1102  	res, err := c.Do(req)
  1103  	if err != nil {
  1104  		t.Fatal(err)
  1105  	}
  1106  	defer res.Body.Close()
  1107  	if res.StatusCode != 200 {
  1108  		t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
  1109  	}
  1110  	bs := bufio.NewScanner(res.Body)
  1111  	want := func(v string) {
  1112  		if !bs.Scan() {
  1113  			t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
  1114  		}
  1115  	}
  1116  	write := func(v string) {
  1117  		_, err := io.WriteString(pw, v)
  1118  		if err != nil {
  1119  			t.Fatalf("pipe write: %v", err)
  1120  		}
  1121  	}
  1122  	write("foo\n")
  1123  	want("FOO")
  1124  	write("bar\n")
  1125  	want("BAR")
  1126  	pw.Close()
  1127  	want("bye.")
  1128  	if err := bs.Err(); err != nil {
  1129  		t.Fatal(err)
  1130  	}
  1131  }
  1132  
  1133  func TestTransportConnectRequest(t *testing.T) {
  1134  	gotc := make(chan *http.Request, 1)
  1135  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  1136  		gotc <- r
  1137  	}, optOnlyServer)
  1138  	defer st.Close()
  1139  
  1140  	u, err := url.Parse(st.ts.URL)
  1141  	if err != nil {
  1142  		t.Fatal(err)
  1143  	}
  1144  
  1145  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1146  	defer tr.CloseIdleConnections()
  1147  	c := &http.Client{Transport: tr}
  1148  
  1149  	tests := []struct {
  1150  		req  *http.Request
  1151  		want string
  1152  	}{
  1153  		{
  1154  			req: &http.Request{
  1155  				Method: "CONNECT",
  1156  				Header: http.Header{},
  1157  				URL:    u,
  1158  			},
  1159  			want: u.Host,
  1160  		},
  1161  		{
  1162  			req: &http.Request{
  1163  				Method: "CONNECT",
  1164  				Header: http.Header{},
  1165  				URL:    u,
  1166  				Host:   "example.com:123",
  1167  			},
  1168  			want: "example.com:123",
  1169  		},
  1170  	}
  1171  
  1172  	for i, tt := range tests {
  1173  		res, err := c.Do(tt.req)
  1174  		if err != nil {
  1175  			t.Errorf("%d. RoundTrip = %v", i, err)
  1176  			continue
  1177  		}
  1178  		res.Body.Close()
  1179  		req := <-gotc
  1180  		if req.Method != "CONNECT" {
  1181  			t.Errorf("method = %q; want CONNECT", req.Method)
  1182  		}
  1183  		if req.Host != tt.want {
  1184  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
  1185  		}
  1186  		if req.URL.Host != tt.want {
  1187  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
  1188  		}
  1189  	}
  1190  }
  1191  
  1192  type headerType int
  1193  
  1194  const (
  1195  	noHeader headerType = iota // omitted
  1196  	oneHeader
  1197  	splitHeader // broken into continuation on purpose
  1198  )
  1199  
  1200  const (
  1201  	f0 = noHeader
  1202  	f1 = oneHeader
  1203  	f2 = splitHeader
  1204  	d0 = false
  1205  	d1 = true
  1206  )
  1207  
  1208  // Test all 36 combinations of response frame orders:
  1209  //
  1210  //	(3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
  1211  //
  1212  // Generated by http://play.golang.org/p/SScqYKJYXd
  1213  func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
  1214  func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
  1215  func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
  1216  func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
  1217  func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
  1218  func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
  1219  func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
  1220  func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
  1221  func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
  1222  func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
  1223  func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
  1224  func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
  1225  func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
  1226  func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
  1227  func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
  1228  func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
  1229  func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
  1230  func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
  1231  func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
  1232  func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
  1233  func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
  1234  func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
  1235  func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
  1236  func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
  1237  func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
  1238  func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
  1239  func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
  1240  func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
  1241  func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
  1242  func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
  1243  func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
  1244  func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
  1245  func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
  1246  func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
  1247  func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
  1248  func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
  1249  
  1250  func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
  1251  	const reqBody = "some request body"
  1252  	const resBody = "some response body"
  1253  
  1254  	if resHeader == noHeader {
  1255  		// TODO: test 100-continue followed by immediate
  1256  		// server stream reset, without headers in the middle?
  1257  		panic("invalid combination")
  1258  	}
  1259  
  1260  	ct := newClientTester(t)
  1261  	ct.client = func() error {
  1262  		req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
  1263  		if expect100Continue != noHeader {
  1264  			req.Header.Set("Expect", "100-continue")
  1265  		}
  1266  		res, err := ct.tr.RoundTrip(req)
  1267  		if err != nil {
  1268  			return fmt.Errorf("RoundTrip: %v", err)
  1269  		}
  1270  		defer res.Body.Close()
  1271  		if res.StatusCode != 200 {
  1272  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1273  		}
  1274  		slurp, err := ioutil.ReadAll(res.Body)
  1275  		if err != nil {
  1276  			return fmt.Errorf("Slurp: %v", err)
  1277  		}
  1278  		wantBody := resBody
  1279  		if !withData {
  1280  			wantBody = ""
  1281  		}
  1282  		if string(slurp) != wantBody {
  1283  			return fmt.Errorf("body = %q; want %q", slurp, wantBody)
  1284  		}
  1285  		if trailers == noHeader {
  1286  			if len(res.Trailer) > 0 {
  1287  				t.Errorf("Trailer = %v; want none", res.Trailer)
  1288  			}
  1289  		} else {
  1290  			want := http.Header{"Some-Trailer": {"some-value"}}
  1291  			if !reflect.DeepEqual(res.Trailer, want) {
  1292  				t.Errorf("Trailer = %v; want %v", res.Trailer, want)
  1293  			}
  1294  		}
  1295  		return nil
  1296  	}
  1297  	ct.server = func() error {
  1298  		ct.greet()
  1299  		var buf bytes.Buffer
  1300  		enc := hpack.NewEncoder(&buf)
  1301  
  1302  		for {
  1303  			f, err := ct.fr.ReadFrame()
  1304  			if err != nil {
  1305  				return err
  1306  			}
  1307  			endStream := false
  1308  			send := func(mode headerType) {
  1309  				hbf := buf.Bytes()
  1310  				switch mode {
  1311  				case oneHeader:
  1312  					ct.fr.WriteHeaders(HeadersFrameParam{
  1313  						StreamID:      f.Header().StreamID,
  1314  						EndHeaders:    true,
  1315  						EndStream:     endStream,
  1316  						BlockFragment: hbf,
  1317  					})
  1318  				case splitHeader:
  1319  					if len(hbf) < 2 {
  1320  						panic("too small")
  1321  					}
  1322  					ct.fr.WriteHeaders(HeadersFrameParam{
  1323  						StreamID:      f.Header().StreamID,
  1324  						EndHeaders:    false,
  1325  						EndStream:     endStream,
  1326  						BlockFragment: hbf[:1],
  1327  					})
  1328  					ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
  1329  				default:
  1330  					panic("bogus mode")
  1331  				}
  1332  			}
  1333  			switch f := f.(type) {
  1334  			case *WindowUpdateFrame, *SettingsFrame:
  1335  			case *DataFrame:
  1336  				if !f.StreamEnded() {
  1337  					// No need to send flow control tokens. The test request body is tiny.
  1338  					continue
  1339  				}
  1340  				// Response headers (1+ frames; 1 or 2 in this test, but never 0)
  1341  				{
  1342  					buf.Reset()
  1343  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1344  					enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
  1345  					enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
  1346  					if trailers != noHeader {
  1347  						enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
  1348  					}
  1349  					endStream = withData == false && trailers == noHeader
  1350  					send(resHeader)
  1351  				}
  1352  				if withData {
  1353  					endStream = trailers == noHeader
  1354  					ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
  1355  				}
  1356  				if trailers != noHeader {
  1357  					endStream = true
  1358  					buf.Reset()
  1359  					enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
  1360  					send(trailers)
  1361  				}
  1362  				if endStream {
  1363  					return nil
  1364  				}
  1365  			case *HeadersFrame:
  1366  				if expect100Continue != noHeader {
  1367  					buf.Reset()
  1368  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
  1369  					send(expect100Continue)
  1370  				}
  1371  			}
  1372  		}
  1373  	}
  1374  	ct.run()
  1375  }
  1376  
  1377  // Issue 26189, Issue 17739: ignore unknown 1xx responses
  1378  func TestTransportUnknown1xx(t *testing.T) {
  1379  	var buf bytes.Buffer
  1380  	defer func() { got1xxFuncForTests = nil }()
  1381  	got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error {
  1382  		fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
  1383  		return nil
  1384  	}
  1385  
  1386  	ct := newClientTester(t)
  1387  	ct.client = func() error {
  1388  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1389  		res, err := ct.tr.RoundTrip(req)
  1390  		if err != nil {
  1391  			return fmt.Errorf("RoundTrip: %v", err)
  1392  		}
  1393  		defer res.Body.Close()
  1394  		if res.StatusCode != 204 {
  1395  			return fmt.Errorf("status code = %v; want 204", res.StatusCode)
  1396  		}
  1397  		want := `code=110 header=map[Foo-Bar:[110]]
  1398  code=111 header=map[Foo-Bar:[111]]
  1399  code=112 header=map[Foo-Bar:[112]]
  1400  code=113 header=map[Foo-Bar:[113]]
  1401  code=114 header=map[Foo-Bar:[114]]
  1402  `
  1403  		if got := buf.String(); got != want {
  1404  			t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
  1405  		}
  1406  		return nil
  1407  	}
  1408  	ct.server = func() error {
  1409  		ct.greet()
  1410  		var buf bytes.Buffer
  1411  		enc := hpack.NewEncoder(&buf)
  1412  
  1413  		for {
  1414  			f, err := ct.fr.ReadFrame()
  1415  			if err != nil {
  1416  				return err
  1417  			}
  1418  			switch f := f.(type) {
  1419  			case *WindowUpdateFrame, *SettingsFrame:
  1420  			case *HeadersFrame:
  1421  				for i := 110; i <= 114; i++ {
  1422  					buf.Reset()
  1423  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)})
  1424  					enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)})
  1425  					ct.fr.WriteHeaders(HeadersFrameParam{
  1426  						StreamID:      f.StreamID,
  1427  						EndHeaders:    true,
  1428  						EndStream:     false,
  1429  						BlockFragment: buf.Bytes(),
  1430  					})
  1431  				}
  1432  				buf.Reset()
  1433  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
  1434  				ct.fr.WriteHeaders(HeadersFrameParam{
  1435  					StreamID:      f.StreamID,
  1436  					EndHeaders:    true,
  1437  					EndStream:     false,
  1438  					BlockFragment: buf.Bytes(),
  1439  				})
  1440  				return nil
  1441  			}
  1442  		}
  1443  	}
  1444  	ct.run()
  1445  
  1446  }
  1447  
  1448  func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
  1449  	ct := newClientTester(t)
  1450  	ct.client = func() error {
  1451  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1452  		res, err := ct.tr.RoundTrip(req)
  1453  		if err != nil {
  1454  			return fmt.Errorf("RoundTrip: %v", err)
  1455  		}
  1456  		defer res.Body.Close()
  1457  		if res.StatusCode != 200 {
  1458  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1459  		}
  1460  		slurp, err := ioutil.ReadAll(res.Body)
  1461  		if err != nil {
  1462  			return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
  1463  		}
  1464  		if len(slurp) > 0 {
  1465  			return fmt.Errorf("body = %q; want nothing", slurp)
  1466  		}
  1467  		if _, ok := res.Trailer["Some-Trailer"]; !ok {
  1468  			return fmt.Errorf("expected Some-Trailer")
  1469  		}
  1470  		return nil
  1471  	}
  1472  	ct.server = func() error {
  1473  		ct.greet()
  1474  
  1475  		var n int
  1476  		var hf *HeadersFrame
  1477  		for hf == nil && n < 10 {
  1478  			f, err := ct.fr.ReadFrame()
  1479  			if err != nil {
  1480  				return err
  1481  			}
  1482  			hf, _ = f.(*HeadersFrame)
  1483  			n++
  1484  		}
  1485  
  1486  		var buf bytes.Buffer
  1487  		enc := hpack.NewEncoder(&buf)
  1488  
  1489  		// send headers without Trailer header
  1490  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1491  		ct.fr.WriteHeaders(HeadersFrameParam{
  1492  			StreamID:      hf.StreamID,
  1493  			EndHeaders:    true,
  1494  			EndStream:     false,
  1495  			BlockFragment: buf.Bytes(),
  1496  		})
  1497  
  1498  		// send trailers
  1499  		buf.Reset()
  1500  		enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
  1501  		ct.fr.WriteHeaders(HeadersFrameParam{
  1502  			StreamID:      hf.StreamID,
  1503  			EndHeaders:    true,
  1504  			EndStream:     true,
  1505  			BlockFragment: buf.Bytes(),
  1506  		})
  1507  		return nil
  1508  	}
  1509  	ct.run()
  1510  }
  1511  
  1512  func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
  1513  	testTransportInvalidTrailer_Pseudo(t, oneHeader)
  1514  }
  1515  func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
  1516  	testTransportInvalidTrailer_Pseudo(t, splitHeader)
  1517  }
  1518  func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
  1519  	testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
  1520  		enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
  1521  		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  1522  	})
  1523  }
  1524  
  1525  func TestTransportInvalidTrailer_Capital1(t *testing.T) {
  1526  	testTransportInvalidTrailer_Capital(t, oneHeader)
  1527  }
  1528  func TestTransportInvalidTrailer_Capital2(t *testing.T) {
  1529  	testTransportInvalidTrailer_Capital(t, splitHeader)
  1530  }
  1531  func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
  1532  	testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
  1533  		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  1534  		enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
  1535  	})
  1536  }
  1537  func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
  1538  	testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
  1539  		enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
  1540  	})
  1541  }
  1542  func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
  1543  	testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) {
  1544  		enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
  1545  	})
  1546  }
  1547  
  1548  func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
  1549  	ct := newClientTester(t)
  1550  	ct.client = func() error {
  1551  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1552  		res, err := ct.tr.RoundTrip(req)
  1553  		if err != nil {
  1554  			return fmt.Errorf("RoundTrip: %v", err)
  1555  		}
  1556  		defer res.Body.Close()
  1557  		if res.StatusCode != 200 {
  1558  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1559  		}
  1560  		slurp, err := ioutil.ReadAll(res.Body)
  1561  		se, ok := err.(StreamError)
  1562  		if !ok || se.Cause != wantErr {
  1563  			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
  1564  		}
  1565  		if len(slurp) > 0 {
  1566  			return fmt.Errorf("body = %q; want nothing", slurp)
  1567  		}
  1568  		return nil
  1569  	}
  1570  	ct.server = func() error {
  1571  		ct.greet()
  1572  		var buf bytes.Buffer
  1573  		enc := hpack.NewEncoder(&buf)
  1574  
  1575  		for {
  1576  			f, err := ct.fr.ReadFrame()
  1577  			if err != nil {
  1578  				return err
  1579  			}
  1580  			switch f := f.(type) {
  1581  			case *HeadersFrame:
  1582  				var endStream bool
  1583  				send := func(mode headerType) {
  1584  					hbf := buf.Bytes()
  1585  					switch mode {
  1586  					case oneHeader:
  1587  						ct.fr.WriteHeaders(HeadersFrameParam{
  1588  							StreamID:      f.StreamID,
  1589  							EndHeaders:    true,
  1590  							EndStream:     endStream,
  1591  							BlockFragment: hbf,
  1592  						})
  1593  					case splitHeader:
  1594  						if len(hbf) < 2 {
  1595  							panic("too small")
  1596  						}
  1597  						ct.fr.WriteHeaders(HeadersFrameParam{
  1598  							StreamID:      f.StreamID,
  1599  							EndHeaders:    false,
  1600  							EndStream:     endStream,
  1601  							BlockFragment: hbf[:1],
  1602  						})
  1603  						ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
  1604  					default:
  1605  						panic("bogus mode")
  1606  					}
  1607  				}
  1608  				// Response headers (1+ frames; 1 or 2 in this test, but never 0)
  1609  				{
  1610  					buf.Reset()
  1611  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1612  					enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
  1613  					endStream = false
  1614  					send(oneHeader)
  1615  				}
  1616  				// Trailers:
  1617  				{
  1618  					endStream = true
  1619  					buf.Reset()
  1620  					writeTrailer(enc)
  1621  					send(trailers)
  1622  				}
  1623  				return nil
  1624  			}
  1625  		}
  1626  	}
  1627  	ct.run()
  1628  }
  1629  
  1630  // headerListSize returns the HTTP2 header list size of h.
  1631  //
  1632  //	http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
  1633  //	http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock
  1634  func headerListSize(h http.Header) (size uint32) {
  1635  	for k, vv := range h {
  1636  		for _, v := range vv {
  1637  			hf := hpack.HeaderField{Name: k, Value: v}
  1638  			size += hf.Size()
  1639  		}
  1640  	}
  1641  	return size
  1642  }
  1643  
  1644  // padHeaders adds data to an http.Header until headerListSize(h) ==
  1645  // limit. Due to the way header list sizes are calculated, padHeaders
  1646  // cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will
  1647  // call t.Fatal if asked to do so. PadHeaders first reserves enough
  1648  // space for an empty "Pad-Headers" key, then adds as many copies of
  1649  // filler as possible. Any remaining bytes necessary to push the
  1650  // header list size up to limit are added to h["Pad-Headers"].
  1651  func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
  1652  	if limit > 0xffffffff {
  1653  		t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
  1654  	}
  1655  	hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
  1656  	minPadding := uint64(hf.Size())
  1657  	size := uint64(headerListSize(h))
  1658  
  1659  	minlimit := size + minPadding
  1660  	if limit < minlimit {
  1661  		t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
  1662  	}
  1663  
  1664  	// Use a fixed-width format for name so that fieldSize
  1665  	// remains constant.
  1666  	nameFmt := "Pad-Headers-%06d"
  1667  	hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
  1668  	fieldSize := uint64(hf.Size())
  1669  
  1670  	// Add as many complete filler values as possible, leaving
  1671  	// room for at least one empty "Pad-Headers" key.
  1672  	limit = limit - minPadding
  1673  	for i := 0; size+fieldSize < limit; i++ {
  1674  		name := fmt.Sprintf(nameFmt, i)
  1675  		h.Add(name, filler)
  1676  		size += fieldSize
  1677  	}
  1678  
  1679  	// Add enough bytes to reach limit.
  1680  	remain := limit - size
  1681  	lastValue := strings.Repeat("*", int(remain))
  1682  	h.Add("Pad-Headers", lastValue)
  1683  }
  1684  
  1685  func TestPadHeaders(t *testing.T) {
  1686  	check := func(h http.Header, limit uint32, fillerLen int) {
  1687  		if h == nil {
  1688  			h = make(http.Header)
  1689  		}
  1690  		filler := strings.Repeat("f", fillerLen)
  1691  		padHeaders(t, h, uint64(limit), filler)
  1692  		gotSize := headerListSize(h)
  1693  		if gotSize != limit {
  1694  			t.Errorf("Got size = %v; want %v", gotSize, limit)
  1695  		}
  1696  	}
  1697  	// Try all possible combinations for small fillerLen and limit.
  1698  	hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
  1699  	minLimit := hf.Size()
  1700  	for limit := minLimit; limit <= 128; limit++ {
  1701  		for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
  1702  			check(nil, limit, fillerLen)
  1703  		}
  1704  	}
  1705  
  1706  	// Try a few tests with larger limits, plus cumulative
  1707  	// tests. Since these tests are cumulative, tests[i+1].limit
  1708  	// must be >= tests[i].limit + minLimit. See the comment on
  1709  	// padHeaders for more info on why the limit arg has this
  1710  	// restriction.
  1711  	tests := []struct {
  1712  		fillerLen int
  1713  		limit     uint32
  1714  	}{
  1715  		{
  1716  			fillerLen: 64,
  1717  			limit:     1024,
  1718  		},
  1719  		{
  1720  			fillerLen: 1024,
  1721  			limit:     1286,
  1722  		},
  1723  		{
  1724  			fillerLen: 256,
  1725  			limit:     2048,
  1726  		},
  1727  		{
  1728  			fillerLen: 1024,
  1729  			limit:     10 * 1024,
  1730  		},
  1731  		{
  1732  			fillerLen: 1023,
  1733  			limit:     11 * 1024,
  1734  		},
  1735  	}
  1736  	h := make(http.Header)
  1737  	for _, tc := range tests {
  1738  		check(nil, tc.limit, tc.fillerLen)
  1739  		check(h, tc.limit, tc.fillerLen)
  1740  	}
  1741  }
  1742  
  1743  func TestTransportChecksRequestHeaderListSize(t *testing.T) {
  1744  	st := newServerTester(t,
  1745  		func(w http.ResponseWriter, r *http.Request) {
  1746  			// Consume body & force client to send
  1747  			// trailers before writing response.
  1748  			// ioutil.ReadAll returns non-nil err for
  1749  			// requests that attempt to send greater than
  1750  			// maxHeaderListSize bytes of trailers, since
  1751  			// those requests generate a stream reset.
  1752  			ioutil.ReadAll(r.Body)
  1753  			r.Body.Close()
  1754  		},
  1755  		func(ts *httptest.Server) {
  1756  			ts.Config.MaxHeaderBytes = 16 << 10
  1757  		},
  1758  		optOnlyServer,
  1759  		optQuiet,
  1760  	)
  1761  	defer st.Close()
  1762  
  1763  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1764  	defer tr.CloseIdleConnections()
  1765  
  1766  	checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
  1767  		// Make an arbitrary request to ensure we get the server's
  1768  		// settings frame and initialize peerMaxHeaderListSize.
  1769  		req0, err := http.NewRequest("GET", st.ts.URL, nil)
  1770  		if err != nil {
  1771  			t.Fatalf("newRequest: NewRequest: %v", err)
  1772  		}
  1773  		res0, err := tr.RoundTrip(req0)
  1774  		if err != nil {
  1775  			t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
  1776  		}
  1777  		res0.Body.Close()
  1778  
  1779  		res, err := tr.RoundTrip(req)
  1780  		if err != wantErr {
  1781  			if res != nil {
  1782  				res.Body.Close()
  1783  			}
  1784  			t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
  1785  			return
  1786  		}
  1787  		if err == nil {
  1788  			if res == nil {
  1789  				t.Errorf("%v: response nil; want non-nil.", desc)
  1790  				return
  1791  			}
  1792  			defer res.Body.Close()
  1793  			if res.StatusCode != http.StatusOK {
  1794  				t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
  1795  			}
  1796  			return
  1797  		}
  1798  		if res != nil {
  1799  			t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
  1800  		}
  1801  	}
  1802  	headerListSizeForRequest := func(req *http.Request) (size uint64) {
  1803  		contentLen := actualContentLength(req)
  1804  		trailers, err := commaSeparatedTrailers(req)
  1805  		if err != nil {
  1806  			t.Fatalf("headerListSizeForRequest: %v", err)
  1807  		}
  1808  		cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
  1809  		cc.henc = hpack.NewEncoder(&cc.hbuf)
  1810  		cc.mu.Lock()
  1811  		hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
  1812  		cc.mu.Unlock()
  1813  		if err != nil {
  1814  			t.Fatalf("headerListSizeForRequest: %v", err)
  1815  		}
  1816  		hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
  1817  			size += uint64(hf.Size())
  1818  		})
  1819  		if len(hdrs) > 0 {
  1820  			if _, err := hpackDec.Write(hdrs); err != nil {
  1821  				t.Fatalf("headerListSizeForRequest: %v", err)
  1822  			}
  1823  		}
  1824  		return size
  1825  	}
  1826  	// Create a new Request for each test, rather than reusing the
  1827  	// same Request, to avoid a race when modifying req.Headers.
  1828  	// See https://github.com/golang/go/issues/21316
  1829  	newRequest := func() *http.Request {
  1830  		// Body must be non-nil to enable writing trailers.
  1831  		body := strings.NewReader("hello")
  1832  		req, err := http.NewRequest("POST", st.ts.URL, body)
  1833  		if err != nil {
  1834  			t.Fatalf("newRequest: NewRequest: %v", err)
  1835  		}
  1836  		return req
  1837  	}
  1838  
  1839  	// Validate peerMaxHeaderListSize.
  1840  	req := newRequest()
  1841  	checkRoundTrip(req, nil, "Initial request")
  1842  	addr := authorityAddr(req.URL.Scheme, req.URL.Host)
  1843  	cc, err := tr.connPool().GetClientConn(req, addr)
  1844  	if err != nil {
  1845  		t.Fatalf("GetClientConn: %v", err)
  1846  	}
  1847  	cc.mu.Lock()
  1848  	peerSize := cc.peerMaxHeaderListSize
  1849  	cc.mu.Unlock()
  1850  	st.scMu.Lock()
  1851  	wantSize := uint64(st.sc.maxHeaderListSize())
  1852  	st.scMu.Unlock()
  1853  	if peerSize != wantSize {
  1854  		t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
  1855  	}
  1856  
  1857  	// Sanity check peerSize. (*serverConn) maxHeaderListSize adds
  1858  	// 320 bytes of padding.
  1859  	wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
  1860  	if peerSize != wantHeaderBytes {
  1861  		t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
  1862  	}
  1863  
  1864  	// Pad headers & trailers, but stay under peerSize.
  1865  	req = newRequest()
  1866  	req.Header = make(http.Header)
  1867  	req.Trailer = make(http.Header)
  1868  	filler := strings.Repeat("*", 1024)
  1869  	padHeaders(t, req.Trailer, peerSize, filler)
  1870  	// cc.encodeHeaders adds some default headers to the request,
  1871  	// so we need to leave room for those.
  1872  	defaultBytes := headerListSizeForRequest(req)
  1873  	padHeaders(t, req.Header, peerSize-defaultBytes, filler)
  1874  	checkRoundTrip(req, nil, "Headers & Trailers under limit")
  1875  
  1876  	// Add enough header bytes to push us over peerSize.
  1877  	req = newRequest()
  1878  	req.Header = make(http.Header)
  1879  	padHeaders(t, req.Header, peerSize, filler)
  1880  	checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
  1881  
  1882  	// Push trailers over the limit.
  1883  	req = newRequest()
  1884  	req.Trailer = make(http.Header)
  1885  	padHeaders(t, req.Trailer, peerSize+1, filler)
  1886  	checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
  1887  
  1888  	// Send headers with a single large value.
  1889  	req = newRequest()
  1890  	filler = strings.Repeat("*", int(peerSize))
  1891  	req.Header = make(http.Header)
  1892  	req.Header.Set("Big", filler)
  1893  	checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
  1894  
  1895  	// Send trailers with a single large value.
  1896  	req = newRequest()
  1897  	req.Trailer = make(http.Header)
  1898  	req.Trailer.Set("Big", filler)
  1899  	checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
  1900  }
  1901  
  1902  func TestTransportChecksResponseHeaderListSize(t *testing.T) {
  1903  	ct := newClientTester(t)
  1904  	ct.client = func() error {
  1905  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1906  		res, err := ct.tr.RoundTrip(req)
  1907  		if e, ok := err.(StreamError); ok {
  1908  			err = e.Cause
  1909  		}
  1910  		if err != errResponseHeaderListSize {
  1911  			size := int64(0)
  1912  			if res != nil {
  1913  				res.Body.Close()
  1914  				for k, vv := range res.Header {
  1915  					for _, v := range vv {
  1916  						size += int64(len(k)) + int64(len(v)) + 32
  1917  					}
  1918  				}
  1919  			}
  1920  			return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
  1921  		}
  1922  		return nil
  1923  	}
  1924  	ct.server = func() error {
  1925  		ct.greet()
  1926  		var buf bytes.Buffer
  1927  		enc := hpack.NewEncoder(&buf)
  1928  
  1929  		for {
  1930  			f, err := ct.fr.ReadFrame()
  1931  			if err != nil {
  1932  				return err
  1933  			}
  1934  			switch f := f.(type) {
  1935  			case *HeadersFrame:
  1936  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1937  				large := strings.Repeat("a", 1<<10)
  1938  				for i := 0; i < 5042; i++ {
  1939  					enc.WriteField(hpack.HeaderField{Name: large, Value: large})
  1940  				}
  1941  				if size, want := buf.Len(), 6329; size != want {
  1942  					// Note: this number might change if
  1943  					// our hpack implementation
  1944  					// changes. That's fine. This is
  1945  					// just a sanity check that our
  1946  					// response can fit in a single
  1947  					// header block fragment frame.
  1948  					return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
  1949  				}
  1950  				ct.fr.WriteHeaders(HeadersFrameParam{
  1951  					StreamID:      f.StreamID,
  1952  					EndHeaders:    true,
  1953  					EndStream:     true,
  1954  					BlockFragment: buf.Bytes(),
  1955  				})
  1956  				return nil
  1957  			}
  1958  		}
  1959  	}
  1960  	ct.run()
  1961  }
  1962  
  1963  func TestTransportCookieHeaderSplit(t *testing.T) {
  1964  	ct := newClientTester(t)
  1965  	ct.client = func() error {
  1966  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1967  		req.Header.Add("Cookie", "a=b;c=d;  e=f;")
  1968  		req.Header.Add("Cookie", "e=f;g=h; ")
  1969  		req.Header.Add("Cookie", "i=j")
  1970  		_, err := ct.tr.RoundTrip(req)
  1971  		return err
  1972  	}
  1973  	ct.server = func() error {
  1974  		ct.greet()
  1975  		for {
  1976  			f, err := ct.fr.ReadFrame()
  1977  			if err != nil {
  1978  				return err
  1979  			}
  1980  			switch f := f.(type) {
  1981  			case *HeadersFrame:
  1982  				dec := hpack.NewDecoder(initialHeaderTableSize, nil)
  1983  				hfs, err := dec.DecodeFull(f.HeaderBlockFragment())
  1984  				if err != nil {
  1985  					return err
  1986  				}
  1987  				got := []string{}
  1988  				want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}
  1989  				for _, hf := range hfs {
  1990  					if hf.Name == "cookie" {
  1991  						got = append(got, hf.Value)
  1992  					}
  1993  				}
  1994  				if !reflect.DeepEqual(got, want) {
  1995  					t.Errorf("Cookies = %#v, want %#v", got, want)
  1996  				}
  1997  
  1998  				var buf bytes.Buffer
  1999  				enc := hpack.NewEncoder(&buf)
  2000  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2001  				ct.fr.WriteHeaders(HeadersFrameParam{
  2002  					StreamID:      f.StreamID,
  2003  					EndHeaders:    true,
  2004  					EndStream:     true,
  2005  					BlockFragment: buf.Bytes(),
  2006  				})
  2007  				return nil
  2008  			}
  2009  		}
  2010  	}
  2011  	ct.run()
  2012  }
  2013  
  2014  // Test that the Transport returns a typed error from Response.Body.Read calls
  2015  // when the server sends an error. (here we use a panic, since that should generate
  2016  // a stream error, but others like cancel should be similar)
  2017  func TestTransportBodyReadErrorType(t *testing.T) {
  2018  	doPanic := make(chan bool, 1)
  2019  	st := newServerTester(t,
  2020  		func(w http.ResponseWriter, r *http.Request) {
  2021  			w.(http.Flusher).Flush() // force headers out
  2022  			<-doPanic
  2023  			panic("boom")
  2024  		},
  2025  		optOnlyServer,
  2026  		optQuiet,
  2027  	)
  2028  	defer st.Close()
  2029  
  2030  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2031  	defer tr.CloseIdleConnections()
  2032  	c := &http.Client{Transport: tr}
  2033  
  2034  	res, err := c.Get(st.ts.URL)
  2035  	if err != nil {
  2036  		t.Fatal(err)
  2037  	}
  2038  	defer res.Body.Close()
  2039  	doPanic <- true
  2040  	buf := make([]byte, 100)
  2041  	n, err := res.Body.Read(buf)
  2042  	got, ok := err.(StreamError)
  2043  	want := StreamError{StreamID: 0x1, Code: 0x2}
  2044  	if !ok || got.StreamID != want.StreamID || got.Code != want.Code {
  2045  		t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
  2046  	}
  2047  }
  2048  
  2049  // golang.org/issue/13924
  2050  // This used to fail after many iterations, especially with -race:
  2051  // go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race
  2052  func TestTransportDoubleCloseOnWriteError(t *testing.T) {
  2053  	var (
  2054  		mu   sync.Mutex
  2055  		conn net.Conn // to close if set
  2056  	)
  2057  
  2058  	st := newServerTester(t,
  2059  		func(w http.ResponseWriter, r *http.Request) {
  2060  			mu.Lock()
  2061  			defer mu.Unlock()
  2062  			if conn != nil {
  2063  				conn.Close()
  2064  			}
  2065  		},
  2066  		optOnlyServer,
  2067  	)
  2068  	defer st.Close()
  2069  
  2070  	tr := &Transport{
  2071  		TLSClientConfig: tlsConfigInsecure,
  2072  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  2073  			tc, err := tls.Dial(network, addr, cfg)
  2074  			if err != nil {
  2075  				return nil, err
  2076  			}
  2077  			mu.Lock()
  2078  			defer mu.Unlock()
  2079  			conn = tc
  2080  			return tc, nil
  2081  		},
  2082  	}
  2083  	defer tr.CloseIdleConnections()
  2084  	c := &http.Client{Transport: tr}
  2085  	c.Get(st.ts.URL)
  2086  }
  2087  
  2088  // Test that the http1 Transport.DisableKeepAlives option is respected
  2089  // and connections are closed as soon as idle.
  2090  // See golang.org/issue/14008
  2091  func TestTransportDisableKeepAlives(t *testing.T) {
  2092  	st := newServerTester(t,
  2093  		func(w http.ResponseWriter, r *http.Request) {
  2094  			io.WriteString(w, "hi")
  2095  		},
  2096  		optOnlyServer,
  2097  	)
  2098  	defer st.Close()
  2099  
  2100  	connClosed := make(chan struct{}) // closed on tls.Conn.Close
  2101  	tr := &Transport{
  2102  		t1: &http.Transport{
  2103  			DisableKeepAlives: true,
  2104  		},
  2105  		TLSClientConfig: tlsConfigInsecure,
  2106  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  2107  			tc, err := tls.Dial(network, addr, cfg)
  2108  			if err != nil {
  2109  				return nil, err
  2110  			}
  2111  			return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
  2112  		},
  2113  	}
  2114  	c := &http.Client{Transport: tr}
  2115  	res, err := c.Get(st.ts.URL)
  2116  	if err != nil {
  2117  		t.Fatal(err)
  2118  	}
  2119  	if _, err := ioutil.ReadAll(res.Body); err != nil {
  2120  		t.Fatal(err)
  2121  	}
  2122  	defer res.Body.Close()
  2123  
  2124  	select {
  2125  	case <-connClosed:
  2126  	case <-time.After(1 * time.Second):
  2127  		t.Errorf("timeout")
  2128  	}
  2129  
  2130  }
  2131  
  2132  // Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
  2133  // but when things are totally idle, it still needs to close.
  2134  func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
  2135  	const D = 25 * time.Millisecond
  2136  	st := newServerTester(t,
  2137  		func(w http.ResponseWriter, r *http.Request) {
  2138  			time.Sleep(D)
  2139  			io.WriteString(w, "hi")
  2140  		},
  2141  		optOnlyServer,
  2142  	)
  2143  	defer st.Close()
  2144  
  2145  	var dials int32
  2146  	var conns sync.WaitGroup
  2147  	tr := &Transport{
  2148  		t1: &http.Transport{
  2149  			DisableKeepAlives: true,
  2150  		},
  2151  		TLSClientConfig: tlsConfigInsecure,
  2152  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  2153  			tc, err := tls.Dial(network, addr, cfg)
  2154  			if err != nil {
  2155  				return nil, err
  2156  			}
  2157  			atomic.AddInt32(&dials, 1)
  2158  			conns.Add(1)
  2159  			return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
  2160  		},
  2161  	}
  2162  	c := &http.Client{Transport: tr}
  2163  	var reqs sync.WaitGroup
  2164  	const N = 20
  2165  	for i := 0; i < N; i++ {
  2166  		reqs.Add(1)
  2167  		if i == N-1 {
  2168  			// For the final request, try to make all the
  2169  			// others close. This isn't verified in the
  2170  			// count, other than the Log statement, since
  2171  			// it's so timing dependent. This test is
  2172  			// really to make sure we don't interrupt a
  2173  			// valid request.
  2174  			time.Sleep(D * 2)
  2175  		}
  2176  		go func() {
  2177  			defer reqs.Done()
  2178  			res, err := c.Get(st.ts.URL)
  2179  			if err != nil {
  2180  				t.Error(err)
  2181  				return
  2182  			}
  2183  			if _, err := ioutil.ReadAll(res.Body); err != nil {
  2184  				t.Error(err)
  2185  				return
  2186  			}
  2187  			res.Body.Close()
  2188  		}()
  2189  	}
  2190  	reqs.Wait()
  2191  	conns.Wait()
  2192  	t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
  2193  }
  2194  
  2195  type noteCloseConn struct {
  2196  	net.Conn
  2197  	onceClose sync.Once
  2198  	closefn   func()
  2199  }
  2200  
  2201  func (c *noteCloseConn) Close() error {
  2202  	c.onceClose.Do(c.closefn)
  2203  	return c.Conn.Close()
  2204  }
  2205  
  2206  func isTimeout(err error) bool {
  2207  	switch err := err.(type) {
  2208  	case nil:
  2209  		return false
  2210  	case *url.Error:
  2211  		return isTimeout(err.Err)
  2212  	case net.Error:
  2213  		return err.Timeout()
  2214  	}
  2215  	return false
  2216  }
  2217  
  2218  // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
  2219  func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
  2220  	testTransportResponseHeaderTimeout(t, false)
  2221  }
  2222  func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
  2223  	testTransportResponseHeaderTimeout(t, true)
  2224  }
  2225  
  2226  func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
  2227  	ct := newClientTester(t)
  2228  	ct.tr.t1 = &http.Transport{
  2229  		ResponseHeaderTimeout: 5 * time.Millisecond,
  2230  	}
  2231  	ct.client = func() error {
  2232  		c := &http.Client{Transport: ct.tr}
  2233  		var err error
  2234  		var n int64
  2235  		const bodySize = 4 << 20
  2236  		if body {
  2237  			_, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
  2238  		} else {
  2239  			_, err = c.Get("https://dummy.tld/")
  2240  		}
  2241  		if !isTimeout(err) {
  2242  			t.Errorf("client expected timeout error; got %#v", err)
  2243  		}
  2244  		if body && n != bodySize {
  2245  			t.Errorf("only read %d bytes of body; want %d", n, bodySize)
  2246  		}
  2247  		return nil
  2248  	}
  2249  	ct.server = func() error {
  2250  		ct.greet()
  2251  		for {
  2252  			f, err := ct.fr.ReadFrame()
  2253  			if err != nil {
  2254  				t.Logf("ReadFrame: %v", err)
  2255  				return nil
  2256  			}
  2257  			switch f := f.(type) {
  2258  			case *DataFrame:
  2259  				dataLen := len(f.Data())
  2260  				if dataLen > 0 {
  2261  					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
  2262  						return err
  2263  					}
  2264  					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
  2265  						return err
  2266  					}
  2267  				}
  2268  			case *RSTStreamFrame:
  2269  				if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
  2270  					return nil
  2271  				}
  2272  			}
  2273  		}
  2274  	}
  2275  	ct.run()
  2276  }
  2277  
  2278  func TestTransportDisableCompression(t *testing.T) {
  2279  	const body = "sup"
  2280  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2281  		want := http.Header{
  2282  			"User-Agent": []string{"Go-http-client/2.0"},
  2283  		}
  2284  		if !reflect.DeepEqual(r.Header, want) {
  2285  			t.Errorf("request headers = %v; want %v", r.Header, want)
  2286  		}
  2287  	}, optOnlyServer)
  2288  	defer st.Close()
  2289  
  2290  	tr := &Transport{
  2291  		TLSClientConfig: tlsConfigInsecure,
  2292  		t1: &http.Transport{
  2293  			DisableCompression: true,
  2294  		},
  2295  	}
  2296  	defer tr.CloseIdleConnections()
  2297  
  2298  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  2299  	if err != nil {
  2300  		t.Fatal(err)
  2301  	}
  2302  	res, err := tr.RoundTrip(req)
  2303  	if err != nil {
  2304  		t.Fatal(err)
  2305  	}
  2306  	defer res.Body.Close()
  2307  }
  2308  
  2309  // RFC 7540 section 8.1.2.2
  2310  func TestTransportRejectsConnHeaders(t *testing.T) {
  2311  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2312  		var got []string
  2313  		for k := range r.Header {
  2314  			got = append(got, k)
  2315  		}
  2316  		sort.Strings(got)
  2317  		w.Header().Set("Got-Header", strings.Join(got, ","))
  2318  	}, optOnlyServer)
  2319  	defer st.Close()
  2320  
  2321  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2322  	defer tr.CloseIdleConnections()
  2323  
  2324  	tests := []struct {
  2325  		key   string
  2326  		value []string
  2327  		want  string
  2328  	}{
  2329  		{
  2330  			key:   "Upgrade",
  2331  			value: []string{"anything"},
  2332  			want:  "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
  2333  		},
  2334  		{
  2335  			key:   "Connection",
  2336  			value: []string{"foo"},
  2337  			want:  "ERROR: http2: invalid Connection request header: [\"foo\"]",
  2338  		},
  2339  		{
  2340  			key:   "Connection",
  2341  			value: []string{"close"},
  2342  			want:  "Accept-Encoding,User-Agent",
  2343  		},
  2344  		{
  2345  			key:   "Connection",
  2346  			value: []string{"CLoSe"},
  2347  			want:  "Accept-Encoding,User-Agent",
  2348  		},
  2349  		{
  2350  			key:   "Connection",
  2351  			value: []string{"close", "something-else"},
  2352  			want:  "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
  2353  		},
  2354  		{
  2355  			key:   "Connection",
  2356  			value: []string{"keep-alive"},
  2357  			want:  "Accept-Encoding,User-Agent",
  2358  		},
  2359  		{
  2360  			key:   "Connection",
  2361  			value: []string{"Keep-ALIVE"},
  2362  			want:  "Accept-Encoding,User-Agent",
  2363  		},
  2364  		{
  2365  			key:   "Proxy-Connection", // just deleted and ignored
  2366  			value: []string{"keep-alive"},
  2367  			want:  "Accept-Encoding,User-Agent",
  2368  		},
  2369  		{
  2370  			key:   "Transfer-Encoding",
  2371  			value: []string{""},
  2372  			want:  "Accept-Encoding,User-Agent",
  2373  		},
  2374  		{
  2375  			key:   "Transfer-Encoding",
  2376  			value: []string{"foo"},
  2377  			want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
  2378  		},
  2379  		{
  2380  			key:   "Transfer-Encoding",
  2381  			value: []string{"chunked"},
  2382  			want:  "Accept-Encoding,User-Agent",
  2383  		},
  2384  		{
  2385  			key:   "Transfer-Encoding",
  2386  			value: []string{"chunKed"}, // Kelvin sign
  2387  			want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]",
  2388  		},
  2389  		{
  2390  			key:   "Transfer-Encoding",
  2391  			value: []string{"chunked", "other"},
  2392  			want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
  2393  		},
  2394  		{
  2395  			key:   "Content-Length",
  2396  			value: []string{"123"},
  2397  			want:  "Accept-Encoding,User-Agent",
  2398  		},
  2399  		{
  2400  			key:   "Keep-Alive",
  2401  			value: []string{"doop"},
  2402  			want:  "Accept-Encoding,User-Agent",
  2403  		},
  2404  	}
  2405  
  2406  	for _, tt := range tests {
  2407  		req, _ := http.NewRequest("GET", st.ts.URL, nil)
  2408  		req.Header[tt.key] = tt.value
  2409  		res, err := tr.RoundTrip(req)
  2410  		var got string
  2411  		if err != nil {
  2412  			got = fmt.Sprintf("ERROR: %v", err)
  2413  		} else {
  2414  			got = res.Header.Get("Got-Header")
  2415  			res.Body.Close()
  2416  		}
  2417  		if got != tt.want {
  2418  			t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
  2419  		}
  2420  	}
  2421  }
  2422  
  2423  // Reject content-length headers containing a sign.
  2424  // See https://golang.org/issue/39017
  2425  func TestTransportRejectsContentLengthWithSign(t *testing.T) {
  2426  	tests := []struct {
  2427  		name   string
  2428  		cl     []string
  2429  		wantCL string
  2430  	}{
  2431  		{
  2432  			name:   "proper content-length",
  2433  			cl:     []string{"3"},
  2434  			wantCL: "3",
  2435  		},
  2436  		{
  2437  			name:   "ignore cl with plus sign",
  2438  			cl:     []string{"+3"},
  2439  			wantCL: "",
  2440  		},
  2441  		{
  2442  			name:   "ignore cl with minus sign",
  2443  			cl:     []string{"-3"},
  2444  			wantCL: "",
  2445  		},
  2446  		{
  2447  			name:   "max int64, for safe uint64->int64 conversion",
  2448  			cl:     []string{"9223372036854775807"},
  2449  			wantCL: "9223372036854775807",
  2450  		},
  2451  		{
  2452  			name:   "overflows int64, so ignored",
  2453  			cl:     []string{"9223372036854775808"},
  2454  			wantCL: "",
  2455  		},
  2456  	}
  2457  
  2458  	for _, tt := range tests {
  2459  		tt := tt
  2460  		t.Run(tt.name, func(t *testing.T) {
  2461  			st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2462  				w.Header().Set("Content-Length", tt.cl[0])
  2463  			}, optOnlyServer)
  2464  			defer st.Close()
  2465  			tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2466  			defer tr.CloseIdleConnections()
  2467  
  2468  			req, _ := http.NewRequest("HEAD", st.ts.URL, nil)
  2469  			res, err := tr.RoundTrip(req)
  2470  
  2471  			var got string
  2472  			if err != nil {
  2473  				got = fmt.Sprintf("ERROR: %v", err)
  2474  			} else {
  2475  				got = res.Header.Get("Content-Length")
  2476  				res.Body.Close()
  2477  			}
  2478  
  2479  			if got != tt.wantCL {
  2480  				t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL)
  2481  			}
  2482  		})
  2483  	}
  2484  }
  2485  
  2486  // golang.org/issue/14048
  2487  func TestTransportFailsOnInvalidHeaders(t *testing.T) {
  2488  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2489  		var got []string
  2490  		for k := range r.Header {
  2491  			got = append(got, k)
  2492  		}
  2493  		sort.Strings(got)
  2494  		w.Header().Set("Got-Header", strings.Join(got, ","))
  2495  	}, optOnlyServer)
  2496  	defer st.Close()
  2497  
  2498  	tests := [...]struct {
  2499  		h       http.Header
  2500  		wantErr string
  2501  	}{
  2502  		0: {
  2503  			h:       http.Header{"with space": {"foo"}},
  2504  			wantErr: `invalid HTTP header name "with space"`,
  2505  		},
  2506  		1: {
  2507  			h:       http.Header{"name": {"Брэд"}},
  2508  			wantErr: "", // okay
  2509  		},
  2510  		2: {
  2511  			h:       http.Header{"имя": {"Brad"}},
  2512  			wantErr: `invalid HTTP header name "имя"`,
  2513  		},
  2514  		3: {
  2515  			h:       http.Header{"foo": {"foo\x01bar"}},
  2516  			wantErr: `invalid HTTP header value for header "foo"`,
  2517  		},
  2518  	}
  2519  
  2520  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2521  	defer tr.CloseIdleConnections()
  2522  
  2523  	for i, tt := range tests {
  2524  		req, _ := http.NewRequest("GET", st.ts.URL, nil)
  2525  		req.Header = tt.h
  2526  		res, err := tr.RoundTrip(req)
  2527  		var bad bool
  2528  		if tt.wantErr == "" {
  2529  			if err != nil {
  2530  				bad = true
  2531  				t.Errorf("case %d: error = %v; want no error", i, err)
  2532  			}
  2533  		} else {
  2534  			if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
  2535  				bad = true
  2536  				t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
  2537  			}
  2538  		}
  2539  		if err == nil {
  2540  			if bad {
  2541  				t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
  2542  			}
  2543  			res.Body.Close()
  2544  		}
  2545  	}
  2546  }
  2547  
  2548  // Tests that gzipReader doesn't crash on a second Read call following
  2549  // the first Read call's gzip.NewReader returning an error.
  2550  func TestGzipReader_DoubleReadCrash(t *testing.T) {
  2551  	gz := &gzipReader{
  2552  		body: ioutil.NopCloser(strings.NewReader("0123456789")),
  2553  	}
  2554  	var buf [1]byte
  2555  	n, err1 := gz.Read(buf[:])
  2556  	if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
  2557  		t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
  2558  	}
  2559  	n, err2 := gz.Read(buf[:])
  2560  	if n != 0 || err2 != err1 {
  2561  		t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
  2562  	}
  2563  }
  2564  
  2565  func TestGzipReader_ReadAfterClose(t *testing.T) {
  2566  	body := bytes.Buffer{}
  2567  	w := gzip.NewWriter(&body)
  2568  	w.Write([]byte("012345679"))
  2569  	w.Close()
  2570  	gz := &gzipReader{
  2571  		body: ioutil.NopCloser(&body),
  2572  	}
  2573  	var buf [1]byte
  2574  	n, err := gz.Read(buf[:])
  2575  	if n != 1 || err != nil {
  2576  		t.Fatalf("first Read = %v, %v; want 1, nil", n, err)
  2577  	}
  2578  	if err := gz.Close(); err != nil {
  2579  		t.Fatalf("gz Close error: %v", err)
  2580  	}
  2581  	n, err = gz.Read(buf[:])
  2582  	if n != 0 || err != fs.ErrClosed {
  2583  		t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err)
  2584  	}
  2585  }
  2586  
  2587  func TestTransportNewTLSConfig(t *testing.T) {
  2588  	tests := [...]struct {
  2589  		conf *tls.Config
  2590  		host string
  2591  		want *tls.Config
  2592  	}{
  2593  		// Normal case.
  2594  		0: {
  2595  			conf: nil,
  2596  			host: "foo.com",
  2597  			want: &tls.Config{
  2598  				ServerName: "foo.com",
  2599  				NextProtos: []string{NextProtoTLS},
  2600  			},
  2601  		},
  2602  
  2603  		// User-provided name (bar.com) takes precedence:
  2604  		1: {
  2605  			conf: &tls.Config{
  2606  				ServerName: "bar.com",
  2607  			},
  2608  			host: "foo.com",
  2609  			want: &tls.Config{
  2610  				ServerName: "bar.com",
  2611  				NextProtos: []string{NextProtoTLS},
  2612  			},
  2613  		},
  2614  
  2615  		// NextProto is prepended:
  2616  		2: {
  2617  			conf: &tls.Config{
  2618  				NextProtos: []string{"foo", "bar"},
  2619  			},
  2620  			host: "example.com",
  2621  			want: &tls.Config{
  2622  				ServerName: "example.com",
  2623  				NextProtos: []string{NextProtoTLS, "foo", "bar"},
  2624  			},
  2625  		},
  2626  
  2627  		// NextProto is not duplicated:
  2628  		3: {
  2629  			conf: &tls.Config{
  2630  				NextProtos: []string{"foo", "bar", NextProtoTLS},
  2631  			},
  2632  			host: "example.com",
  2633  			want: &tls.Config{
  2634  				ServerName: "example.com",
  2635  				NextProtos: []string{"foo", "bar", NextProtoTLS},
  2636  			},
  2637  		},
  2638  	}
  2639  	for i, tt := range tests {
  2640  		// Ignore the session ticket keys part, which ends up populating
  2641  		// unexported fields in the Config:
  2642  		if tt.conf != nil {
  2643  			tt.conf.SessionTicketsDisabled = true
  2644  		}
  2645  
  2646  		tr := &Transport{TLSClientConfig: tt.conf}
  2647  		got := tr.newTLSConfig(tt.host)
  2648  
  2649  		got.SessionTicketsDisabled = false
  2650  
  2651  		if !reflect.DeepEqual(got, tt.want) {
  2652  			t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
  2653  		}
  2654  	}
  2655  }
  2656  
  2657  // The Google GFE responds to HEAD requests with a HEADERS frame
  2658  // without END_STREAM, followed by a 0-length DATA frame with
  2659  // END_STREAM. Make sure we don't get confused by that. (We did.)
  2660  func TestTransportReadHeadResponse(t *testing.T) {
  2661  	ct := newClientTester(t)
  2662  	clientDone := make(chan struct{})
  2663  	ct.client = func() error {
  2664  		defer close(clientDone)
  2665  		req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
  2666  		res, err := ct.tr.RoundTrip(req)
  2667  		if err != nil {
  2668  			return err
  2669  		}
  2670  		if res.ContentLength != 123 {
  2671  			return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
  2672  		}
  2673  		slurp, err := ioutil.ReadAll(res.Body)
  2674  		if err != nil {
  2675  			return fmt.Errorf("ReadAll: %v", err)
  2676  		}
  2677  		if len(slurp) > 0 {
  2678  			return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
  2679  		}
  2680  		return nil
  2681  	}
  2682  	ct.server = func() error {
  2683  		ct.greet()
  2684  		for {
  2685  			f, err := ct.fr.ReadFrame()
  2686  			if err != nil {
  2687  				t.Logf("ReadFrame: %v", err)
  2688  				return nil
  2689  			}
  2690  			hf, ok := f.(*HeadersFrame)
  2691  			if !ok {
  2692  				continue
  2693  			}
  2694  			var buf bytes.Buffer
  2695  			enc := hpack.NewEncoder(&buf)
  2696  			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2697  			enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
  2698  			ct.fr.WriteHeaders(HeadersFrameParam{
  2699  				StreamID:      hf.StreamID,
  2700  				EndHeaders:    true,
  2701  				EndStream:     false, // as the GFE does
  2702  				BlockFragment: buf.Bytes(),
  2703  			})
  2704  			ct.fr.WriteData(hf.StreamID, true, nil)
  2705  
  2706  			<-clientDone
  2707  			return nil
  2708  		}
  2709  	}
  2710  	ct.run()
  2711  }
  2712  
  2713  func TestTransportReadHeadResponseWithBody(t *testing.T) {
  2714  	// This test use not valid response format.
  2715  	// Discarding logger output to not spam tests output.
  2716  	log.SetOutput(ioutil.Discard)
  2717  	defer log.SetOutput(os.Stderr)
  2718  
  2719  	response := "redirecting to /elsewhere"
  2720  	ct := newClientTester(t)
  2721  	clientDone := make(chan struct{})
  2722  	ct.client = func() error {
  2723  		defer close(clientDone)
  2724  		req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
  2725  		res, err := ct.tr.RoundTrip(req)
  2726  		if err != nil {
  2727  			return err
  2728  		}
  2729  		if res.ContentLength != int64(len(response)) {
  2730  			return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
  2731  		}
  2732  		slurp, err := ioutil.ReadAll(res.Body)
  2733  		if err != nil {
  2734  			return fmt.Errorf("ReadAll: %v", err)
  2735  		}
  2736  		if len(slurp) > 0 {
  2737  			return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
  2738  		}
  2739  		return nil
  2740  	}
  2741  	ct.server = func() error {
  2742  		ct.greet()
  2743  		for {
  2744  			f, err := ct.fr.ReadFrame()
  2745  			if err != nil {
  2746  				t.Logf("ReadFrame: %v", err)
  2747  				return nil
  2748  			}
  2749  			hf, ok := f.(*HeadersFrame)
  2750  			if !ok {
  2751  				continue
  2752  			}
  2753  			var buf bytes.Buffer
  2754  			enc := hpack.NewEncoder(&buf)
  2755  			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2756  			enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
  2757  			ct.fr.WriteHeaders(HeadersFrameParam{
  2758  				StreamID:      hf.StreamID,
  2759  				EndHeaders:    true,
  2760  				EndStream:     false,
  2761  				BlockFragment: buf.Bytes(),
  2762  			})
  2763  			ct.fr.WriteData(hf.StreamID, true, []byte(response))
  2764  
  2765  			<-clientDone
  2766  			return nil
  2767  		}
  2768  	}
  2769  	ct.run()
  2770  }
  2771  
  2772  type neverEnding byte
  2773  
  2774  func (b neverEnding) Read(p []byte) (int, error) {
  2775  	for i := range p {
  2776  		p[i] = byte(b)
  2777  	}
  2778  	return len(p), nil
  2779  }
  2780  
  2781  // golang.org/issue/15425: test that a handler closing the request
  2782  // body doesn't terminate the stream to the peer. (It just stops
  2783  // readability from the handler's side, and eventually the client
  2784  // runs out of flow control tokens)
  2785  func TestTransportHandlerBodyClose(t *testing.T) {
  2786  	const bodySize = 10 << 20
  2787  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2788  		r.Body.Close()
  2789  		io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
  2790  	}, optOnlyServer)
  2791  	defer st.Close()
  2792  
  2793  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2794  	defer tr.CloseIdleConnections()
  2795  
  2796  	g0 := runtime.NumGoroutine()
  2797  
  2798  	const numReq = 10
  2799  	for i := 0; i < numReq; i++ {
  2800  		req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
  2801  		if err != nil {
  2802  			t.Fatal(err)
  2803  		}
  2804  		res, err := tr.RoundTrip(req)
  2805  		if err != nil {
  2806  			t.Fatal(err)
  2807  		}
  2808  		n, err := io.Copy(ioutil.Discard, res.Body)
  2809  		res.Body.Close()
  2810  		if n != bodySize || err != nil {
  2811  			t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
  2812  		}
  2813  	}
  2814  	tr.CloseIdleConnections()
  2815  
  2816  	if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool {
  2817  		gd := runtime.NumGoroutine() - g0
  2818  		return gd < numReq/2
  2819  	}) {
  2820  		t.Errorf("appeared to leak goroutines")
  2821  	}
  2822  }
  2823  
  2824  // https://golang.org/issue/15930
  2825  func TestTransportFlowControl(t *testing.T) {
  2826  	const bufLen = 64 << 10
  2827  	var total int64 = 100 << 20 // 100MB
  2828  	if testing.Short() {
  2829  		total = 10 << 20
  2830  	}
  2831  
  2832  	var wrote int64 // updated atomically
  2833  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2834  		b := make([]byte, bufLen)
  2835  		for wrote < total {
  2836  			n, err := w.Write(b)
  2837  			atomic.AddInt64(&wrote, int64(n))
  2838  			if err != nil {
  2839  				t.Errorf("ResponseWriter.Write error: %v", err)
  2840  				break
  2841  			}
  2842  			w.(http.Flusher).Flush()
  2843  		}
  2844  	}, optOnlyServer)
  2845  
  2846  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2847  	defer tr.CloseIdleConnections()
  2848  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  2849  	if err != nil {
  2850  		t.Fatal("NewRequest error:", err)
  2851  	}
  2852  	resp, err := tr.RoundTrip(req)
  2853  	if err != nil {
  2854  		t.Fatal("RoundTrip error:", err)
  2855  	}
  2856  	defer resp.Body.Close()
  2857  
  2858  	var read int64
  2859  	b := make([]byte, bufLen)
  2860  	for {
  2861  		n, err := resp.Body.Read(b)
  2862  		if err == io.EOF {
  2863  			break
  2864  		}
  2865  		if err != nil {
  2866  			t.Fatal("Read error:", err)
  2867  		}
  2868  		read += int64(n)
  2869  
  2870  		const max = transportDefaultStreamFlow
  2871  		if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max {
  2872  			t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read)
  2873  		}
  2874  
  2875  		// Let the server get ahead of the client.
  2876  		time.Sleep(1 * time.Millisecond)
  2877  	}
  2878  }
  2879  
  2880  // golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
  2881  // the Transport remember it and return it back to users (via
  2882  // RoundTrip or request body reads) if needed (e.g. if the server
  2883  // proceeds to close the TCP connection before the client gets its
  2884  // response)
  2885  func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
  2886  	testTransportUsesGoAwayDebugError(t, false)
  2887  }
  2888  
  2889  func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
  2890  	testTransportUsesGoAwayDebugError(t, true)
  2891  }
  2892  
  2893  func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
  2894  	ct := newClientTester(t)
  2895  	clientDone := make(chan struct{})
  2896  
  2897  	const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
  2898  	const goAwayDebugData = "some debug data"
  2899  
  2900  	ct.client = func() error {
  2901  		defer close(clientDone)
  2902  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  2903  		res, err := ct.tr.RoundTrip(req)
  2904  		if failMidBody {
  2905  			if err != nil {
  2906  				return fmt.Errorf("unexpected client RoundTrip error: %v", err)
  2907  			}
  2908  			_, err = io.Copy(ioutil.Discard, res.Body)
  2909  			res.Body.Close()
  2910  		}
  2911  		want := GoAwayError{
  2912  			LastStreamID: 5,
  2913  			ErrCode:      goAwayErrCode,
  2914  			DebugData:    goAwayDebugData,
  2915  		}
  2916  		if !reflect.DeepEqual(err, want) {
  2917  			t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
  2918  		}
  2919  		return nil
  2920  	}
  2921  	ct.server = func() error {
  2922  		ct.greet()
  2923  		for {
  2924  			f, err := ct.fr.ReadFrame()
  2925  			if err != nil {
  2926  				t.Logf("ReadFrame: %v", err)
  2927  				return nil
  2928  			}
  2929  			hf, ok := f.(*HeadersFrame)
  2930  			if !ok {
  2931  				continue
  2932  			}
  2933  			if failMidBody {
  2934  				var buf bytes.Buffer
  2935  				enc := hpack.NewEncoder(&buf)
  2936  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2937  				enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
  2938  				ct.fr.WriteHeaders(HeadersFrameParam{
  2939  					StreamID:      hf.StreamID,
  2940  					EndHeaders:    true,
  2941  					EndStream:     false,
  2942  					BlockFragment: buf.Bytes(),
  2943  				})
  2944  			}
  2945  			// Write two GOAWAY frames, to test that the Transport takes
  2946  			// the interesting parts of both.
  2947  			ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
  2948  			ct.fr.WriteGoAway(5, goAwayErrCode, nil)
  2949  			ct.sc.(*net.TCPConn).CloseWrite()
  2950  			if runtime.GOOS == "plan9" {
  2951  				// CloseWrite not supported on Plan 9; Issue 17906
  2952  				ct.sc.(*net.TCPConn).Close()
  2953  			}
  2954  			<-clientDone
  2955  			return nil
  2956  		}
  2957  	}
  2958  	ct.run()
  2959  }
  2960  
  2961  func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
  2962  	ct := newClientTester(t)
  2963  
  2964  	ct.client = func() error {
  2965  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  2966  		res, err := ct.tr.RoundTrip(req)
  2967  		if err != nil {
  2968  			return err
  2969  		}
  2970  
  2971  		if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
  2972  			return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
  2973  		}
  2974  		res.Body.Close() // leaving 4999 bytes unread
  2975  
  2976  		return nil
  2977  	}
  2978  	ct.server = func() error {
  2979  		ct.greet()
  2980  
  2981  		var hf *HeadersFrame
  2982  		for {
  2983  			f, err := ct.fr.ReadFrame()
  2984  			if err != nil {
  2985  				return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  2986  			}
  2987  			switch f.(type) {
  2988  			case *WindowUpdateFrame, *SettingsFrame:
  2989  				continue
  2990  			}
  2991  			var ok bool
  2992  			hf, ok = f.(*HeadersFrame)
  2993  			if !ok {
  2994  				return fmt.Errorf("Got %T; want HeadersFrame", f)
  2995  			}
  2996  			break
  2997  		}
  2998  
  2999  		var buf bytes.Buffer
  3000  		enc := hpack.NewEncoder(&buf)
  3001  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3002  		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
  3003  		ct.fr.WriteHeaders(HeadersFrameParam{
  3004  			StreamID:      hf.StreamID,
  3005  			EndHeaders:    true,
  3006  			EndStream:     false,
  3007  			BlockFragment: buf.Bytes(),
  3008  		})
  3009  		initialInflow := ct.inflowWindow(0)
  3010  
  3011  		// Two cases:
  3012  		// - Send one DATA frame with 5000 bytes.
  3013  		// - Send two DATA frames with 1 and 4999 bytes each.
  3014  		//
  3015  		// In both cases, the client should consume one byte of data,
  3016  		// refund that byte, then refund the following 4999 bytes.
  3017  		//
  3018  		// In the second case, the server waits for the client to reset the
  3019  		// stream before sending the second DATA frame. This tests the case
  3020  		// where the client receives a DATA frame after it has reset the stream.
  3021  		if oneDataFrame {
  3022  			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
  3023  		} else {
  3024  			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
  3025  		}
  3026  
  3027  		wantRST := true
  3028  		wantWUF := true
  3029  		if !oneDataFrame {
  3030  			wantWUF = false // flow control update is small, and will not be sent
  3031  		}
  3032  		for wantRST || wantWUF {
  3033  			f, err := ct.readNonSettingsFrame()
  3034  			if err != nil {
  3035  				return err
  3036  			}
  3037  			switch f := f.(type) {
  3038  			case *RSTStreamFrame:
  3039  				if !wantRST {
  3040  					return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
  3041  				}
  3042  				if f.ErrCode != ErrCodeCancel {
  3043  					return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
  3044  				}
  3045  				wantRST = false
  3046  			case *WindowUpdateFrame:
  3047  				if !wantWUF {
  3048  					return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
  3049  				}
  3050  				if f.Increment != 5000 {
  3051  					return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
  3052  				}
  3053  				wantWUF = false
  3054  			default:
  3055  				return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
  3056  			}
  3057  		}
  3058  		if !oneDataFrame {
  3059  			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
  3060  			f, err := ct.readNonSettingsFrame()
  3061  			if err != nil {
  3062  				return err
  3063  			}
  3064  			wuf, ok := f.(*WindowUpdateFrame)
  3065  			if !ok || wuf.Increment != 5000 {
  3066  				return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f))
  3067  			}
  3068  		}
  3069  		if err := ct.writeReadPing(); err != nil {
  3070  			return err
  3071  		}
  3072  		if got, want := ct.inflowWindow(0), initialInflow; got != want {
  3073  			return fmt.Errorf("connection flow tokens = %v, want %v", got, want)
  3074  		}
  3075  		return nil
  3076  	}
  3077  	ct.run()
  3078  }
  3079  
  3080  // See golang.org/issue/16481
  3081  func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
  3082  	testTransportReturnsUnusedFlowControl(t, true)
  3083  }
  3084  
  3085  // See golang.org/issue/20469
  3086  func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
  3087  	testTransportReturnsUnusedFlowControl(t, false)
  3088  }
  3089  
  3090  // Issue 16612: adjust flow control on open streams when transport
  3091  // receives SETTINGS with INITIAL_WINDOW_SIZE from server.
  3092  func TestTransportAdjustsFlowControl(t *testing.T) {
  3093  	ct := newClientTester(t)
  3094  	clientDone := make(chan struct{})
  3095  
  3096  	const bodySize = 1 << 20
  3097  
  3098  	ct.client = func() error {
  3099  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3100  		if runtime.GOOS == "plan9" {
  3101  			// CloseWrite not supported on Plan 9; Issue 17906
  3102  			defer ct.cc.(*net.TCPConn).Close()
  3103  		}
  3104  		defer close(clientDone)
  3105  
  3106  		req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
  3107  		res, err := ct.tr.RoundTrip(req)
  3108  		if err != nil {
  3109  			return err
  3110  		}
  3111  		res.Body.Close()
  3112  		return nil
  3113  	}
  3114  	ct.server = func() error {
  3115  		_, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
  3116  		if err != nil {
  3117  			return fmt.Errorf("reading client preface: %v", err)
  3118  		}
  3119  
  3120  		var gotBytes int64
  3121  		var sentSettings bool
  3122  		for {
  3123  			f, err := ct.fr.ReadFrame()
  3124  			if err != nil {
  3125  				select {
  3126  				case <-clientDone:
  3127  					return nil
  3128  				default:
  3129  					return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  3130  				}
  3131  			}
  3132  			switch f := f.(type) {
  3133  			case *DataFrame:
  3134  				gotBytes += int64(len(f.Data()))
  3135  				// After we've got half the client's
  3136  				// initial flow control window's worth
  3137  				// of request body data, give it just
  3138  				// enough flow control to finish.
  3139  				if gotBytes >= initialWindowSize/2 && !sentSettings {
  3140  					sentSettings = true
  3141  
  3142  					ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
  3143  					ct.fr.WriteWindowUpdate(0, bodySize)
  3144  					ct.fr.WriteSettingsAck()
  3145  				}
  3146  
  3147  				if f.StreamEnded() {
  3148  					var buf bytes.Buffer
  3149  					enc := hpack.NewEncoder(&buf)
  3150  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3151  					ct.fr.WriteHeaders(HeadersFrameParam{
  3152  						StreamID:      f.StreamID,
  3153  						EndHeaders:    true,
  3154  						EndStream:     true,
  3155  						BlockFragment: buf.Bytes(),
  3156  					})
  3157  				}
  3158  			}
  3159  		}
  3160  	}
  3161  	ct.run()
  3162  }
  3163  
  3164  // See golang.org/issue/16556
  3165  func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
  3166  	ct := newClientTester(t)
  3167  
  3168  	unblockClient := make(chan bool, 1)
  3169  
  3170  	ct.client = func() error {
  3171  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3172  		res, err := ct.tr.RoundTrip(req)
  3173  		if err != nil {
  3174  			return err
  3175  		}
  3176  		defer res.Body.Close()
  3177  		<-unblockClient
  3178  		return nil
  3179  	}
  3180  	ct.server = func() error {
  3181  		ct.greet()
  3182  
  3183  		var hf *HeadersFrame
  3184  		for {
  3185  			f, err := ct.fr.ReadFrame()
  3186  			if err != nil {
  3187  				return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  3188  			}
  3189  			switch f.(type) {
  3190  			case *WindowUpdateFrame, *SettingsFrame:
  3191  				continue
  3192  			}
  3193  			var ok bool
  3194  			hf, ok = f.(*HeadersFrame)
  3195  			if !ok {
  3196  				return fmt.Errorf("Got %T; want HeadersFrame", f)
  3197  			}
  3198  			break
  3199  		}
  3200  
  3201  		initialConnWindow := ct.inflowWindow(0)
  3202  
  3203  		var buf bytes.Buffer
  3204  		enc := hpack.NewEncoder(&buf)
  3205  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3206  		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
  3207  		ct.fr.WriteHeaders(HeadersFrameParam{
  3208  			StreamID:      hf.StreamID,
  3209  			EndHeaders:    true,
  3210  			EndStream:     false,
  3211  			BlockFragment: buf.Bytes(),
  3212  		})
  3213  		initialStreamWindow := ct.inflowWindow(hf.StreamID)
  3214  		pad := make([]byte, 5)
  3215  		ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
  3216  		if err := ct.writeReadPing(); err != nil {
  3217  			return err
  3218  		}
  3219  		// Padding flow control should have been returned.
  3220  		if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want {
  3221  			t.Errorf("conn inflow window = %v, want %v", got, want)
  3222  		}
  3223  		if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want {
  3224  			t.Errorf("stream inflow window = %v, want %v", got, want)
  3225  		}
  3226  		unblockClient <- true
  3227  		return nil
  3228  	}
  3229  	ct.run()
  3230  }
  3231  
  3232  // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
  3233  // StreamError as a result of the response HEADERS
  3234  func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
  3235  	ct := newClientTester(t)
  3236  
  3237  	ct.client = func() error {
  3238  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3239  		res, err := ct.tr.RoundTrip(req)
  3240  		if err == nil {
  3241  			res.Body.Close()
  3242  			return errors.New("unexpected successful GET")
  3243  		}
  3244  		want := StreamError{1, ErrCodeProtocol, headerFieldNameError("  content-type")}
  3245  		if !reflect.DeepEqual(want, err) {
  3246  			t.Errorf("RoundTrip error = %#v; want %#v", err, want)
  3247  		}
  3248  		return nil
  3249  	}
  3250  	ct.server = func() error {
  3251  		ct.greet()
  3252  
  3253  		hf, err := ct.firstHeaders()
  3254  		if err != nil {
  3255  			return err
  3256  		}
  3257  
  3258  		var buf bytes.Buffer
  3259  		enc := hpack.NewEncoder(&buf)
  3260  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3261  		enc.WriteField(hpack.HeaderField{Name: "  content-type", Value: "bogus"}) // bogus spaces
  3262  		ct.fr.WriteHeaders(HeadersFrameParam{
  3263  			StreamID:      hf.StreamID,
  3264  			EndHeaders:    true,
  3265  			EndStream:     false,
  3266  			BlockFragment: buf.Bytes(),
  3267  		})
  3268  
  3269  		for {
  3270  			fr, err := ct.readFrame()
  3271  			if err != nil {
  3272  				return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
  3273  			}
  3274  			if _, ok := fr.(*SettingsFrame); ok {
  3275  				continue
  3276  			}
  3277  			if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
  3278  				t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
  3279  			}
  3280  			break
  3281  		}
  3282  
  3283  		return nil
  3284  	}
  3285  	ct.run()
  3286  }
  3287  
  3288  // byteAndEOFReader returns is in an io.Reader which reads one byte
  3289  // (the underlying byte) and io.EOF at once in its Read call.
  3290  type byteAndEOFReader byte
  3291  
  3292  func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
  3293  	if len(p) == 0 {
  3294  		panic("unexpected useless call")
  3295  	}
  3296  	p[0] = byte(b)
  3297  	return 1, io.EOF
  3298  }
  3299  
  3300  // Issue 16788: the Transport had a regression where it started
  3301  // sending a spurious DATA frame with a duplicate END_STREAM bit after
  3302  // the request body writer goroutine had already read an EOF from the
  3303  // Request.Body and included the END_STREAM on a data-carrying DATA
  3304  // frame.
  3305  //
  3306  // Notably, to trigger this, the requests need to use a Request.Body
  3307  // which returns (non-0, io.EOF) and also needs to set the ContentLength
  3308  // explicitly.
  3309  func TestTransportBodyDoubleEndStream(t *testing.T) {
  3310  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  3311  		// Nothing.
  3312  	}, optOnlyServer)
  3313  	defer st.Close()
  3314  
  3315  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3316  	defer tr.CloseIdleConnections()
  3317  
  3318  	for i := 0; i < 2; i++ {
  3319  		req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a'))
  3320  		req.ContentLength = 1
  3321  		res, err := tr.RoundTrip(req)
  3322  		if err != nil {
  3323  			t.Fatalf("failure on req %d: %v", i+1, err)
  3324  		}
  3325  		defer res.Body.Close()
  3326  	}
  3327  }
  3328  
  3329  // golang.org/issue/16847, golang.org/issue/19103
  3330  func TestTransportRequestPathPseudo(t *testing.T) {
  3331  	type result struct {
  3332  		path string
  3333  		err  string
  3334  	}
  3335  	tests := []struct {
  3336  		req  *http.Request
  3337  		want result
  3338  	}{
  3339  		0: {
  3340  			req: &http.Request{
  3341  				Method: "GET",
  3342  				URL: &url.URL{
  3343  					Host: "foo.com",
  3344  					Path: "/foo",
  3345  				},
  3346  			},
  3347  			want: result{path: "/foo"},
  3348  		},
  3349  		// In Go 1.7, we accepted paths of "//foo".
  3350  		// In Go 1.8, we rejected it (issue 16847).
  3351  		// In Go 1.9, we accepted it again (issue 19103).
  3352  		1: {
  3353  			req: &http.Request{
  3354  				Method: "GET",
  3355  				URL: &url.URL{
  3356  					Host: "foo.com",
  3357  					Path: "//foo",
  3358  				},
  3359  			},
  3360  			want: result{path: "//foo"},
  3361  		},
  3362  
  3363  		// Opaque with //$Matching_Hostname/path
  3364  		2: {
  3365  			req: &http.Request{
  3366  				Method: "GET",
  3367  				URL: &url.URL{
  3368  					Scheme: "https",
  3369  					Opaque: "//foo.com/path",
  3370  					Host:   "foo.com",
  3371  					Path:   "/ignored",
  3372  				},
  3373  			},
  3374  			want: result{path: "/path"},
  3375  		},
  3376  
  3377  		// Opaque with some other Request.Host instead:
  3378  		3: {
  3379  			req: &http.Request{
  3380  				Method: "GET",
  3381  				Host:   "bar.com",
  3382  				URL: &url.URL{
  3383  					Scheme: "https",
  3384  					Opaque: "//bar.com/path",
  3385  					Host:   "foo.com",
  3386  					Path:   "/ignored",
  3387  				},
  3388  			},
  3389  			want: result{path: "/path"},
  3390  		},
  3391  
  3392  		// Opaque without the leading "//":
  3393  		4: {
  3394  			req: &http.Request{
  3395  				Method: "GET",
  3396  				URL: &url.URL{
  3397  					Opaque: "/path",
  3398  					Host:   "foo.com",
  3399  					Path:   "/ignored",
  3400  				},
  3401  			},
  3402  			want: result{path: "/path"},
  3403  		},
  3404  
  3405  		// Opaque we can't handle:
  3406  		5: {
  3407  			req: &http.Request{
  3408  				Method: "GET",
  3409  				URL: &url.URL{
  3410  					Scheme: "https",
  3411  					Opaque: "//unknown_host/path",
  3412  					Host:   "foo.com",
  3413  					Path:   "/ignored",
  3414  				},
  3415  			},
  3416  			want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
  3417  		},
  3418  
  3419  		// A CONNECT request:
  3420  		6: {
  3421  			req: &http.Request{
  3422  				Method: "CONNECT",
  3423  				URL: &url.URL{
  3424  					Host: "foo.com",
  3425  				},
  3426  			},
  3427  			want: result{},
  3428  		},
  3429  	}
  3430  	for i, tt := range tests {
  3431  		cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
  3432  		cc.henc = hpack.NewEncoder(&cc.hbuf)
  3433  		cc.mu.Lock()
  3434  		hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
  3435  		cc.mu.Unlock()
  3436  		var got result
  3437  		hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
  3438  			if f.Name == ":path" {
  3439  				got.path = f.Value
  3440  			}
  3441  		})
  3442  		if err != nil {
  3443  			got.err = err.Error()
  3444  		} else if len(hdrs) > 0 {
  3445  			if _, err := hpackDec.Write(hdrs); err != nil {
  3446  				t.Errorf("%d. bogus hpack: %v", i, err)
  3447  				continue
  3448  			}
  3449  		}
  3450  		if got != tt.want {
  3451  			t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
  3452  		}
  3453  
  3454  	}
  3455  
  3456  }
  3457  
  3458  // golang.org/issue/17071 -- don't sniff the first byte of the request body
  3459  // before we've determined that the ClientConn is usable.
  3460  func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
  3461  	const body = "foo"
  3462  	req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
  3463  	cc := &ClientConn{
  3464  		closed:      true,
  3465  		reqHeaderMu: make(chan struct{}, 1),
  3466  	}
  3467  	_, err := cc.RoundTrip(req)
  3468  	if err != errClientConnUnusable {
  3469  		t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
  3470  	}
  3471  	slurp, err := ioutil.ReadAll(req.Body)
  3472  	if err != nil {
  3473  		t.Errorf("ReadAll = %v", err)
  3474  	}
  3475  	if string(slurp) != body {
  3476  		t.Errorf("Body = %q; want %q", slurp, body)
  3477  	}
  3478  }
  3479  
  3480  func TestClientConnPing(t *testing.T) {
  3481  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
  3482  	defer st.Close()
  3483  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3484  	defer tr.CloseIdleConnections()
  3485  	ctx := context.Background()
  3486  	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
  3487  	if err != nil {
  3488  		t.Fatal(err)
  3489  	}
  3490  	if err = cc.Ping(context.Background()); err != nil {
  3491  		t.Fatal(err)
  3492  	}
  3493  }
  3494  
  3495  // Issue 16974: if the server sent a DATA frame after the user
  3496  // canceled the Transport's Request, the Transport previously wrote to a
  3497  // closed pipe, got an error, and ended up closing the whole TCP
  3498  // connection.
  3499  func TestTransportCancelDataResponseRace(t *testing.T) {
  3500  	cancel := make(chan struct{})
  3501  	clientGotResponse := make(chan bool, 1)
  3502  
  3503  	const msg = "Hello."
  3504  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  3505  		if strings.Contains(r.URL.Path, "/hello") {
  3506  			time.Sleep(50 * time.Millisecond)
  3507  			io.WriteString(w, msg)
  3508  			return
  3509  		}
  3510  		for i := 0; i < 50; i++ {
  3511  			io.WriteString(w, "Some data.")
  3512  			w.(http.Flusher).Flush()
  3513  			if i == 2 {
  3514  				<-clientGotResponse
  3515  				close(cancel)
  3516  			}
  3517  			time.Sleep(10 * time.Millisecond)
  3518  		}
  3519  	}, optOnlyServer)
  3520  	defer st.Close()
  3521  
  3522  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3523  	defer tr.CloseIdleConnections()
  3524  
  3525  	c := &http.Client{Transport: tr}
  3526  	req, _ := http.NewRequest("GET", st.ts.URL, nil)
  3527  	req.Cancel = cancel
  3528  	res, err := c.Do(req)
  3529  	clientGotResponse <- true
  3530  	if err != nil {
  3531  		t.Fatal(err)
  3532  	}
  3533  	if _, err = io.Copy(ioutil.Discard, res.Body); err == nil {
  3534  		t.Fatal("unexpected success")
  3535  	}
  3536  
  3537  	res, err = c.Get(st.ts.URL + "/hello")
  3538  	if err != nil {
  3539  		t.Fatal(err)
  3540  	}
  3541  	slurp, err := ioutil.ReadAll(res.Body)
  3542  	if err != nil {
  3543  		t.Fatal(err)
  3544  	}
  3545  	if string(slurp) != msg {
  3546  		t.Errorf("Got = %q; want %q", slurp, msg)
  3547  	}
  3548  }
  3549  
  3550  // Issue 21316: It should be safe to reuse an http.Request after the
  3551  // request has completed.
  3552  func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
  3553  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  3554  		w.WriteHeader(200)
  3555  		io.WriteString(w, "body")
  3556  	}, optOnlyServer)
  3557  	defer st.Close()
  3558  
  3559  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3560  	defer tr.CloseIdleConnections()
  3561  
  3562  	req, _ := http.NewRequest("GET", st.ts.URL, nil)
  3563  	resp, err := tr.RoundTrip(req)
  3564  	if err != nil {
  3565  		t.Fatal(err)
  3566  	}
  3567  	if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
  3568  		t.Fatalf("error reading response body: %v", err)
  3569  	}
  3570  	if err := resp.Body.Close(); err != nil {
  3571  		t.Fatalf("error closing response body: %v", err)
  3572  	}
  3573  
  3574  	// This access of req.Header should not race with code in the transport.
  3575  	req.Header = http.Header{}
  3576  }
  3577  
  3578  func TestTransportCloseAfterLostPing(t *testing.T) {
  3579  	clientDone := make(chan struct{})
  3580  	ct := newClientTester(t)
  3581  	ct.tr.PingTimeout = 1 * time.Second
  3582  	ct.tr.ReadIdleTimeout = 1 * time.Second
  3583  	ct.client = func() error {
  3584  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3585  		defer close(clientDone)
  3586  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3587  		_, err := ct.tr.RoundTrip(req)
  3588  		if err == nil || !strings.Contains(err.Error(), "client connection lost") {
  3589  			return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
  3590  		}
  3591  		return nil
  3592  	}
  3593  	ct.server = func() error {
  3594  		ct.greet()
  3595  		<-clientDone
  3596  		return nil
  3597  	}
  3598  	ct.run()
  3599  }
  3600  
  3601  func TestTransportPingWriteBlocks(t *testing.T) {
  3602  	st := newServerTester(t,
  3603  		func(w http.ResponseWriter, r *http.Request) {},
  3604  		optOnlyServer,
  3605  	)
  3606  	defer st.Close()
  3607  	tr := &Transport{
  3608  		TLSClientConfig: tlsConfigInsecure,
  3609  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  3610  			s, c := net.Pipe() // unbuffered, unlike a TCP conn
  3611  			go func() {
  3612  				// Read initial handshake frames.
  3613  				// Without this, we block indefinitely in newClientConn,
  3614  				// and never get to the point of sending a PING.
  3615  				var buf [1024]byte
  3616  				s.Read(buf[:])
  3617  			}()
  3618  			return c, nil
  3619  		},
  3620  		PingTimeout:     1 * time.Millisecond,
  3621  		ReadIdleTimeout: 1 * time.Millisecond,
  3622  	}
  3623  	defer tr.CloseIdleConnections()
  3624  	c := &http.Client{Transport: tr}
  3625  	_, err := c.Get(st.ts.URL)
  3626  	if err == nil {
  3627  		t.Fatalf("Get = nil, want error")
  3628  	}
  3629  }
  3630  
  3631  func TestTransportPingWhenReading(t *testing.T) {
  3632  	testCases := []struct {
  3633  		name              string
  3634  		readIdleTimeout   time.Duration
  3635  		deadline          time.Duration
  3636  		expectedPingCount int
  3637  	}{
  3638  		{
  3639  			name:              "two pings",
  3640  			readIdleTimeout:   100 * time.Millisecond,
  3641  			deadline:          time.Second,
  3642  			expectedPingCount: 2,
  3643  		},
  3644  		{
  3645  			name:              "zero ping",
  3646  			readIdleTimeout:   time.Second,
  3647  			deadline:          200 * time.Millisecond,
  3648  			expectedPingCount: 0,
  3649  		},
  3650  		{
  3651  			name:              "0 readIdleTimeout means no ping",
  3652  			readIdleTimeout:   0 * time.Millisecond,
  3653  			deadline:          500 * time.Millisecond,
  3654  			expectedPingCount: 0,
  3655  		},
  3656  	}
  3657  
  3658  	for _, tc := range testCases {
  3659  		tc := tc // capture range variable
  3660  		t.Run(tc.name, func(t *testing.T) {
  3661  			testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
  3662  		})
  3663  	}
  3664  }
  3665  
  3666  func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
  3667  	var pingCount int
  3668  	ct := newClientTester(t)
  3669  	ct.tr.ReadIdleTimeout = readIdleTimeout
  3670  
  3671  	ctx, cancel := context.WithTimeout(context.Background(), deadline)
  3672  	defer cancel()
  3673  	ct.client = func() error {
  3674  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3675  		if runtime.GOOS == "plan9" {
  3676  			// CloseWrite not supported on Plan 9; Issue 17906
  3677  			defer ct.cc.(*net.TCPConn).Close()
  3678  		}
  3679  		req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
  3680  		res, err := ct.tr.RoundTrip(req)
  3681  		if err != nil {
  3682  			return fmt.Errorf("RoundTrip: %v", err)
  3683  		}
  3684  		defer res.Body.Close()
  3685  		if res.StatusCode != 200 {
  3686  			return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
  3687  		}
  3688  		_, err = ioutil.ReadAll(res.Body)
  3689  		if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
  3690  			return nil
  3691  		}
  3692  
  3693  		cancel()
  3694  		return err
  3695  	}
  3696  
  3697  	ct.server = func() error {
  3698  		ct.greet()
  3699  		var buf bytes.Buffer
  3700  		enc := hpack.NewEncoder(&buf)
  3701  		var streamID uint32
  3702  		for {
  3703  			f, err := ct.fr.ReadFrame()
  3704  			if err != nil {
  3705  				select {
  3706  				case <-ctx.Done():
  3707  					// If the client's done, it
  3708  					// will have reported any
  3709  					// errors on its side.
  3710  					return nil
  3711  				default:
  3712  					return err
  3713  				}
  3714  			}
  3715  			switch f := f.(type) {
  3716  			case *WindowUpdateFrame, *SettingsFrame:
  3717  			case *HeadersFrame:
  3718  				if !f.HeadersEnded() {
  3719  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  3720  				}
  3721  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
  3722  				ct.fr.WriteHeaders(HeadersFrameParam{
  3723  					StreamID:      f.StreamID,
  3724  					EndHeaders:    true,
  3725  					EndStream:     false,
  3726  					BlockFragment: buf.Bytes(),
  3727  				})
  3728  				streamID = f.StreamID
  3729  			case *PingFrame:
  3730  				pingCount++
  3731  				if pingCount == expectedPingCount {
  3732  					if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
  3733  						return err
  3734  					}
  3735  				}
  3736  				if err := ct.fr.WritePing(true, f.Data); err != nil {
  3737  					return err
  3738  				}
  3739  			case *RSTStreamFrame:
  3740  			default:
  3741  				return fmt.Errorf("Unexpected client frame %v", f)
  3742  			}
  3743  		}
  3744  	}
  3745  	ct.run()
  3746  }
  3747  
  3748  func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
  3749  	ln := newLocalListener(t)
  3750  	defer ln.Close()
  3751  
  3752  	var (
  3753  		mu    sync.Mutex
  3754  		count int
  3755  		conns []net.Conn
  3756  	)
  3757  	var wg sync.WaitGroup
  3758  	tr := &Transport{
  3759  		TLSClientConfig: tlsConfigInsecure,
  3760  	}
  3761  	tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  3762  		mu.Lock()
  3763  		defer mu.Unlock()
  3764  		count++
  3765  		cc, err := net.Dial("tcp", ln.Addr().String())
  3766  		if err != nil {
  3767  			return nil, fmt.Errorf("dial error: %v", err)
  3768  		}
  3769  		conns = append(conns, cc)
  3770  		sc, err := ln.Accept()
  3771  		if err != nil {
  3772  			return nil, fmt.Errorf("accept error: %v", err)
  3773  		}
  3774  		conns = append(conns, sc)
  3775  		ct := &clientTester{
  3776  			t:  t,
  3777  			tr: tr,
  3778  			cc: cc,
  3779  			sc: sc,
  3780  			fr: NewFramer(sc, sc),
  3781  		}
  3782  		wg.Add(1)
  3783  		go func(count int) {
  3784  			defer wg.Done()
  3785  			server(count, ct)
  3786  		}(count)
  3787  		return cc, nil
  3788  	}
  3789  
  3790  	client(tr)
  3791  	tr.CloseIdleConnections()
  3792  	ln.Close()
  3793  	for _, c := range conns {
  3794  		c.Close()
  3795  	}
  3796  	wg.Wait()
  3797  }
  3798  
  3799  func TestTransportRetryAfterGOAWAY(t *testing.T) {
  3800  	client := func(tr *Transport) {
  3801  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3802  		res, err := tr.RoundTrip(req)
  3803  		if res != nil {
  3804  			res.Body.Close()
  3805  			if got := res.Header.Get("Foo"); got != "bar" {
  3806  				err = fmt.Errorf("foo header = %q; want bar", got)
  3807  			}
  3808  		}
  3809  		if err != nil {
  3810  			t.Errorf("RoundTrip: %v", err)
  3811  		}
  3812  	}
  3813  
  3814  	server := func(count int, ct *clientTester) {
  3815  		switch count {
  3816  		case 1:
  3817  			ct.greet()
  3818  			hf, err := ct.firstHeaders()
  3819  			if err != nil {
  3820  				t.Errorf("server1 failed reading HEADERS: %v", err)
  3821  				return
  3822  			}
  3823  			t.Logf("server1 got %v", hf)
  3824  			if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
  3825  				t.Errorf("server1 failed writing GOAWAY: %v", err)
  3826  				return
  3827  			}
  3828  		case 2:
  3829  			ct.greet()
  3830  			hf, err := ct.firstHeaders()
  3831  			if err != nil {
  3832  				t.Errorf("server2 failed reading HEADERS: %v", err)
  3833  				return
  3834  			}
  3835  			t.Logf("server2 got %v", hf)
  3836  
  3837  			var buf bytes.Buffer
  3838  			enc := hpack.NewEncoder(&buf)
  3839  			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3840  			enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  3841  			err = ct.fr.WriteHeaders(HeadersFrameParam{
  3842  				StreamID:      hf.StreamID,
  3843  				EndHeaders:    true,
  3844  				EndStream:     false,
  3845  				BlockFragment: buf.Bytes(),
  3846  			})
  3847  			if err != nil {
  3848  				t.Errorf("server2 failed writing response HEADERS: %v", err)
  3849  			}
  3850  		default:
  3851  			t.Errorf("unexpected number of dials")
  3852  			return
  3853  		}
  3854  	}
  3855  
  3856  	testClientMultipleDials(t, client, server)
  3857  }
  3858  
  3859  func TestTransportRetryAfterRefusedStream(t *testing.T) {
  3860  	clientDone := make(chan struct{})
  3861  	client := func(tr *Transport) {
  3862  		defer close(clientDone)
  3863  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3864  		resp, err := tr.RoundTrip(req)
  3865  		if err != nil {
  3866  			t.Errorf("RoundTrip: %v", err)
  3867  			return
  3868  		}
  3869  		resp.Body.Close()
  3870  		if resp.StatusCode != 204 {
  3871  			t.Errorf("Status = %v; want 204", resp.StatusCode)
  3872  			return
  3873  		}
  3874  	}
  3875  
  3876  	server := func(_ int, ct *clientTester) {
  3877  		ct.greet()
  3878  		var buf bytes.Buffer
  3879  		enc := hpack.NewEncoder(&buf)
  3880  		var count int
  3881  		for {
  3882  			f, err := ct.fr.ReadFrame()
  3883  			if err != nil {
  3884  				select {
  3885  				case <-clientDone:
  3886  					// If the client's done, it
  3887  					// will have reported any
  3888  					// errors on its side.
  3889  				default:
  3890  					t.Error(err)
  3891  				}
  3892  				return
  3893  			}
  3894  			switch f := f.(type) {
  3895  			case *WindowUpdateFrame, *SettingsFrame:
  3896  			case *HeadersFrame:
  3897  				if !f.HeadersEnded() {
  3898  					t.Errorf("headers should have END_HEADERS be ended: %v", f)
  3899  					return
  3900  				}
  3901  				count++
  3902  				if count == 1 {
  3903  					ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
  3904  				} else {
  3905  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
  3906  					ct.fr.WriteHeaders(HeadersFrameParam{
  3907  						StreamID:      f.StreamID,
  3908  						EndHeaders:    true,
  3909  						EndStream:     true,
  3910  						BlockFragment: buf.Bytes(),
  3911  					})
  3912  				}
  3913  			default:
  3914  				t.Errorf("Unexpected client frame %v", f)
  3915  				return
  3916  			}
  3917  		}
  3918  	}
  3919  
  3920  	testClientMultipleDials(t, client, server)
  3921  }
  3922  
  3923  func TestTransportRetryHasLimit(t *testing.T) {
  3924  	// Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
  3925  	if testing.Short() {
  3926  		t.Skip("skipping long test in short mode")
  3927  	}
  3928  	retryBackoffHook = func(d time.Duration) *time.Timer {
  3929  		return time.NewTimer(0) // fires immediately
  3930  	}
  3931  	defer func() {
  3932  		retryBackoffHook = nil
  3933  	}()
  3934  	clientDone := make(chan struct{})
  3935  	ct := newClientTester(t)
  3936  	ct.client = func() error {
  3937  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3938  		if runtime.GOOS == "plan9" {
  3939  			// CloseWrite not supported on Plan 9; Issue 17906
  3940  			defer ct.cc.(*net.TCPConn).Close()
  3941  		}
  3942  		defer close(clientDone)
  3943  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3944  		resp, err := ct.tr.RoundTrip(req)
  3945  		if err == nil {
  3946  			return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
  3947  		}
  3948  		t.Logf("expected error, got: %v", err)
  3949  		return nil
  3950  	}
  3951  	ct.server = func() error {
  3952  		ct.greet()
  3953  		for {
  3954  			f, err := ct.fr.ReadFrame()
  3955  			if err != nil {
  3956  				select {
  3957  				case <-clientDone:
  3958  					// If the client's done, it
  3959  					// will have reported any
  3960  					// errors on its side.
  3961  					return nil
  3962  				default:
  3963  					return err
  3964  				}
  3965  			}
  3966  			switch f := f.(type) {
  3967  			case *WindowUpdateFrame, *SettingsFrame:
  3968  			case *HeadersFrame:
  3969  				if !f.HeadersEnded() {
  3970  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  3971  				}
  3972  				ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
  3973  			default:
  3974  				return fmt.Errorf("Unexpected client frame %v", f)
  3975  			}
  3976  		}
  3977  	}
  3978  	ct.run()
  3979  }
  3980  
  3981  func TestTransportResponseDataBeforeHeaders(t *testing.T) {
  3982  	// This test use not valid response format.
  3983  	// Discarding logger output to not spam tests output.
  3984  	log.SetOutput(ioutil.Discard)
  3985  	defer log.SetOutput(os.Stderr)
  3986  
  3987  	ct := newClientTester(t)
  3988  	ct.client = func() error {
  3989  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3990  		if runtime.GOOS == "plan9" {
  3991  			// CloseWrite not supported on Plan 9; Issue 17906
  3992  			defer ct.cc.(*net.TCPConn).Close()
  3993  		}
  3994  		req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
  3995  		// First request is normal to ensure the check is per stream and not per connection.
  3996  		_, err := ct.tr.RoundTrip(req)
  3997  		if err != nil {
  3998  			return fmt.Errorf("RoundTrip expected no error, got: %v", err)
  3999  		}
  4000  		// Second request returns a DATA frame with no HEADERS.
  4001  		resp, err := ct.tr.RoundTrip(req)
  4002  		if err == nil {
  4003  			return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
  4004  		}
  4005  		if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
  4006  			return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
  4007  		}
  4008  		return nil
  4009  	}
  4010  	ct.server = func() error {
  4011  		ct.greet()
  4012  		for {
  4013  			f, err := ct.fr.ReadFrame()
  4014  			if err == io.EOF {
  4015  				return nil
  4016  			} else if err != nil {
  4017  				return err
  4018  			}
  4019  			switch f := f.(type) {
  4020  			case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame:
  4021  			case *HeadersFrame:
  4022  				switch f.StreamID {
  4023  				case 1:
  4024  					// Send a valid response to first request.
  4025  					var buf bytes.Buffer
  4026  					enc := hpack.NewEncoder(&buf)
  4027  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  4028  					ct.fr.WriteHeaders(HeadersFrameParam{
  4029  						StreamID:      f.StreamID,
  4030  						EndHeaders:    true,
  4031  						EndStream:     true,
  4032  						BlockFragment: buf.Bytes(),
  4033  					})
  4034  				case 3:
  4035  					ct.fr.WriteData(f.StreamID, true, []byte("payload"))
  4036  				}
  4037  			default:
  4038  				return fmt.Errorf("Unexpected client frame %v", f)
  4039  			}
  4040  		}
  4041  	}
  4042  	ct.run()
  4043  }
  4044  
  4045  func TestTransportMaxFrameReadSize(t *testing.T) {
  4046  	for _, test := range []struct {
  4047  		maxReadFrameSize uint32
  4048  		want             uint32
  4049  	}{{
  4050  		maxReadFrameSize: 64000,
  4051  		want:             64000,
  4052  	}, {
  4053  		maxReadFrameSize: 1024,
  4054  		want:             minMaxFrameSize,
  4055  	}} {
  4056  		ct := newClientTester(t)
  4057  		ct.tr.MaxReadFrameSize = test.maxReadFrameSize
  4058  		ct.client = func() error {
  4059  			req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
  4060  			ct.tr.RoundTrip(req)
  4061  			return nil
  4062  		}
  4063  		ct.server = func() error {
  4064  			defer ct.cc.(*net.TCPConn).Close()
  4065  			ct.greet()
  4066  			var got uint32
  4067  			ct.settings.ForeachSetting(func(s Setting) error {
  4068  				switch s.ID {
  4069  				case SettingMaxFrameSize:
  4070  					got = s.Val
  4071  				}
  4072  				return nil
  4073  			})
  4074  			if got != test.want {
  4075  				t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
  4076  			}
  4077  			return nil
  4078  		}
  4079  		ct.run()
  4080  	}
  4081  }
  4082  
  4083  func TestTransportRequestsLowServerLimit(t *testing.T) {
  4084  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4085  	}, optOnlyServer, func(s *Server) {
  4086  		s.MaxConcurrentStreams = 1
  4087  	})
  4088  	defer st.Close()
  4089  
  4090  	var (
  4091  		connCountMu sync.Mutex
  4092  		connCount   int
  4093  	)
  4094  	tr := &Transport{
  4095  		TLSClientConfig: tlsConfigInsecure,
  4096  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  4097  			connCountMu.Lock()
  4098  			defer connCountMu.Unlock()
  4099  			connCount++
  4100  			return tls.Dial(network, addr, cfg)
  4101  		},
  4102  	}
  4103  	defer tr.CloseIdleConnections()
  4104  
  4105  	const reqCount = 3
  4106  	for i := 0; i < reqCount; i++ {
  4107  		req, err := http.NewRequest("GET", st.ts.URL, nil)
  4108  		if err != nil {
  4109  			t.Fatal(err)
  4110  		}
  4111  		res, err := tr.RoundTrip(req)
  4112  		if err != nil {
  4113  			t.Fatal(err)
  4114  		}
  4115  		if got, want := res.StatusCode, 200; got != want {
  4116  			t.Errorf("StatusCode = %v; want %v", got, want)
  4117  		}
  4118  		if res != nil && res.Body != nil {
  4119  			res.Body.Close()
  4120  		}
  4121  	}
  4122  
  4123  	if connCount != 1 {
  4124  		t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
  4125  	}
  4126  }
  4127  
  4128  // tests Transport.StrictMaxConcurrentStreams
  4129  func TestTransportRequestsStallAtServerLimit(t *testing.T) {
  4130  	const maxConcurrent = 2
  4131  
  4132  	greet := make(chan struct{})      // server sends initial SETTINGS frame
  4133  	gotRequest := make(chan struct{}) // server received a request
  4134  	clientDone := make(chan struct{})
  4135  	cancelClientRequest := make(chan struct{})
  4136  
  4137  	// Collect errors from goroutines.
  4138  	var wg sync.WaitGroup
  4139  	errs := make(chan error, 100)
  4140  	defer func() {
  4141  		wg.Wait()
  4142  		close(errs)
  4143  		for err := range errs {
  4144  			t.Error(err)
  4145  		}
  4146  	}()
  4147  
  4148  	// We will send maxConcurrent+2 requests. This checker goroutine waits for the
  4149  	// following stages:
  4150  	//   1. The first maxConcurrent requests are received by the server.
  4151  	//   2. The client will cancel the next request
  4152  	//   3. The server is unblocked so it can service the first maxConcurrent requests
  4153  	//   4. The client will send the final request
  4154  	wg.Add(1)
  4155  	unblockClient := make(chan struct{})
  4156  	clientRequestCancelled := make(chan struct{})
  4157  	unblockServer := make(chan struct{})
  4158  	go func() {
  4159  		defer wg.Done()
  4160  		// Stage 1.
  4161  		for k := 0; k < maxConcurrent; k++ {
  4162  			<-gotRequest
  4163  		}
  4164  		// Stage 2.
  4165  		close(unblockClient)
  4166  		<-clientRequestCancelled
  4167  		// Stage 3: give some time for the final RoundTrip call to be scheduled and
  4168  		// verify that the final request is not sent.
  4169  		time.Sleep(50 * time.Millisecond)
  4170  		select {
  4171  		case <-gotRequest:
  4172  			errs <- errors.New("last request did not stall")
  4173  			close(unblockServer)
  4174  			return
  4175  		default:
  4176  		}
  4177  		close(unblockServer)
  4178  		// Stage 4.
  4179  		<-gotRequest
  4180  	}()
  4181  
  4182  	ct := newClientTester(t)
  4183  	ct.tr.StrictMaxConcurrentStreams = true
  4184  	ct.client = func() error {
  4185  		var wg sync.WaitGroup
  4186  		defer func() {
  4187  			wg.Wait()
  4188  			close(clientDone)
  4189  			ct.cc.(*net.TCPConn).CloseWrite()
  4190  			if runtime.GOOS == "plan9" {
  4191  				// CloseWrite not supported on Plan 9; Issue 17906
  4192  				ct.cc.(*net.TCPConn).Close()
  4193  			}
  4194  		}()
  4195  		for k := 0; k < maxConcurrent+2; k++ {
  4196  			wg.Add(1)
  4197  			go func(k int) {
  4198  				defer wg.Done()
  4199  				// Don't send the second request until after receiving SETTINGS from the server
  4200  				// to avoid a race where we use the default SettingMaxConcurrentStreams, which
  4201  				// is much larger than maxConcurrent. We have to send the first request before
  4202  				// waiting because the first request triggers the dial and greet.
  4203  				if k > 0 {
  4204  					<-greet
  4205  				}
  4206  				// Block until maxConcurrent requests are sent before sending any more.
  4207  				if k >= maxConcurrent {
  4208  					<-unblockClient
  4209  				}
  4210  				body := newStaticCloseChecker("")
  4211  				req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
  4212  				if k == maxConcurrent {
  4213  					// This request will be canceled.
  4214  					req.Cancel = cancelClientRequest
  4215  					close(cancelClientRequest)
  4216  					_, err := ct.tr.RoundTrip(req)
  4217  					close(clientRequestCancelled)
  4218  					if err == nil {
  4219  						errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
  4220  						return
  4221  					}
  4222  				} else {
  4223  					resp, err := ct.tr.RoundTrip(req)
  4224  					if err != nil {
  4225  						errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
  4226  						return
  4227  					}
  4228  					ioutil.ReadAll(resp.Body)
  4229  					resp.Body.Close()
  4230  					if resp.StatusCode != 204 {
  4231  						errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
  4232  						return
  4233  					}
  4234  				}
  4235  				if err := body.isClosed(); err != nil {
  4236  					errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
  4237  				}
  4238  			}(k)
  4239  		}
  4240  		return nil
  4241  	}
  4242  
  4243  	ct.server = func() error {
  4244  		var wg sync.WaitGroup
  4245  		defer wg.Wait()
  4246  
  4247  		ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
  4248  
  4249  		// Server write loop.
  4250  		var buf bytes.Buffer
  4251  		enc := hpack.NewEncoder(&buf)
  4252  		writeResp := make(chan uint32, maxConcurrent+1)
  4253  
  4254  		wg.Add(1)
  4255  		go func() {
  4256  			defer wg.Done()
  4257  			<-unblockServer
  4258  			for id := range writeResp {
  4259  				buf.Reset()
  4260  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
  4261  				ct.fr.WriteHeaders(HeadersFrameParam{
  4262  					StreamID:      id,
  4263  					EndHeaders:    true,
  4264  					EndStream:     true,
  4265  					BlockFragment: buf.Bytes(),
  4266  				})
  4267  			}
  4268  		}()
  4269  
  4270  		// Server read loop.
  4271  		var nreq int
  4272  		for {
  4273  			f, err := ct.fr.ReadFrame()
  4274  			if err != nil {
  4275  				select {
  4276  				case <-clientDone:
  4277  					// If the client's done, it will have reported any errors on its side.
  4278  					return nil
  4279  				default:
  4280  					return err
  4281  				}
  4282  			}
  4283  			switch f := f.(type) {
  4284  			case *WindowUpdateFrame:
  4285  			case *SettingsFrame:
  4286  				// Wait for the client SETTINGS ack until ending the greet.
  4287  				close(greet)
  4288  			case *HeadersFrame:
  4289  				if !f.HeadersEnded() {
  4290  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  4291  				}
  4292  				gotRequest <- struct{}{}
  4293  				nreq++
  4294  				writeResp <- f.StreamID
  4295  				if nreq == maxConcurrent+1 {
  4296  					close(writeResp)
  4297  				}
  4298  			case *DataFrame:
  4299  			default:
  4300  				return fmt.Errorf("Unexpected client frame %v", f)
  4301  			}
  4302  		}
  4303  	}
  4304  
  4305  	ct.run()
  4306  }
  4307  
  4308  func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
  4309  	ct := newClientTester(t)
  4310  	var reqSize, resSize uint32 = 8192, 16384
  4311  	ct.tr.MaxDecoderHeaderTableSize = reqSize
  4312  	ct.client = func() error {
  4313  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  4314  		cc, err := ct.tr.NewClientConn(ct.cc)
  4315  		if err != nil {
  4316  			return err
  4317  		}
  4318  		_, err = cc.RoundTrip(req)
  4319  		if err != nil {
  4320  			return err
  4321  		}
  4322  		if got, want := cc.peerMaxHeaderTableSize, resSize; got != want {
  4323  			return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want)
  4324  		}
  4325  		return nil
  4326  	}
  4327  	ct.server = func() error {
  4328  		buf := make([]byte, len(ClientPreface))
  4329  		_, err := io.ReadFull(ct.sc, buf)
  4330  		if err != nil {
  4331  			return fmt.Errorf("reading client preface: %v", err)
  4332  		}
  4333  		f, err := ct.fr.ReadFrame()
  4334  		if err != nil {
  4335  			return err
  4336  		}
  4337  		sf, ok := f.(*SettingsFrame)
  4338  		if !ok {
  4339  			ct.t.Fatalf("wanted client settings frame; got %v", f)
  4340  			_ = sf // stash it away?
  4341  		}
  4342  		var found bool
  4343  		err = sf.ForeachSetting(func(s Setting) error {
  4344  			if s.ID == SettingHeaderTableSize {
  4345  				found = true
  4346  				if got, want := s.Val, reqSize; got != want {
  4347  					return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want)
  4348  				}
  4349  			}
  4350  			return nil
  4351  		})
  4352  		if err != nil {
  4353  			return err
  4354  		}
  4355  		if !found {
  4356  			return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting")
  4357  		}
  4358  		if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil {
  4359  			ct.t.Fatal(err)
  4360  		}
  4361  		if err := ct.fr.WriteSettingsAck(); err != nil {
  4362  			ct.t.Fatal(err)
  4363  		}
  4364  
  4365  		for {
  4366  			f, err := ct.fr.ReadFrame()
  4367  			if err != nil {
  4368  				return err
  4369  			}
  4370  			switch f := f.(type) {
  4371  			case *HeadersFrame:
  4372  				var buf bytes.Buffer
  4373  				enc := hpack.NewEncoder(&buf)
  4374  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  4375  				ct.fr.WriteHeaders(HeadersFrameParam{
  4376  					StreamID:      f.StreamID,
  4377  					EndHeaders:    true,
  4378  					EndStream:     true,
  4379  					BlockFragment: buf.Bytes(),
  4380  				})
  4381  				return nil
  4382  			}
  4383  		}
  4384  	}
  4385  	ct.run()
  4386  }
  4387  
  4388  func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
  4389  	ct := newClientTester(t)
  4390  	var peerAdvertisedMaxHeaderTableSize uint32 = 16384
  4391  	ct.tr.MaxEncoderHeaderTableSize = 8192
  4392  	ct.client = func() error {
  4393  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  4394  		cc, err := ct.tr.NewClientConn(ct.cc)
  4395  		if err != nil {
  4396  			return err
  4397  		}
  4398  		_, err = cc.RoundTrip(req)
  4399  		if err != nil {
  4400  			return err
  4401  		}
  4402  		if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want {
  4403  			return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
  4404  		}
  4405  		return nil
  4406  	}
  4407  	ct.server = func() error {
  4408  		buf := make([]byte, len(ClientPreface))
  4409  		_, err := io.ReadFull(ct.sc, buf)
  4410  		if err != nil {
  4411  			return fmt.Errorf("reading client preface: %v", err)
  4412  		}
  4413  		f, err := ct.fr.ReadFrame()
  4414  		if err != nil {
  4415  			return err
  4416  		}
  4417  		sf, ok := f.(*SettingsFrame)
  4418  		if !ok {
  4419  			ct.t.Fatalf("wanted client settings frame; got %v", f)
  4420  			_ = sf // stash it away?
  4421  		}
  4422  		if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil {
  4423  			ct.t.Fatal(err)
  4424  		}
  4425  		if err := ct.fr.WriteSettingsAck(); err != nil {
  4426  			ct.t.Fatal(err)
  4427  		}
  4428  
  4429  		for {
  4430  			f, err := ct.fr.ReadFrame()
  4431  			if err != nil {
  4432  				return err
  4433  			}
  4434  			switch f := f.(type) {
  4435  			case *HeadersFrame:
  4436  				var buf bytes.Buffer
  4437  				enc := hpack.NewEncoder(&buf)
  4438  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  4439  				ct.fr.WriteHeaders(HeadersFrameParam{
  4440  					StreamID:      f.StreamID,
  4441  					EndHeaders:    true,
  4442  					EndStream:     true,
  4443  					BlockFragment: buf.Bytes(),
  4444  				})
  4445  				return nil
  4446  			}
  4447  		}
  4448  	}
  4449  	ct.run()
  4450  }
  4451  
  4452  func TestAuthorityAddr(t *testing.T) {
  4453  	tests := []struct {
  4454  		scheme, authority string
  4455  		want              string
  4456  	}{
  4457  		{"http", "foo.com", "foo.com:80"},
  4458  		{"https", "foo.com", "foo.com:443"},
  4459  		{"https", "foo.com:", "foo.com:443"},
  4460  		{"https", "foo.com:1234", "foo.com:1234"},
  4461  		{"https", "1.2.3.4:1234", "1.2.3.4:1234"},
  4462  		{"https", "1.2.3.4", "1.2.3.4:443"},
  4463  		{"https", "1.2.3.4:", "1.2.3.4:443"},
  4464  		{"https", "[::1]:1234", "[::1]:1234"},
  4465  		{"https", "[::1]", "[::1]:443"},
  4466  		{"https", "[::1]:", "[::1]:443"},
  4467  	}
  4468  	for _, tt := range tests {
  4469  		got := authorityAddr(tt.scheme, tt.authority)
  4470  		if got != tt.want {
  4471  			t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
  4472  		}
  4473  	}
  4474  }
  4475  
  4476  // Issue 20448: stop allocating for DATA frames' payload after
  4477  // Response.Body.Close is called.
  4478  func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
  4479  	megabyteZero := make([]byte, 1<<20)
  4480  
  4481  	writeErr := make(chan error, 1)
  4482  
  4483  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4484  		w.(http.Flusher).Flush()
  4485  		var sum int64
  4486  		for i := 0; i < 100; i++ {
  4487  			n, err := w.Write(megabyteZero)
  4488  			sum += int64(n)
  4489  			if err != nil {
  4490  				writeErr <- err
  4491  				return
  4492  			}
  4493  		}
  4494  		t.Logf("wrote all %d bytes", sum)
  4495  		writeErr <- nil
  4496  	}, optOnlyServer)
  4497  	defer st.Close()
  4498  
  4499  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4500  	defer tr.CloseIdleConnections()
  4501  	c := &http.Client{Transport: tr}
  4502  	res, err := c.Get(st.ts.URL)
  4503  	if err != nil {
  4504  		t.Fatal(err)
  4505  	}
  4506  	var buf [1]byte
  4507  	if _, err := res.Body.Read(buf[:]); err != nil {
  4508  		t.Error(err)
  4509  	}
  4510  	if err := res.Body.Close(); err != nil {
  4511  		t.Error(err)
  4512  	}
  4513  
  4514  	trb, ok := res.Body.(transportResponseBody)
  4515  	if !ok {
  4516  		t.Fatalf("res.Body = %T; want transportResponseBody", res.Body)
  4517  	}
  4518  	if trb.cs.bufPipe.b != nil {
  4519  		t.Errorf("response body pipe is still open")
  4520  	}
  4521  
  4522  	gotErr := <-writeErr
  4523  	if gotErr == nil {
  4524  		t.Errorf("Handler unexpectedly managed to write its entire response without getting an error")
  4525  	} else if gotErr != errStreamClosed {
  4526  		t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr)
  4527  	}
  4528  }
  4529  
  4530  // Issue 18891: make sure Request.Body == NoBody means no DATA frame
  4531  // is ever sent, even if empty.
  4532  func TestTransportNoBodyMeansNoDATA(t *testing.T) {
  4533  	ct := newClientTester(t)
  4534  
  4535  	unblockClient := make(chan bool)
  4536  
  4537  	ct.client = func() error {
  4538  		req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
  4539  		ct.tr.RoundTrip(req)
  4540  		<-unblockClient
  4541  		return nil
  4542  	}
  4543  	ct.server = func() error {
  4544  		defer close(unblockClient)
  4545  		defer ct.cc.(*net.TCPConn).Close()
  4546  		ct.greet()
  4547  
  4548  		for {
  4549  			f, err := ct.fr.ReadFrame()
  4550  			if err != nil {
  4551  				return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  4552  			}
  4553  			switch f := f.(type) {
  4554  			default:
  4555  				return fmt.Errorf("Got %T; want HeadersFrame", f)
  4556  			case *WindowUpdateFrame, *SettingsFrame:
  4557  				continue
  4558  			case *HeadersFrame:
  4559  				if !f.StreamEnded() {
  4560  					return fmt.Errorf("got headers frame without END_STREAM")
  4561  				}
  4562  				return nil
  4563  			}
  4564  		}
  4565  	}
  4566  	ct.run()
  4567  }
  4568  
  4569  func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
  4570  	defer disableGoroutineTracking()()
  4571  	b.ReportAllocs()
  4572  	st := newServerTester(b,
  4573  		func(w http.ResponseWriter, r *http.Request) {
  4574  			for i := 0; i < nResHeader; i++ {
  4575  				name := fmt.Sprint("A-", i)
  4576  				w.Header().Set(name, "*")
  4577  			}
  4578  		},
  4579  		optOnlyServer,
  4580  		optQuiet,
  4581  	)
  4582  	defer st.Close()
  4583  
  4584  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4585  	defer tr.CloseIdleConnections()
  4586  
  4587  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  4588  	if err != nil {
  4589  		b.Fatal(err)
  4590  	}
  4591  
  4592  	for i := 0; i < nReqHeaders; i++ {
  4593  		name := fmt.Sprint("A-", i)
  4594  		req.Header.Set(name, "*")
  4595  	}
  4596  
  4597  	b.ResetTimer()
  4598  
  4599  	for i := 0; i < b.N; i++ {
  4600  		res, err := tr.RoundTrip(req)
  4601  		if err != nil {
  4602  			if res != nil {
  4603  				res.Body.Close()
  4604  			}
  4605  			b.Fatalf("RoundTrip err = %v; want nil", err)
  4606  		}
  4607  		res.Body.Close()
  4608  		if res.StatusCode != http.StatusOK {
  4609  			b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
  4610  		}
  4611  	}
  4612  }
  4613  
  4614  type infiniteReader struct{}
  4615  
  4616  func (r infiniteReader) Read(b []byte) (int, error) {
  4617  	return len(b), nil
  4618  }
  4619  
  4620  // Issue 20521: it is not an error to receive a response and end stream
  4621  // from the server without the body being consumed.
  4622  func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
  4623  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4624  		w.WriteHeader(http.StatusOK)
  4625  	}, optOnlyServer)
  4626  	defer st.Close()
  4627  
  4628  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4629  	defer tr.CloseIdleConnections()
  4630  
  4631  	// The request body needs to be big enough to trigger flow control.
  4632  	req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{})
  4633  	res, err := tr.RoundTrip(req)
  4634  	if err != nil {
  4635  		t.Fatal(err)
  4636  	}
  4637  	if res.StatusCode != http.StatusOK {
  4638  		t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
  4639  	}
  4640  }
  4641  
  4642  // Verify transport doesn't crash when receiving bogus response lacking a :status header.
  4643  // Issue 22880.
  4644  func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
  4645  	ct := newClientTester(t)
  4646  	ct.client = func() error {
  4647  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  4648  		_, err := ct.tr.RoundTrip(req)
  4649  		const substr = "malformed response from server: missing status pseudo header"
  4650  		if !strings.Contains(fmt.Sprint(err), substr) {
  4651  			return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
  4652  		}
  4653  		return nil
  4654  	}
  4655  	ct.server = func() error {
  4656  		ct.greet()
  4657  		var buf bytes.Buffer
  4658  		enc := hpack.NewEncoder(&buf)
  4659  
  4660  		for {
  4661  			f, err := ct.fr.ReadFrame()
  4662  			if err != nil {
  4663  				return err
  4664  			}
  4665  			switch f := f.(type) {
  4666  			case *HeadersFrame:
  4667  				enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
  4668  				ct.fr.WriteHeaders(HeadersFrameParam{
  4669  					StreamID:      f.StreamID,
  4670  					EndHeaders:    true,
  4671  					EndStream:     false, // we'll send some DATA to try to crash the transport
  4672  					BlockFragment: buf.Bytes(),
  4673  				})
  4674  				ct.fr.WriteData(f.StreamID, true, []byte("payload"))
  4675  				return nil
  4676  			}
  4677  		}
  4678  	}
  4679  	ct.run()
  4680  }
  4681  
  4682  func BenchmarkClientRequestHeaders(b *testing.B) {
  4683  	b.Run("   0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
  4684  	b.Run("  10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) })
  4685  	b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) })
  4686  	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) })
  4687  }
  4688  
  4689  func BenchmarkClientResponseHeaders(b *testing.B) {
  4690  	b.Run("   0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
  4691  	b.Run("  10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) })
  4692  	b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) })
  4693  	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
  4694  }
  4695  
  4696  func BenchmarkDownloadFrameSize(b *testing.B) {
  4697  	b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
  4698  	b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
  4699  	b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
  4700  	b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
  4701  	b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
  4702  }
  4703  func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
  4704  	defer disableGoroutineTracking()()
  4705  	const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M
  4706  	b.ReportAllocs()
  4707  	st := newServerTester(b,
  4708  		func(w http.ResponseWriter, r *http.Request) {
  4709  			// test 1GB transfer
  4710  			w.Header().Set("Content-Length", strconv.Itoa(transferSize))
  4711  			w.Header().Set("Content-Transfer-Encoding", "binary")
  4712  			var data [1024 * 1024]byte
  4713  			for i := 0; i < transferSize/(1024*1024); i++ {
  4714  				w.Write(data[:])
  4715  			}
  4716  		}, optQuiet,
  4717  	)
  4718  	defer st.Close()
  4719  
  4720  	tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
  4721  	defer tr.CloseIdleConnections()
  4722  
  4723  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  4724  	if err != nil {
  4725  		b.Fatal(err)
  4726  	}
  4727  
  4728  	b.N = 3
  4729  	b.SetBytes(transferSize)
  4730  	b.ResetTimer()
  4731  
  4732  	for i := 0; i < b.N; i++ {
  4733  		res, err := tr.RoundTrip(req)
  4734  		if err != nil {
  4735  			if res != nil {
  4736  				res.Body.Close()
  4737  			}
  4738  			b.Fatalf("RoundTrip err = %v; want nil", err)
  4739  		}
  4740  		data, _ := io.ReadAll(res.Body)
  4741  		if len(data) != transferSize {
  4742  			b.Fatalf("Response length invalid")
  4743  		}
  4744  		res.Body.Close()
  4745  		if res.StatusCode != http.StatusOK {
  4746  			b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
  4747  		}
  4748  	}
  4749  }
  4750  
  4751  func activeStreams(cc *ClientConn) int {
  4752  	count := 0
  4753  	cc.mu.Lock()
  4754  	defer cc.mu.Unlock()
  4755  	for _, cs := range cc.streams {
  4756  		select {
  4757  		case <-cs.abort:
  4758  		default:
  4759  			count++
  4760  		}
  4761  	}
  4762  	return count
  4763  }
  4764  
  4765  type closeMode int
  4766  
  4767  const (
  4768  	closeAtHeaders closeMode = iota
  4769  	closeAtBody
  4770  	shutdown
  4771  	shutdownCancel
  4772  )
  4773  
  4774  // See golang.org/issue/17292
  4775  func testClientConnClose(t *testing.T, closeMode closeMode) {
  4776  	clientDone := make(chan struct{})
  4777  	defer close(clientDone)
  4778  	handlerDone := make(chan struct{})
  4779  	closeDone := make(chan struct{})
  4780  	beforeHeader := func() {}
  4781  	bodyWrite := func(w http.ResponseWriter) {}
  4782  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4783  		defer close(handlerDone)
  4784  		beforeHeader()
  4785  		w.WriteHeader(http.StatusOK)
  4786  		w.(http.Flusher).Flush()
  4787  		bodyWrite(w)
  4788  		select {
  4789  		case <-w.(http.CloseNotifier).CloseNotify():
  4790  			// client closed connection before completion
  4791  			if closeMode == shutdown || closeMode == shutdownCancel {
  4792  				t.Error("expected request to complete")
  4793  			}
  4794  		case <-clientDone:
  4795  			if closeMode == closeAtHeaders || closeMode == closeAtBody {
  4796  				t.Error("expected connection closed by client")
  4797  			}
  4798  		}
  4799  	}, optOnlyServer)
  4800  	defer st.Close()
  4801  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4802  	defer tr.CloseIdleConnections()
  4803  	ctx := context.Background()
  4804  	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
  4805  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  4806  	if err != nil {
  4807  		t.Fatal(err)
  4808  	}
  4809  	if closeMode == closeAtHeaders {
  4810  		beforeHeader = func() {
  4811  			if err := cc.Close(); err != nil {
  4812  				t.Error(err)
  4813  			}
  4814  			close(closeDone)
  4815  		}
  4816  	}
  4817  	var sendBody chan struct{}
  4818  	if closeMode == closeAtBody {
  4819  		sendBody = make(chan struct{})
  4820  		bodyWrite = func(w http.ResponseWriter) {
  4821  			<-sendBody
  4822  			b := make([]byte, 32)
  4823  			w.Write(b)
  4824  			w.(http.Flusher).Flush()
  4825  			if err := cc.Close(); err != nil {
  4826  				t.Errorf("unexpected ClientConn close error: %v", err)
  4827  			}
  4828  			close(closeDone)
  4829  			w.Write(b)
  4830  			w.(http.Flusher).Flush()
  4831  		}
  4832  	}
  4833  	res, err := cc.RoundTrip(req)
  4834  	if res != nil {
  4835  		defer res.Body.Close()
  4836  	}
  4837  	if closeMode == closeAtHeaders {
  4838  		got := fmt.Sprint(err)
  4839  		want := "http2: client connection force closed via ClientConn.Close"
  4840  		if got != want {
  4841  			t.Fatalf("RoundTrip error = %v, want %v", got, want)
  4842  		}
  4843  	} else {
  4844  		if err != nil {
  4845  			t.Fatalf("RoundTrip: %v", err)
  4846  		}
  4847  		if got, want := activeStreams(cc), 1; got != want {
  4848  			t.Errorf("got %d active streams, want %d", got, want)
  4849  		}
  4850  	}
  4851  	switch closeMode {
  4852  	case shutdownCancel:
  4853  		if err = cc.Shutdown(canceledCtx); err != context.Canceled {
  4854  			t.Errorf("got %v, want %v", err, context.Canceled)
  4855  		}
  4856  		if cc.closing == false {
  4857  			t.Error("expected closing to be true")
  4858  		}
  4859  		if cc.CanTakeNewRequest() == true {
  4860  			t.Error("CanTakeNewRequest to return false")
  4861  		}
  4862  		if v, want := len(cc.streams), 1; v != want {
  4863  			t.Errorf("expected %d active streams, got %d", want, v)
  4864  		}
  4865  		clientDone <- struct{}{}
  4866  		<-handlerDone
  4867  	case shutdown:
  4868  		wait := make(chan struct{})
  4869  		shutdownEnterWaitStateHook = func() {
  4870  			close(wait)
  4871  			shutdownEnterWaitStateHook = func() {}
  4872  		}
  4873  		defer func() { shutdownEnterWaitStateHook = func() {} }()
  4874  		shutdown := make(chan struct{}, 1)
  4875  		go func() {
  4876  			if err = cc.Shutdown(context.Background()); err != nil {
  4877  				t.Error(err)
  4878  			}
  4879  			close(shutdown)
  4880  		}()
  4881  		// Let the shutdown to enter wait state
  4882  		<-wait
  4883  		cc.mu.Lock()
  4884  		if cc.closing == false {
  4885  			t.Error("expected closing to be true")
  4886  		}
  4887  		cc.mu.Unlock()
  4888  		if cc.CanTakeNewRequest() == true {
  4889  			t.Error("CanTakeNewRequest to return false")
  4890  		}
  4891  		if got, want := activeStreams(cc), 1; got != want {
  4892  			t.Errorf("got %d active streams, want %d", got, want)
  4893  		}
  4894  		// Let the active request finish
  4895  		clientDone <- struct{}{}
  4896  		// Wait for the shutdown to end
  4897  		select {
  4898  		case <-shutdown:
  4899  		case <-time.After(2 * time.Second):
  4900  			t.Fatal("expected server connection to close")
  4901  		}
  4902  	case closeAtHeaders, closeAtBody:
  4903  		if closeMode == closeAtBody {
  4904  			go close(sendBody)
  4905  			if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
  4906  				t.Error("expected a Copy error, got nil")
  4907  			}
  4908  		}
  4909  		<-closeDone
  4910  		if got, want := activeStreams(cc), 0; got != want {
  4911  			t.Errorf("got %d active streams, want %d", got, want)
  4912  		}
  4913  		// wait for server to get the connection close notice
  4914  		select {
  4915  		case <-handlerDone:
  4916  		case <-time.After(2 * time.Second):
  4917  			t.Fatal("expected server connection to close")
  4918  		}
  4919  	}
  4920  }
  4921  
  4922  // The client closes the connection just after the server got the client's HEADERS
  4923  // frame, but before the server sends its HEADERS response back. The expected
  4924  // result is an error on RoundTrip explaining the client closed the connection.
  4925  func TestClientConnCloseAtHeaders(t *testing.T) {
  4926  	testClientConnClose(t, closeAtHeaders)
  4927  }
  4928  
  4929  // The client closes the connection between two server's response DATA frames.
  4930  // The expected behavior is a response body io read error on the client.
  4931  func TestClientConnCloseAtBody(t *testing.T) {
  4932  	testClientConnClose(t, closeAtBody)
  4933  }
  4934  
  4935  // The client sends a GOAWAY frame before the server finished processing a request.
  4936  // We expect the connection not to close until the request is completed.
  4937  func TestClientConnShutdown(t *testing.T) {
  4938  	testClientConnClose(t, shutdown)
  4939  }
  4940  
  4941  // The client sends a GOAWAY frame before the server finishes processing a request,
  4942  // but cancels the passed context before the request is completed. The expected
  4943  // behavior is the client closing the connection after the context is canceled.
  4944  func TestClientConnShutdownCancel(t *testing.T) {
  4945  	testClientConnClose(t, shutdownCancel)
  4946  }
  4947  
  4948  // Issue 25009: use Request.GetBody if present, even if it seems like
  4949  // we might not need it. Apparently something else can still read from
  4950  // the original request body. Data race? In any case, rewinding
  4951  // unconditionally on retry is a nicer model anyway and should
  4952  // simplify code in the future (after the Go 1.11 freeze)
  4953  func TestTransportUsesGetBodyWhenPresent(t *testing.T) {
  4954  	calls := 0
  4955  	someBody := func() io.ReadCloser {
  4956  		return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))}
  4957  	}
  4958  	req := &http.Request{
  4959  		Body: someBody(),
  4960  		GetBody: func() (io.ReadCloser, error) {
  4961  			calls++
  4962  			return someBody(), nil
  4963  		},
  4964  	}
  4965  
  4966  	req2, err := shouldRetryRequest(req, errClientConnUnusable)
  4967  	if err != nil {
  4968  		t.Fatal(err)
  4969  	}
  4970  	if calls != 1 {
  4971  		t.Errorf("Calls = %d; want 1", calls)
  4972  	}
  4973  	if req2 == req {
  4974  		t.Error("req2 changed")
  4975  	}
  4976  	if req2 == nil {
  4977  		t.Fatal("req2 is nil")
  4978  	}
  4979  	if req2.Body == nil {
  4980  		t.Fatal("req2.Body is nil")
  4981  	}
  4982  	if req2.GetBody == nil {
  4983  		t.Fatal("req2.GetBody is nil")
  4984  	}
  4985  	if req2.Body == req.Body {
  4986  		t.Error("req2.Body unchanged")
  4987  	}
  4988  }
  4989  
  4990  // Issue 22891: verify that the "https" altproto we register with net/http
  4991  // is a certain type: a struct with one field with our *http2.Transport in it.
  4992  func TestNoDialH2RoundTripperType(t *testing.T) {
  4993  	t1 := new(http.Transport)
  4994  	t2 := new(Transport)
  4995  	rt := noDialH2RoundTripper{t2}
  4996  	if err := registerHTTPSProtocol(t1, rt); err != nil {
  4997  		t.Fatal(err)
  4998  	}
  4999  	rv := reflect.ValueOf(rt)
  5000  	if rv.Type().Kind() != reflect.Struct {
  5001  		t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind())
  5002  	}
  5003  	if n := rv.Type().NumField(); n != 1 {
  5004  		t.Fatalf("fields = %d; net/http expects 1", n)
  5005  	}
  5006  	v := rv.Field(0)
  5007  	if _, ok := v.Interface().(*Transport); !ok {
  5008  		t.Fatalf("wrong kind %T; want *Transport", v.Interface())
  5009  	}
  5010  }
  5011  
  5012  type errReader struct {
  5013  	body []byte
  5014  	err  error
  5015  }
  5016  
  5017  func (r *errReader) Read(p []byte) (int, error) {
  5018  	if len(r.body) > 0 {
  5019  		n := copy(p, r.body)
  5020  		r.body = r.body[n:]
  5021  		return n, nil
  5022  	}
  5023  	return 0, r.err
  5024  }
  5025  
  5026  func testTransportBodyReadError(t *testing.T, body []byte) {
  5027  	if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
  5028  		// So far we've only seen this be flaky on Windows and Plan 9,
  5029  		// perhaps due to TCP behavior on shutdowns while
  5030  		// unread data is in flight. This test should be
  5031  		// fixed, but a skip is better than annoying people
  5032  		// for now.
  5033  		t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS)
  5034  	}
  5035  	clientDone := make(chan struct{})
  5036  	ct := newClientTester(t)
  5037  	ct.client = func() error {
  5038  		defer ct.cc.(*net.TCPConn).CloseWrite()
  5039  		if runtime.GOOS == "plan9" {
  5040  			// CloseWrite not supported on Plan 9; Issue 17906
  5041  			defer ct.cc.(*net.TCPConn).Close()
  5042  		}
  5043  		defer close(clientDone)
  5044  
  5045  		checkNoStreams := func() error {
  5046  			cp, ok := ct.tr.connPool().(*clientConnPool)
  5047  			if !ok {
  5048  				return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
  5049  			}
  5050  			cp.mu.Lock()
  5051  			defer cp.mu.Unlock()
  5052  			conns, ok := cp.conns["dummy.tld:443"]
  5053  			if !ok {
  5054  				return fmt.Errorf("missing connection")
  5055  			}
  5056  			if len(conns) != 1 {
  5057  				return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
  5058  			}
  5059  			if activeStreams(conns[0]) != 0 {
  5060  				return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
  5061  			}
  5062  			return nil
  5063  		}
  5064  		bodyReadError := errors.New("body read error")
  5065  		body := &errReader{body, bodyReadError}
  5066  		req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
  5067  		if err != nil {
  5068  			return err
  5069  		}
  5070  		_, err = ct.tr.RoundTrip(req)
  5071  		if err != bodyReadError {
  5072  			return fmt.Errorf("err = %v; want %v", err, bodyReadError)
  5073  		}
  5074  		if err = checkNoStreams(); err != nil {
  5075  			return err
  5076  		}
  5077  		return nil
  5078  	}
  5079  	ct.server = func() error {
  5080  		ct.greet()
  5081  		var receivedBody []byte
  5082  		var resetCount int
  5083  		for {
  5084  			f, err := ct.fr.ReadFrame()
  5085  			t.Logf("server: ReadFrame = %v, %v", f, err)
  5086  			if err != nil {
  5087  				select {
  5088  				case <-clientDone:
  5089  					// If the client's done, it
  5090  					// will have reported any
  5091  					// errors on its side.
  5092  					if bytes.Compare(receivedBody, body) != 0 {
  5093  						return fmt.Errorf("body: %q; expected %q", receivedBody, body)
  5094  					}
  5095  					if resetCount != 1 {
  5096  						return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
  5097  					}
  5098  					return nil
  5099  				default:
  5100  					return err
  5101  				}
  5102  			}
  5103  			switch f := f.(type) {
  5104  			case *WindowUpdateFrame, *SettingsFrame:
  5105  			case *HeadersFrame:
  5106  			case *DataFrame:
  5107  				receivedBody = append(receivedBody, f.Data()...)
  5108  			case *RSTStreamFrame:
  5109  				resetCount++
  5110  			default:
  5111  				return fmt.Errorf("Unexpected client frame %v", f)
  5112  			}
  5113  		}
  5114  	}
  5115  	ct.run()
  5116  }
  5117  
  5118  func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
  5119  func TestTransportBodyReadError_Some(t *testing.T)        { testTransportBodyReadError(t, []byte("123")) }
  5120  
  5121  // Issue 32254: verify that the client sends END_STREAM flag eagerly with the last
  5122  // (or in this test-case the only one) request body data frame, and does not send
  5123  // extra zero-len data frames.
  5124  func TestTransportBodyEagerEndStream(t *testing.T) {
  5125  	const reqBody = "some request body"
  5126  	const resBody = "some response body"
  5127  
  5128  	ct := newClientTester(t)
  5129  	ct.client = func() error {
  5130  		defer ct.cc.(*net.TCPConn).CloseWrite()
  5131  		if runtime.GOOS == "plan9" {
  5132  			// CloseWrite not supported on Plan 9; Issue 17906
  5133  			defer ct.cc.(*net.TCPConn).Close()
  5134  		}
  5135  		body := strings.NewReader(reqBody)
  5136  		req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
  5137  		if err != nil {
  5138  			return err
  5139  		}
  5140  		_, err = ct.tr.RoundTrip(req)
  5141  		if err != nil {
  5142  			return err
  5143  		}
  5144  		return nil
  5145  	}
  5146  	ct.server = func() error {
  5147  		ct.greet()
  5148  
  5149  		for {
  5150  			f, err := ct.fr.ReadFrame()
  5151  			if err != nil {
  5152  				return err
  5153  			}
  5154  
  5155  			switch f := f.(type) {
  5156  			case *WindowUpdateFrame, *SettingsFrame:
  5157  			case *HeadersFrame:
  5158  			case *DataFrame:
  5159  				if !f.StreamEnded() {
  5160  					ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
  5161  					return fmt.Errorf("data frame without END_STREAM %v", f)
  5162  				}
  5163  				var buf bytes.Buffer
  5164  				enc := hpack.NewEncoder(&buf)
  5165  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  5166  				ct.fr.WriteHeaders(HeadersFrameParam{
  5167  					StreamID:      f.Header().StreamID,
  5168  					EndHeaders:    true,
  5169  					EndStream:     false,
  5170  					BlockFragment: buf.Bytes(),
  5171  				})
  5172  				ct.fr.WriteData(f.StreamID, true, []byte(resBody))
  5173  				return nil
  5174  			case *RSTStreamFrame:
  5175  			default:
  5176  				return fmt.Errorf("Unexpected client frame %v", f)
  5177  			}
  5178  		}
  5179  	}
  5180  	ct.run()
  5181  }
  5182  
  5183  type chunkReader struct {
  5184  	chunks [][]byte
  5185  }
  5186  
  5187  func (r *chunkReader) Read(p []byte) (int, error) {
  5188  	if len(r.chunks) > 0 {
  5189  		n := copy(p, r.chunks[0])
  5190  		r.chunks = r.chunks[1:]
  5191  		return n, nil
  5192  	}
  5193  	panic("shouldn't read this many times")
  5194  }
  5195  
  5196  // Issue 32254: if the request body is larger than the specified
  5197  // content length, the client should refuse to send the extra part
  5198  // and abort the stream.
  5199  //
  5200  // In _len3 case, the first Read() matches the expected content length
  5201  // but the second read returns more data.
  5202  //
  5203  // In _len2 case, the first Read() exceeds the expected content length.
  5204  func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) {
  5205  	body := &chunkReader{[][]byte{
  5206  		[]byte("123"),
  5207  		[]byte("456"),
  5208  	}}
  5209  	testTransportBodyLargerThanSpecifiedContentLength(t, body, 3)
  5210  }
  5211  
  5212  func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
  5213  	body := &chunkReader{[][]byte{
  5214  		[]byte("123"),
  5215  	}}
  5216  	testTransportBodyLargerThanSpecifiedContentLength(t, body, 2)
  5217  }
  5218  
  5219  func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) {
  5220  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5221  		r.Body.Read(make([]byte, 6))
  5222  	}, optOnlyServer)
  5223  	defer st.Close()
  5224  
  5225  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5226  	defer tr.CloseIdleConnections()
  5227  
  5228  	req, _ := http.NewRequest("POST", st.ts.URL, body)
  5229  	req.ContentLength = contentLen
  5230  	_, err := tr.RoundTrip(req)
  5231  	if err != errReqBodyTooLong {
  5232  		t.Fatalf("expected %v, got %v", errReqBodyTooLong, err)
  5233  	}
  5234  }
  5235  
  5236  func TestClientConnTooIdle(t *testing.T) {
  5237  	tests := []struct {
  5238  		cc   func() *ClientConn
  5239  		want bool
  5240  	}{
  5241  		{
  5242  			func() *ClientConn {
  5243  				return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
  5244  			},
  5245  			true,
  5246  		},
  5247  		{
  5248  			func() *ClientConn {
  5249  				return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}}
  5250  			},
  5251  			false,
  5252  		},
  5253  		{
  5254  			func() *ClientConn {
  5255  				return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
  5256  			},
  5257  			false,
  5258  		},
  5259  		{
  5260  			func() *ClientConn {
  5261  				return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)}
  5262  			},
  5263  			false,
  5264  		},
  5265  	}
  5266  	for i, tt := range tests {
  5267  		got := tt.cc().tooIdleLocked()
  5268  		if got != tt.want {
  5269  			t.Errorf("%d. got %v; want %v", i, got, tt.want)
  5270  		}
  5271  	}
  5272  }
  5273  
  5274  type fakeConnErr struct {
  5275  	net.Conn
  5276  	writeErr error
  5277  	closed   bool
  5278  }
  5279  
  5280  func (fce *fakeConnErr) Write(b []byte) (n int, err error) {
  5281  	return 0, fce.writeErr
  5282  }
  5283  
  5284  func (fce *fakeConnErr) Close() error {
  5285  	fce.closed = true
  5286  	return nil
  5287  }
  5288  
  5289  // issue 39337: close the connection on a failed write
  5290  func TestTransportNewClientConnCloseOnWriteError(t *testing.T) {
  5291  	tr := &Transport{}
  5292  	writeErr := errors.New("write error")
  5293  	fakeConn := &fakeConnErr{writeErr: writeErr}
  5294  	_, err := tr.NewClientConn(fakeConn)
  5295  	if err != writeErr {
  5296  		t.Fatalf("expected %v, got %v", writeErr, err)
  5297  	}
  5298  	if !fakeConn.closed {
  5299  		t.Error("expected closed conn")
  5300  	}
  5301  }
  5302  
  5303  func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
  5304  	req, err := http.NewRequest("GET", "https://dummy.tld/", nil)
  5305  	if err != nil {
  5306  		t.Fatal(err)
  5307  	}
  5308  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
  5309  	defer st.Close()
  5310  
  5311  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5312  	defer tr.CloseIdleConnections()
  5313  	ctx := context.Background()
  5314  	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
  5315  	if err != nil {
  5316  		t.Fatal(err)
  5317  	}
  5318  
  5319  	writeErr := errors.New("write error")
  5320  	cc.wmu.Lock()
  5321  	cc.werr = writeErr
  5322  	cc.wmu.Unlock()
  5323  
  5324  	_, err = cc.RoundTrip(req)
  5325  	if err != writeErr {
  5326  		t.Fatalf("expected %v, got %v", writeErr, err)
  5327  	}
  5328  
  5329  	cc.mu.Lock()
  5330  	closed := cc.closed
  5331  	cc.mu.Unlock()
  5332  	if !closed {
  5333  		t.Fatal("expected closed")
  5334  	}
  5335  }
  5336  
  5337  // Issue 31192: A failed request may be retried if the body has not been read
  5338  // already. If the request body has started to be sent, one must wait until it
  5339  // is completed.
  5340  func TestTransportBodyRewindRace(t *testing.T) {
  5341  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5342  		w.Header().Set("Connection", "close")
  5343  		w.WriteHeader(http.StatusOK)
  5344  		return
  5345  	}, optOnlyServer)
  5346  	defer st.Close()
  5347  
  5348  	tr := &http.Transport{
  5349  		TLSClientConfig: tlsConfigInsecure,
  5350  		MaxConnsPerHost: 1,
  5351  	}
  5352  	err := ConfigureTransport(tr)
  5353  	if err != nil {
  5354  		t.Fatal(err)
  5355  	}
  5356  	client := &http.Client{
  5357  		Transport: tr,
  5358  	}
  5359  
  5360  	const clients = 50
  5361  
  5362  	var wg sync.WaitGroup
  5363  	wg.Add(clients)
  5364  	for i := 0; i < clients; i++ {
  5365  		req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef"))
  5366  		if err != nil {
  5367  			t.Fatalf("unexpect new request error: %v", err)
  5368  		}
  5369  
  5370  		go func() {
  5371  			defer wg.Done()
  5372  			res, err := client.Do(req)
  5373  			if err == nil {
  5374  				res.Body.Close()
  5375  			}
  5376  		}()
  5377  	}
  5378  
  5379  	wg.Wait()
  5380  }
  5381  
  5382  // Issue 42498: A request with a body will never be sent if the stream is
  5383  // reset prior to sending any data.
  5384  func TestTransportServerResetStreamAtHeaders(t *testing.T) {
  5385  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5386  		w.WriteHeader(http.StatusUnauthorized)
  5387  		return
  5388  	}, optOnlyServer)
  5389  	defer st.Close()
  5390  
  5391  	tr := &http.Transport{
  5392  		TLSClientConfig:       tlsConfigInsecure,
  5393  		MaxConnsPerHost:       1,
  5394  		ExpectContinueTimeout: 10 * time.Second,
  5395  	}
  5396  
  5397  	err := ConfigureTransport(tr)
  5398  	if err != nil {
  5399  		t.Fatal(err)
  5400  	}
  5401  	client := &http.Client{
  5402  		Transport: tr,
  5403  	}
  5404  
  5405  	req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF})
  5406  	if err != nil {
  5407  		t.Fatalf("unexpect new request error: %v", err)
  5408  	}
  5409  	req.ContentLength = 0 // so transport is tempted to sniff it
  5410  	req.Header.Set("Expect", "100-continue")
  5411  	res, err := client.Do(req)
  5412  	if err != nil {
  5413  		t.Fatal(err)
  5414  	}
  5415  	res.Body.Close()
  5416  }
  5417  
  5418  type trackingReader struct {
  5419  	rdr     io.Reader
  5420  	wasRead uint32
  5421  }
  5422  
  5423  func (tr *trackingReader) Read(p []byte) (int, error) {
  5424  	atomic.StoreUint32(&tr.wasRead, 1)
  5425  	return tr.rdr.Read(p)
  5426  }
  5427  
  5428  func (tr *trackingReader) WasRead() bool {
  5429  	return atomic.LoadUint32(&tr.wasRead) != 0
  5430  }
  5431  
  5432  func TestTransportExpectContinue(t *testing.T) {
  5433  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5434  		switch r.URL.Path {
  5435  		case "/reject":
  5436  			w.WriteHeader(403)
  5437  		default:
  5438  			io.Copy(io.Discard, r.Body)
  5439  		}
  5440  	}, optOnlyServer)
  5441  	defer st.Close()
  5442  
  5443  	tr := &http.Transport{
  5444  		TLSClientConfig:       tlsConfigInsecure,
  5445  		MaxConnsPerHost:       1,
  5446  		ExpectContinueTimeout: 10 * time.Second,
  5447  	}
  5448  
  5449  	err := ConfigureTransport(tr)
  5450  	if err != nil {
  5451  		t.Fatal(err)
  5452  	}
  5453  	client := &http.Client{
  5454  		Transport: tr,
  5455  	}
  5456  
  5457  	testCases := []struct {
  5458  		Name         string
  5459  		Path         string
  5460  		Body         *trackingReader
  5461  		ExpectedCode int
  5462  		ShouldRead   bool
  5463  	}{
  5464  		{
  5465  			Name:         "read-all",
  5466  			Path:         "/",
  5467  			Body:         &trackingReader{rdr: strings.NewReader("hello")},
  5468  			ExpectedCode: 200,
  5469  			ShouldRead:   true,
  5470  		},
  5471  		{
  5472  			Name:         "reject",
  5473  			Path:         "/reject",
  5474  			Body:         &trackingReader{rdr: strings.NewReader("hello")},
  5475  			ExpectedCode: 403,
  5476  			ShouldRead:   false,
  5477  		},
  5478  	}
  5479  
  5480  	for _, tc := range testCases {
  5481  		t.Run(tc.Name, func(t *testing.T) {
  5482  			startTime := time.Now()
  5483  
  5484  			req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body)
  5485  			if err != nil {
  5486  				t.Fatal(err)
  5487  			}
  5488  			req.Header.Set("Expect", "100-continue")
  5489  			res, err := client.Do(req)
  5490  			if err != nil {
  5491  				t.Fatal(err)
  5492  			}
  5493  			res.Body.Close()
  5494  
  5495  			if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout {
  5496  				t.Error("Request didn't finish before expect continue timeout")
  5497  			}
  5498  			if res.StatusCode != tc.ExpectedCode {
  5499  				t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode)
  5500  			}
  5501  			if tc.Body.WasRead() != tc.ShouldRead {
  5502  				t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead)
  5503  			}
  5504  		})
  5505  	}
  5506  }
  5507  
  5508  type closeChecker struct {
  5509  	io.ReadCloser
  5510  	closed chan struct{}
  5511  }
  5512  
  5513  func newCloseChecker(r io.ReadCloser) *closeChecker {
  5514  	return &closeChecker{r, make(chan struct{})}
  5515  }
  5516  
  5517  func newStaticCloseChecker(body string) *closeChecker {
  5518  	return newCloseChecker(io.NopCloser(strings.NewReader("body")))
  5519  }
  5520  
  5521  func (rc *closeChecker) Read(b []byte) (n int, err error) {
  5522  	select {
  5523  	default:
  5524  	case <-rc.closed:
  5525  		// TODO(dneil): Consider restructuring the request write to avoid reading
  5526  		// from the request body after closing it, and check for read-after-close here.
  5527  		// Currently, abortRequestBodyWrite races with writeRequestBody.
  5528  		return 0, errors.New("read after Body.Close")
  5529  	}
  5530  	return rc.ReadCloser.Read(b)
  5531  }
  5532  
  5533  func (rc *closeChecker) Close() error {
  5534  	close(rc.closed)
  5535  	return rc.ReadCloser.Close()
  5536  }
  5537  
  5538  func (rc *closeChecker) isClosed() error {
  5539  	// The RoundTrip contract says that it will close the request body,
  5540  	// but that it may do so in a separate goroutine. Wait a reasonable
  5541  	// amount of time before concluding that the body isn't being closed.
  5542  	timeout := time.Duration(10 * time.Second)
  5543  	select {
  5544  	case <-rc.closed:
  5545  	case <-time.After(timeout):
  5546  		return fmt.Errorf("body not closed after %v", timeout)
  5547  	}
  5548  	return nil
  5549  }
  5550  
  5551  // A blockingWriteConn is a net.Conn that blocks in Write after some number of bytes are written.
  5552  type blockingWriteConn struct {
  5553  	net.Conn
  5554  	writeOnce    sync.Once
  5555  	writec       chan struct{} // closed after the write limit is reached
  5556  	unblockc     chan struct{} // closed to unblock writes
  5557  	count, limit int
  5558  }
  5559  
  5560  func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn {
  5561  	return &blockingWriteConn{
  5562  		Conn:     conn,
  5563  		limit:    limit,
  5564  		writec:   make(chan struct{}),
  5565  		unblockc: make(chan struct{}),
  5566  	}
  5567  }
  5568  
  5569  // wait waits until the conn blocks writing the limit+1st byte.
  5570  func (c *blockingWriteConn) wait() {
  5571  	<-c.writec
  5572  }
  5573  
  5574  // unblock unblocks writes to the conn.
  5575  func (c *blockingWriteConn) unblock() {
  5576  	close(c.unblockc)
  5577  }
  5578  
  5579  func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
  5580  	if c.count+len(b) > c.limit {
  5581  		c.writeOnce.Do(func() {
  5582  			close(c.writec)
  5583  		})
  5584  		<-c.unblockc
  5585  	}
  5586  	n, err = c.Conn.Write(b)
  5587  	c.count += n
  5588  	return n, err
  5589  }
  5590  
  5591  // Write several requests to a ClientConn at the same time, looking for race conditions.
  5592  // See golang.org/issue/48340
  5593  func TestTransportFrameBufferReuse(t *testing.T) {
  5594  	filler := hex.EncodeToString([]byte(randString(2048)))
  5595  
  5596  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5597  		if got, want := r.Header.Get("Big"), filler; got != want {
  5598  			t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
  5599  		}
  5600  		b, err := ioutil.ReadAll(r.Body)
  5601  		if err != nil {
  5602  			t.Errorf("error reading request body: %v", err)
  5603  		}
  5604  		if got, want := string(b), filler; got != want {
  5605  			t.Errorf("request body = %q, want %q", got, want)
  5606  		}
  5607  		if got, want := r.Trailer.Get("Big"), filler; got != want {
  5608  			t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
  5609  		}
  5610  	}, optOnlyServer)
  5611  	defer st.Close()
  5612  
  5613  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5614  	defer tr.CloseIdleConnections()
  5615  
  5616  	var wg sync.WaitGroup
  5617  	defer wg.Wait()
  5618  	for i := 0; i < 10; i++ {
  5619  		wg.Add(1)
  5620  		go func() {
  5621  			defer wg.Done()
  5622  			req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler))
  5623  			if err != nil {
  5624  				t.Error(err)
  5625  				return
  5626  			}
  5627  			req.Header.Set("Big", filler)
  5628  			req.Trailer = make(http.Header)
  5629  			req.Trailer.Set("Big", filler)
  5630  			res, err := tr.RoundTrip(req)
  5631  			if err != nil {
  5632  				t.Error(err)
  5633  				return
  5634  			}
  5635  			if got, want := res.StatusCode, 200; got != want {
  5636  				t.Errorf("StatusCode = %v; want %v", got, want)
  5637  			}
  5638  			if res != nil && res.Body != nil {
  5639  				res.Body.Close()
  5640  			}
  5641  		}()
  5642  	}
  5643  
  5644  }
  5645  
  5646  // Ensure that a request blocking while being written to the underlying net.Conn doesn't
  5647  // block access to the ClientConn pool. Test requests blocking while writing headers, the body,
  5648  // and trailers.
  5649  // See golang.org/issue/32388
  5650  func TestTransportBlockingRequestWrite(t *testing.T) {
  5651  	filler := hex.EncodeToString([]byte(randString(2048)))
  5652  	for _, test := range []struct {
  5653  		name string
  5654  		req  func(url string) (*http.Request, error)
  5655  	}{{
  5656  		name: "headers",
  5657  		req: func(url string) (*http.Request, error) {
  5658  			req, err := http.NewRequest("POST", url, nil)
  5659  			if err != nil {
  5660  				return nil, err
  5661  			}
  5662  			req.Header.Set("Big", filler)
  5663  			return req, err
  5664  		},
  5665  	}, {
  5666  		name: "body",
  5667  		req: func(url string) (*http.Request, error) {
  5668  			req, err := http.NewRequest("POST", url, strings.NewReader(filler))
  5669  			if err != nil {
  5670  				return nil, err
  5671  			}
  5672  			return req, err
  5673  		},
  5674  	}, {
  5675  		name: "trailer",
  5676  		req: func(url string) (*http.Request, error) {
  5677  			req, err := http.NewRequest("POST", url, strings.NewReader("body"))
  5678  			if err != nil {
  5679  				return nil, err
  5680  			}
  5681  			req.Trailer = make(http.Header)
  5682  			req.Trailer.Set("Big", filler)
  5683  			return req, err
  5684  		},
  5685  	}} {
  5686  		test := test
  5687  		t.Run(test.name, func(t *testing.T) {
  5688  			st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5689  				if v := r.Header.Get("Big"); v != "" && v != filler {
  5690  					t.Errorf("request header mismatch")
  5691  				}
  5692  				if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler {
  5693  					t.Errorf("request body mismatch\ngot:  %q\nwant: %q", string(v), filler)
  5694  				}
  5695  				if v := r.Trailer.Get("Big"); v != "" && v != filler {
  5696  					t.Errorf("request trailer mismatch\ngot:  %q\nwant: %q", string(v), filler)
  5697  				}
  5698  			}, optOnlyServer, func(s *Server) {
  5699  				s.MaxConcurrentStreams = 1
  5700  			})
  5701  			defer st.Close()
  5702  
  5703  			// This Transport creates connections that block on writes after 1024 bytes.
  5704  			connc := make(chan *blockingWriteConn, 1)
  5705  			connCount := 0
  5706  			tr := &Transport{
  5707  				TLSClientConfig: tlsConfigInsecure,
  5708  				DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  5709  					connCount++
  5710  					c, err := tls.Dial(network, addr, cfg)
  5711  					wc := newBlockingWriteConn(c, 1024)
  5712  					select {
  5713  					case connc <- wc:
  5714  					default:
  5715  					}
  5716  					return wc, err
  5717  				},
  5718  			}
  5719  			defer tr.CloseIdleConnections()
  5720  
  5721  			// Request 1: A small request to ensure we read the server MaxConcurrentStreams.
  5722  			{
  5723  				req, err := http.NewRequest("POST", st.ts.URL, nil)
  5724  				if err != nil {
  5725  					t.Fatal(err)
  5726  				}
  5727  				res, err := tr.RoundTrip(req)
  5728  				if err != nil {
  5729  					t.Fatal(err)
  5730  				}
  5731  				if got, want := res.StatusCode, 200; got != want {
  5732  					t.Errorf("StatusCode = %v; want %v", got, want)
  5733  				}
  5734  				if res != nil && res.Body != nil {
  5735  					res.Body.Close()
  5736  				}
  5737  			}
  5738  
  5739  			// Request 2: A large request that blocks while being written.
  5740  			reqc := make(chan struct{})
  5741  			go func() {
  5742  				defer close(reqc)
  5743  				req, err := test.req(st.ts.URL)
  5744  				if err != nil {
  5745  					t.Error(err)
  5746  					return
  5747  				}
  5748  				res, _ := tr.RoundTrip(req)
  5749  				if res != nil && res.Body != nil {
  5750  					res.Body.Close()
  5751  				}
  5752  			}()
  5753  			conn := <-connc
  5754  			conn.wait() // wait for the request to block
  5755  
  5756  			// Request 3: A small request that is sent on a new connection, since request 2
  5757  			// is hogging the only available stream on the previous connection.
  5758  			{
  5759  				req, err := http.NewRequest("POST", st.ts.URL, nil)
  5760  				if err != nil {
  5761  					t.Fatal(err)
  5762  				}
  5763  				res, err := tr.RoundTrip(req)
  5764  				if err != nil {
  5765  					t.Fatal(err)
  5766  				}
  5767  				if got, want := res.StatusCode, 200; got != want {
  5768  					t.Errorf("StatusCode = %v; want %v", got, want)
  5769  				}
  5770  				if res != nil && res.Body != nil {
  5771  					res.Body.Close()
  5772  				}
  5773  			}
  5774  
  5775  			// Request 2 should still be blocking at this point.
  5776  			select {
  5777  			case <-reqc:
  5778  				t.Errorf("request 2 unexpectedly completed")
  5779  			default:
  5780  			}
  5781  
  5782  			conn.unblock()
  5783  			<-reqc
  5784  
  5785  			if connCount != 2 {
  5786  				t.Errorf("created %v connections, want 1", connCount)
  5787  			}
  5788  		})
  5789  	}
  5790  }
  5791  
  5792  func TestTransportCloseRequestBody(t *testing.T) {
  5793  	var statusCode int
  5794  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5795  		w.WriteHeader(statusCode)
  5796  	}, optOnlyServer)
  5797  	defer st.Close()
  5798  
  5799  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5800  	defer tr.CloseIdleConnections()
  5801  	ctx := context.Background()
  5802  	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
  5803  	if err != nil {
  5804  		t.Fatal(err)
  5805  	}
  5806  
  5807  	for _, status := range []int{200, 401} {
  5808  		t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
  5809  			statusCode = status
  5810  			pr, pw := io.Pipe()
  5811  			body := newCloseChecker(pr)
  5812  			req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
  5813  			if err != nil {
  5814  				t.Fatal(err)
  5815  			}
  5816  			res, err := cc.RoundTrip(req)
  5817  			if err != nil {
  5818  				t.Fatal(err)
  5819  			}
  5820  			res.Body.Close()
  5821  			pw.Close()
  5822  			if err := body.isClosed(); err != nil {
  5823  				t.Fatal(err)
  5824  			}
  5825  		})
  5826  	}
  5827  }
  5828  
  5829  // collectClientsConnPool is a ClientConnPool that wraps lower and
  5830  // collects what calls were made on it.
  5831  type collectClientsConnPool struct {
  5832  	lower ClientConnPool
  5833  
  5834  	mu      sync.Mutex
  5835  	getErrs int
  5836  	got     []*ClientConn
  5837  }
  5838  
  5839  func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
  5840  	cc, err := p.lower.GetClientConn(req, addr)
  5841  	p.mu.Lock()
  5842  	defer p.mu.Unlock()
  5843  	if err != nil {
  5844  		p.getErrs++
  5845  		return nil, err
  5846  	}
  5847  	p.got = append(p.got, cc)
  5848  	return cc, nil
  5849  }
  5850  
  5851  func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
  5852  	p.lower.MarkDead(cc)
  5853  }
  5854  
  5855  func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
  5856  	ct := newClientTester(t)
  5857  	pool := &collectClientsConnPool{
  5858  		lower: &clientConnPool{t: ct.tr},
  5859  	}
  5860  	ct.tr.ConnPool = pool
  5861  
  5862  	gotProtoError := make(chan bool, 1)
  5863  	ct.tr.CountError = func(errType string) {
  5864  		if errType == "recv_rststream_PROTOCOL_ERROR" {
  5865  			select {
  5866  			case gotProtoError <- true:
  5867  			default:
  5868  			}
  5869  		}
  5870  	}
  5871  	ct.client = func() error {
  5872  		// Start two requests. The first is a long request
  5873  		// that will finish after the second. The second one
  5874  		// will result in the protocol error.  We check that
  5875  		// after the first one closes, the connection then
  5876  		// shuts down.
  5877  
  5878  		// The long, outer request.
  5879  		req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil)
  5880  		res1, err := ct.tr.RoundTrip(req1)
  5881  		if err != nil {
  5882  			return err
  5883  		}
  5884  		if got, want := res1.Header.Get("Is-Long"), "1"; got != want {
  5885  			return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want)
  5886  		}
  5887  
  5888  		req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil)
  5889  		res, err := ct.tr.RoundTrip(req)
  5890  		const want = "only one dial allowed in test mode"
  5891  		if got := fmt.Sprint(err); got != want {
  5892  			t.Errorf("didn't dial again: got %#q; want %#q", got, want)
  5893  		}
  5894  		if res != nil {
  5895  			res.Body.Close()
  5896  		}
  5897  		select {
  5898  		case <-gotProtoError:
  5899  		default:
  5900  			t.Errorf("didn't get stream protocol error")
  5901  		}
  5902  
  5903  		if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 {
  5904  			t.Errorf("unexpected body read %v, %v", n, err)
  5905  		}
  5906  
  5907  		pool.mu.Lock()
  5908  		defer pool.mu.Unlock()
  5909  		if pool.getErrs != 1 {
  5910  			t.Errorf("pool get errors = %v; want 1", pool.getErrs)
  5911  		}
  5912  		if len(pool.got) == 2 {
  5913  			if pool.got[0] != pool.got[1] {
  5914  				t.Errorf("requests went on different connections")
  5915  			}
  5916  			cc := pool.got[0]
  5917  			cc.mu.Lock()
  5918  			if !cc.doNotReuse {
  5919  				t.Error("ClientConn not marked doNotReuse")
  5920  			}
  5921  			cc.mu.Unlock()
  5922  
  5923  			select {
  5924  			case <-cc.readerDone:
  5925  			case <-time.After(5 * time.Second):
  5926  				t.Errorf("timeout waiting for reader to be done")
  5927  			}
  5928  		} else {
  5929  			t.Errorf("pool get success = %v; want 2", len(pool.got))
  5930  		}
  5931  		return nil
  5932  	}
  5933  	ct.server = func() error {
  5934  		ct.greet()
  5935  		var sentErr bool
  5936  		var numHeaders int
  5937  		var firstStreamID uint32
  5938  
  5939  		var hbuf bytes.Buffer
  5940  		enc := hpack.NewEncoder(&hbuf)
  5941  
  5942  		for {
  5943  			f, err := ct.fr.ReadFrame()
  5944  			if err == io.EOF {
  5945  				// Client hung up on us, as it should at the end.
  5946  				return nil
  5947  			}
  5948  			if err != nil {
  5949  				return nil
  5950  			}
  5951  			switch f := f.(type) {
  5952  			case *WindowUpdateFrame, *SettingsFrame:
  5953  			case *HeadersFrame:
  5954  				numHeaders++
  5955  				if numHeaders == 1 {
  5956  					firstStreamID = f.StreamID
  5957  					hbuf.Reset()
  5958  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  5959  					enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"})
  5960  					ct.fr.WriteHeaders(HeadersFrameParam{
  5961  						StreamID:      f.StreamID,
  5962  						EndHeaders:    true,
  5963  						EndStream:     false,
  5964  						BlockFragment: hbuf.Bytes(),
  5965  					})
  5966  					continue
  5967  				}
  5968  				if !sentErr {
  5969  					sentErr = true
  5970  					ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol)
  5971  					ct.fr.WriteData(firstStreamID, true, nil)
  5972  					continue
  5973  				}
  5974  			}
  5975  		}
  5976  	}
  5977  	ct.run()
  5978  }
  5979  
  5980  func TestClientConnReservations(t *testing.T) {
  5981  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5982  	}, func(s *Server) {
  5983  		s.MaxConcurrentStreams = initialMaxConcurrentStreams
  5984  	})
  5985  	defer st.Close()
  5986  
  5987  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5988  	defer tr.CloseIdleConnections()
  5989  
  5990  	cc, err := tr.newClientConn(st.cc, false)
  5991  	if err != nil {
  5992  		t.Fatal(err)
  5993  	}
  5994  
  5995  	req, _ := http.NewRequest("GET", st.ts.URL, nil)
  5996  	n := 0
  5997  	for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
  5998  		n++
  5999  	}
  6000  	if n != initialMaxConcurrentStreams {
  6001  		t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
  6002  	}
  6003  	if _, err := cc.RoundTrip(req); err != nil {
  6004  		t.Fatalf("RoundTrip error = %v", err)
  6005  	}
  6006  	n2 := 0
  6007  	for n2 <= 5 && cc.ReserveNewRequest() {
  6008  		n2++
  6009  	}
  6010  	if n2 != 1 {
  6011  		t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
  6012  	}
  6013  
  6014  	// Use up all the reservations
  6015  	for i := 0; i < n; i++ {
  6016  		cc.RoundTrip(req)
  6017  	}
  6018  
  6019  	n2 = 0
  6020  	for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
  6021  		n2++
  6022  	}
  6023  	if n2 != n {
  6024  		t.Errorf("after reset, reservations = %v; want %v", n2, n)
  6025  	}
  6026  }
  6027  
  6028  func TestTransportTimeoutServerHangs(t *testing.T) {
  6029  	clientDone := make(chan struct{})
  6030  	ct := newClientTester(t)
  6031  	ct.client = func() error {
  6032  		defer ct.cc.(*net.TCPConn).CloseWrite()
  6033  		defer close(clientDone)
  6034  
  6035  		req, err := http.NewRequest("PUT", "https://dummy.tld/", nil)
  6036  		if err != nil {
  6037  			return err
  6038  		}
  6039  
  6040  		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
  6041  		defer cancel()
  6042  		req = req.WithContext(ctx)
  6043  		req.Header.Add("Big", strings.Repeat("a", 1<<20))
  6044  		_, err = ct.tr.RoundTrip(req)
  6045  		if err == nil {
  6046  			return errors.New("error should not be nil")
  6047  		}
  6048  		if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
  6049  			return fmt.Errorf("error should be a net error timeout: %v", err)
  6050  		}
  6051  		return nil
  6052  	}
  6053  	ct.server = func() error {
  6054  		ct.greet()
  6055  		select {
  6056  		case <-time.After(5 * time.Second):
  6057  		case <-clientDone:
  6058  		}
  6059  		return nil
  6060  	}
  6061  	ct.run()
  6062  }
  6063  
  6064  func TestTransportContentLengthWithoutBody(t *testing.T) {
  6065  	contentLength := ""
  6066  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  6067  		w.Header().Set("Content-Length", contentLength)
  6068  	}, optOnlyServer)
  6069  	defer st.Close()
  6070  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  6071  	defer tr.CloseIdleConnections()
  6072  
  6073  	for _, test := range []struct {
  6074  		name              string
  6075  		contentLength     string
  6076  		wantBody          string
  6077  		wantErr           error
  6078  		wantContentLength int64
  6079  	}{
  6080  		{
  6081  			name:              "non-zero content length",
  6082  			contentLength:     "42",
  6083  			wantErr:           io.ErrUnexpectedEOF,
  6084  			wantContentLength: 42,
  6085  		},
  6086  		{
  6087  			name:              "zero content length",
  6088  			contentLength:     "0",
  6089  			wantErr:           nil,
  6090  			wantContentLength: 0,
  6091  		},
  6092  	} {
  6093  		t.Run(test.name, func(t *testing.T) {
  6094  			contentLength = test.contentLength
  6095  
  6096  			req, _ := http.NewRequest("GET", st.ts.URL, nil)
  6097  			res, err := tr.RoundTrip(req)
  6098  			if err != nil {
  6099  				t.Fatal(err)
  6100  			}
  6101  			defer res.Body.Close()
  6102  			body, err := io.ReadAll(res.Body)
  6103  
  6104  			if err != test.wantErr {
  6105  				t.Errorf("Expected error %v, got: %v", test.wantErr, err)
  6106  			}
  6107  			if len(body) > 0 {
  6108  				t.Errorf("Expected empty body, got: %v", body)
  6109  			}
  6110  			if res.ContentLength != test.wantContentLength {
  6111  				t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength)
  6112  			}
  6113  		})
  6114  	}
  6115  }
  6116  
  6117  func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
  6118  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  6119  		w.WriteHeader(200)
  6120  		w.(http.Flusher).Flush()
  6121  		io.Copy(io.Discard, r.Body)
  6122  	}, optOnlyServer)
  6123  	defer st.Close()
  6124  
  6125  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  6126  	defer tr.CloseIdleConnections()
  6127  
  6128  	pr, pw := net.Pipe()
  6129  	req, err := http.NewRequest("GET", st.ts.URL, pr)
  6130  	if err != nil {
  6131  		t.Fatal(err)
  6132  	}
  6133  	res, err := tr.RoundTrip(req)
  6134  	if err != nil {
  6135  		t.Fatal(err)
  6136  	}
  6137  	// Closing the Response's Body interrupts the blocked body read.
  6138  	res.Body.Close()
  6139  	pw.Close()
  6140  }
  6141  
  6142  func TestTransport300ResponseBody(t *testing.T) {
  6143  	reqc := make(chan struct{})
  6144  	body := []byte("response body")
  6145  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  6146  		w.WriteHeader(300)
  6147  		w.(http.Flusher).Flush()
  6148  		<-reqc
  6149  		w.Write(body)
  6150  	}, optOnlyServer)
  6151  	defer st.Close()
  6152  
  6153  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  6154  	defer tr.CloseIdleConnections()
  6155  
  6156  	pr, pw := net.Pipe()
  6157  	req, err := http.NewRequest("GET", st.ts.URL, pr)
  6158  	if err != nil {
  6159  		t.Fatal(err)
  6160  	}
  6161  	res, err := tr.RoundTrip(req)
  6162  	if err != nil {
  6163  		t.Fatal(err)
  6164  	}
  6165  	close(reqc)
  6166  	got, err := io.ReadAll(res.Body)
  6167  	if err != nil {
  6168  		t.Fatalf("error reading response body: %v", err)
  6169  	}
  6170  	if !bytes.Equal(got, body) {
  6171  		t.Errorf("got response body %q, want %q", string(got), string(body))
  6172  	}
  6173  	res.Body.Close()
  6174  	pw.Close()
  6175  }
  6176  
  6177  func TestTransportWriteByteTimeout(t *testing.T) {
  6178  	st := newServerTester(t,
  6179  		func(w http.ResponseWriter, r *http.Request) {},
  6180  		optOnlyServer,
  6181  	)
  6182  	defer st.Close()
  6183  	tr := &Transport{
  6184  		TLSClientConfig: tlsConfigInsecure,
  6185  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  6186  			_, c := net.Pipe()
  6187  			return c, nil
  6188  		},
  6189  		WriteByteTimeout: 1 * time.Millisecond,
  6190  	}
  6191  	defer tr.CloseIdleConnections()
  6192  	c := &http.Client{Transport: tr}
  6193  
  6194  	_, err := c.Get(st.ts.URL)
  6195  	if !errors.Is(err, os.ErrDeadlineExceeded) {
  6196  		t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
  6197  	}
  6198  }
  6199  
  6200  type slowWriteConn struct {
  6201  	net.Conn
  6202  	hasWriteDeadline bool
  6203  }
  6204  
  6205  func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
  6206  	c.hasWriteDeadline = !t.IsZero()
  6207  	return nil
  6208  }
  6209  
  6210  func (c *slowWriteConn) Write(b []byte) (n int, err error) {
  6211  	if c.hasWriteDeadline && len(b) > 1 {
  6212  		n, err = c.Conn.Write(b[:1])
  6213  		if err != nil {
  6214  			return n, err
  6215  		}
  6216  		return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
  6217  	}
  6218  	return c.Conn.Write(b)
  6219  }
  6220  
  6221  func TestTransportSlowWrites(t *testing.T) {
  6222  	st := newServerTester(t,
  6223  		func(w http.ResponseWriter, r *http.Request) {},
  6224  		optOnlyServer,
  6225  	)
  6226  	defer st.Close()
  6227  	tr := &Transport{
  6228  		TLSClientConfig: tlsConfigInsecure,
  6229  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  6230  			cfg.InsecureSkipVerify = true
  6231  			c, err := tls.Dial(network, addr, cfg)
  6232  			return &slowWriteConn{Conn: c}, err
  6233  		},
  6234  		WriteByteTimeout: 1 * time.Millisecond,
  6235  	}
  6236  	defer tr.CloseIdleConnections()
  6237  	c := &http.Client{Transport: tr}
  6238  
  6239  	const bodySize = 1 << 20
  6240  	resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
  6241  	if err != nil {
  6242  		t.Fatal(err)
  6243  	}
  6244  	resp.Body.Close()
  6245  }
  6246  
  6247  func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) {
  6248  	testTransportClosesConnAfterGoAway(t, 0)
  6249  }
  6250  func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
  6251  	testTransportClosesConnAfterGoAway(t, 1)
  6252  }
  6253  
  6254  type closeOnceConn struct {
  6255  	net.Conn
  6256  	closed uint32
  6257  }
  6258  
  6259  var errClosed = errors.New("Close of closed connection")
  6260  
  6261  func (c *closeOnceConn) Close() error {
  6262  	if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
  6263  		return c.Conn.Close()
  6264  	}
  6265  	return errClosed
  6266  }
  6267  
  6268  // testTransportClosesConnAfterGoAway verifies that the transport
  6269  // closes a connection after reading a GOAWAY from it.
  6270  //
  6271  // lastStream is the last stream ID in the GOAWAY frame.
  6272  // When 0, the transport (unsuccessfully) retries the request (stream 1);
  6273  // when 1, the transport reads the response after receiving the GOAWAY.
  6274  func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) {
  6275  	ct := newClientTester(t)
  6276  	ct.cc = &closeOnceConn{Conn: ct.cc}
  6277  
  6278  	var wg sync.WaitGroup
  6279  	wg.Add(1)
  6280  	ct.client = func() error {
  6281  		defer wg.Done()
  6282  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  6283  		res, err := ct.tr.RoundTrip(req)
  6284  		if err == nil {
  6285  			res.Body.Close()
  6286  		}
  6287  		if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
  6288  			t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
  6289  		}
  6290  		if err = ct.cc.Close(); err != errClosed {
  6291  			return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err)
  6292  		}
  6293  		return nil
  6294  	}
  6295  
  6296  	ct.server = func() error {
  6297  		defer wg.Wait()
  6298  		ct.greet()
  6299  		hf, err := ct.firstHeaders()
  6300  		if err != nil {
  6301  			return fmt.Errorf("server failed reading HEADERS: %v", err)
  6302  		}
  6303  		if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil {
  6304  			return fmt.Errorf("server failed writing GOAWAY: %v", err)
  6305  		}
  6306  		if lastStream > 0 {
  6307  			// Send a valid response to first request.
  6308  			var buf bytes.Buffer
  6309  			enc := hpack.NewEncoder(&buf)
  6310  			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  6311  			ct.fr.WriteHeaders(HeadersFrameParam{
  6312  				StreamID:      hf.StreamID,
  6313  				EndHeaders:    true,
  6314  				EndStream:     true,
  6315  				BlockFragment: buf.Bytes(),
  6316  			})
  6317  		}
  6318  		return nil
  6319  	}
  6320  
  6321  	ct.run()
  6322  }
  6323  
  6324  type slowCloser struct {
  6325  	closing chan struct{}
  6326  	closed  chan struct{}
  6327  }
  6328  
  6329  func (r *slowCloser) Read([]byte) (int, error) {
  6330  	return 0, io.EOF
  6331  }
  6332  
  6333  func (r *slowCloser) Close() error {
  6334  	close(r.closing)
  6335  	<-r.closed
  6336  	return nil
  6337  }
  6338  
  6339  func TestTransportSlowClose(t *testing.T) {
  6340  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  6341  	}, optOnlyServer)
  6342  	defer st.Close()
  6343  
  6344  	client := st.ts.Client()
  6345  	body := &slowCloser{
  6346  		closing: make(chan struct{}),
  6347  		closed:  make(chan struct{}),
  6348  	}
  6349  
  6350  	reqc := make(chan struct{})
  6351  	go func() {
  6352  		defer close(reqc)
  6353  		res, err := client.Post(st.ts.URL, "text/plain", body)
  6354  		if err != nil {
  6355  			t.Error(err)
  6356  		}
  6357  		res.Body.Close()
  6358  	}()
  6359  	defer func() {
  6360  		close(body.closed)
  6361  		<-reqc // wait for POST request to finish
  6362  	}()
  6363  
  6364  	<-body.closing // wait for POST request to call body.Close
  6365  	// This GET request should not be blocked by the in-progress POST.
  6366  	res, err := client.Get(st.ts.URL)
  6367  	if err != nil {
  6368  		t.Fatal(err)
  6369  	}
  6370  	res.Body.Close()
  6371  }
  6372  
  6373  func TestTransportDialTLSContext(t *testing.T) {
  6374  	blockCh := make(chan struct{})
  6375  	serverTLSConfigFunc := func(ts *httptest.Server) {
  6376  		ts.Config.TLSConfig = &tls.Config{
  6377  			// Triggers the server to request the clients certificate
  6378  			// during TLS handshake.
  6379  			ClientAuth: tls.RequestClientCert,
  6380  		}
  6381  	}
  6382  	ts := newServerTester(t,
  6383  		func(w http.ResponseWriter, r *http.Request) {},
  6384  		optOnlyServer,
  6385  		serverTLSConfigFunc,
  6386  	)
  6387  	defer ts.Close()
  6388  	tr := &Transport{
  6389  		TLSClientConfig: &tls.Config{
  6390  			GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
  6391  				// Tests that the context provided to `req` is
  6392  				// passed into this function.
  6393  				close(blockCh)
  6394  				<-cri.Context().Done()
  6395  				return nil, cri.Context().Err()
  6396  			},
  6397  			InsecureSkipVerify: true,
  6398  		},
  6399  	}
  6400  	defer tr.CloseIdleConnections()
  6401  	req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
  6402  	if err != nil {
  6403  		t.Fatal(err)
  6404  	}
  6405  	ctx, cancel := context.WithCancel(context.Background())
  6406  	defer cancel()
  6407  	req = req.WithContext(ctx)
  6408  	errCh := make(chan error)
  6409  	go func() {
  6410  		defer close(errCh)
  6411  		res, err := tr.RoundTrip(req)
  6412  		if err != nil {
  6413  			errCh <- err
  6414  			return
  6415  		}
  6416  		res.Body.Close()
  6417  	}()
  6418  	// Wait for GetClientCertificate handler to be called
  6419  	<-blockCh
  6420  	// Cancel the context
  6421  	cancel()
  6422  	// Expect the cancellation error here
  6423  	err = <-errCh
  6424  	if err == nil {
  6425  		t.Fatal("cancelling context during client certificate fetch did not error as expected")
  6426  		return
  6427  	}
  6428  	if !errors.Is(err, context.Canceled) {
  6429  		t.Fatalf("unexpected error returned after cancellation: %v", err)
  6430  	}
  6431  }
  6432  
  6433  // TestDialRaceResumesDial tests that, given two concurrent requests
  6434  // to the same address, when the first Dial is interrupted because
  6435  // the first request's context is cancelled, the second request
  6436  // resumes the dial automatically.
  6437  func TestDialRaceResumesDial(t *testing.T) {
  6438  	blockCh := make(chan struct{})
  6439  	serverTLSConfigFunc := func(ts *httptest.Server) {
  6440  		ts.Config.TLSConfig = &tls.Config{
  6441  			// Triggers the server to request the clients certificate
  6442  			// during TLS handshake.
  6443  			ClientAuth: tls.RequestClientCert,
  6444  		}
  6445  	}
  6446  	ts := newServerTester(t,
  6447  		func(w http.ResponseWriter, r *http.Request) {},
  6448  		optOnlyServer,
  6449  		serverTLSConfigFunc,
  6450  	)
  6451  	defer ts.Close()
  6452  	tr := &Transport{
  6453  		TLSClientConfig: &tls.Config{
  6454  			GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
  6455  				select {
  6456  				case <-blockCh:
  6457  					// If we already errored, return without error.
  6458  					return &tls.Certificate{}, nil
  6459  				default:
  6460  				}
  6461  				close(blockCh)
  6462  				<-cri.Context().Done()
  6463  				return nil, cri.Context().Err()
  6464  			},
  6465  			InsecureSkipVerify: true,
  6466  		},
  6467  	}
  6468  	defer tr.CloseIdleConnections()
  6469  	req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
  6470  	if err != nil {
  6471  		t.Fatal(err)
  6472  	}
  6473  	// Create two requests with independent cancellation.
  6474  	ctx1, cancel1 := context.WithCancel(context.Background())
  6475  	defer cancel1()
  6476  	req1 := req.WithContext(ctx1)
  6477  	ctx2, cancel2 := context.WithCancel(context.Background())
  6478  	defer cancel2()
  6479  	req2 := req.WithContext(ctx2)
  6480  	errCh := make(chan error)
  6481  	go func() {
  6482  		res, err := tr.RoundTrip(req1)
  6483  		if err != nil {
  6484  			errCh <- err
  6485  			return
  6486  		}
  6487  		res.Body.Close()
  6488  	}()
  6489  	successCh := make(chan struct{})
  6490  	go func() {
  6491  		// Don't start request until first request
  6492  		// has initiated the handshake.
  6493  		<-blockCh
  6494  		res, err := tr.RoundTrip(req2)
  6495  		if err != nil {
  6496  			errCh <- err
  6497  			return
  6498  		}
  6499  		res.Body.Close()
  6500  		// Close successCh to indicate that the second request
  6501  		// made it to the server successfully.
  6502  		close(successCh)
  6503  	}()
  6504  	// Wait for GetClientCertificate handler to be called
  6505  	<-blockCh
  6506  	// Cancel the context first
  6507  	cancel1()
  6508  	// Expect the cancellation error here
  6509  	err = <-errCh
  6510  	if err == nil {
  6511  		t.Fatal("cancelling context during client certificate fetch did not error as expected")
  6512  		return
  6513  	}
  6514  	if !errors.Is(err, context.Canceled) {
  6515  		t.Fatalf("unexpected error returned after cancellation: %v", err)
  6516  	}
  6517  	select {
  6518  	case err := <-errCh:
  6519  		t.Fatalf("unexpected second error: %v", err)
  6520  	case <-successCh:
  6521  	}
  6522  }
  6523  

View as plain text