Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "net"
21 . "net/http"
22 "net/http/httptest"
23 "net/http/httptrace"
24 "net/http/httputil"
25 "net/textproto"
26 "net/url"
27 "os"
28 "reflect"
29 "runtime"
30 "sort"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "testing"
35 "time"
36 )
37
38 type testMode string
39
40 const (
41 http1Mode = testMode("h1")
42 https1Mode = testMode("https1")
43 http2Mode = testMode("h2")
44 )
45
46 type testNotParallelOpt struct{}
47
48 var (
49 testNotParallel = testNotParallelOpt{}
50 )
51
52 type TBRun[T any] interface {
53 testing.TB
54 Run(string, func(T)) bool
55 }
56
57
58
59
60
61
62
63
64 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
65 t.Helper()
66 modes := []testMode{http1Mode, http2Mode}
67 parallel := true
68 for _, opt := range opts {
69 switch opt := opt.(type) {
70 case []testMode:
71 modes = opt
72 case testNotParallelOpt:
73 parallel = false
74 default:
75 t.Fatalf("unknown option type %T", opt)
76 }
77 }
78 if t, ok := any(t).(*testing.T); ok && parallel {
79 setParallel(t)
80 }
81 for _, mode := range modes {
82 t.Run(string(mode), func(t T) {
83 t.Helper()
84 if t, ok := any(t).(*testing.T); ok && parallel {
85 setParallel(t)
86 }
87 t.Cleanup(func() {
88 afterTest(t)
89 })
90 f(t, mode)
91 })
92 }
93 }
94
95 type clientServerTest struct {
96 t testing.TB
97 h2 bool
98 h Handler
99 ts *httptest.Server
100 tr *Transport
101 c *Client
102 }
103
104 func (t *clientServerTest) close() {
105 t.tr.CloseIdleConnections()
106 t.ts.Close()
107 }
108
109 func (t *clientServerTest) getURL(u string) string {
110 res, err := t.c.Get(u)
111 if err != nil {
112 t.t.Fatal(err)
113 }
114 defer res.Body.Close()
115 slurp, err := io.ReadAll(res.Body)
116 if err != nil {
117 t.t.Fatal(err)
118 }
119 return string(slurp)
120 }
121
122 func (t *clientServerTest) scheme() string {
123 if t.h2 {
124 return "https"
125 }
126 return "http"
127 }
128
129 var optQuietLog = func(ts *httptest.Server) {
130 ts.Config.ErrorLog = quietLog
131 }
132
133 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
134 return func(ts *httptest.Server) {
135 ts.Config.ErrorLog = lg
136 }
137 }
138
139
140
141
142
143
144
145
146
147
148
149
150 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
151 if mode == http2Mode {
152 CondSkipHTTP2(t)
153 }
154 cst := &clientServerTest{
155 t: t,
156 h2: mode == http2Mode,
157 h: h,
158 }
159 cst.ts = httptest.NewUnstartedServer(h)
160
161 var transportFuncs []func(*Transport)
162 for _, opt := range opts {
163 switch opt := opt.(type) {
164 case func(*Transport):
165 transportFuncs = append(transportFuncs, opt)
166 case func(*httptest.Server):
167 opt(cst.ts)
168 default:
169 t.Fatalf("unhandled option type %T", opt)
170 }
171 }
172
173 if cst.ts.Config.ErrorLog == nil {
174 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
175 }
176
177 switch mode {
178 case http1Mode:
179 cst.ts.Start()
180 case https1Mode:
181 cst.ts.StartTLS()
182 case http2Mode:
183 ExportHttp2ConfigureServer(cst.ts.Config, nil)
184 cst.ts.TLS = cst.ts.Config.TLSConfig
185 cst.ts.StartTLS()
186 default:
187 t.Fatalf("unknown test mode %v", mode)
188 }
189 cst.c = cst.ts.Client()
190 cst.tr = cst.c.Transport.(*Transport)
191 if mode == http2Mode {
192 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
193 t.Fatal(err)
194 }
195 }
196 for _, f := range transportFuncs {
197 f(cst.tr)
198 }
199 t.Cleanup(func() {
200 cst.close()
201 })
202 return cst
203 }
204
205 type testLogWriter struct {
206 t testing.TB
207 }
208
209 func (w testLogWriter) Write(b []byte) (int, error) {
210 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
211 return len(b), nil
212 }
213
214
215 func TestNewClientServerTest(t *testing.T) {
216 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
217 }
218 func testNewClientServerTest(t *testing.T, mode testMode) {
219 var got struct {
220 sync.Mutex
221 proto string
222 hasTLS bool
223 }
224 h := HandlerFunc(func(w ResponseWriter, r *Request) {
225 got.Lock()
226 defer got.Unlock()
227 got.proto = r.Proto
228 got.hasTLS = r.TLS != nil
229 })
230 cst := newClientServerTest(t, mode, h)
231 if _, err := cst.c.Head(cst.ts.URL); err != nil {
232 t.Fatal(err)
233 }
234 var wantProto string
235 var wantTLS bool
236 switch mode {
237 case http1Mode:
238 wantProto = "HTTP/1.1"
239 wantTLS = false
240 case https1Mode:
241 wantProto = "HTTP/1.1"
242 wantTLS = true
243 case http2Mode:
244 wantProto = "HTTP/2.0"
245 wantTLS = true
246 }
247 if got.proto != wantProto {
248 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
249 }
250 if got.hasTLS != wantTLS {
251 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
252 }
253 }
254
255 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
256 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
257 log.SetOutput(io.Discard)
258 defer log.SetOutput(os.Stderr)
259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
260 w.Header().Set("Content-Length", "intentional gibberish")
261 w.(Flusher).Flush()
262 fmt.Fprintf(w, "I am a chunked response.")
263 }))
264
265 res, err := cst.c.Get(cst.ts.URL)
266 if err != nil {
267 t.Fatalf("Get error: %v", err)
268 }
269 defer res.Body.Close()
270 if g, e := res.ContentLength, int64(-1); g != e {
271 t.Errorf("expected ContentLength of %d; got %d", e, g)
272 }
273 wantTE := []string{"chunked"}
274 if mode == http2Mode {
275 wantTE = nil
276 }
277 if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
278 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
279 }
280 if got, haveCL := res.Header["Content-Length"]; haveCL {
281 t.Errorf("Unexpected Content-Length: %q", got)
282 }
283 }
284
285 type reqFunc func(c *Client, url string) (*Response, error)
286
287
288
289 type h12Compare struct {
290 Handler func(ResponseWriter, *Request)
291 ReqFunc reqFunc
292 CheckResponse func(proto string, res *Response)
293 EarlyCheckResponse func(proto string, res *Response)
294 Opts []any
295 }
296
297 func (tt h12Compare) reqFunc() reqFunc {
298 if tt.ReqFunc == nil {
299 return (*Client).Get
300 }
301 return tt.ReqFunc
302 }
303
304 func (tt h12Compare) run(t *testing.T) {
305 setParallel(t)
306 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
307 defer cst1.close()
308 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
309 defer cst2.close()
310
311 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
312 if err != nil {
313 t.Errorf("HTTP/1 request: %v", err)
314 return
315 }
316 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
317 if err != nil {
318 t.Errorf("HTTP/2 request: %v", err)
319 return
320 }
321
322 if fn := tt.EarlyCheckResponse; fn != nil {
323 fn("HTTP/1.1", res1)
324 fn("HTTP/2.0", res2)
325 }
326
327 tt.normalizeRes(t, res1, "HTTP/1.1")
328 tt.normalizeRes(t, res2, "HTTP/2.0")
329 res1body, res2body := res1.Body, res2.Body
330
331 eres1 := mostlyCopy(res1)
332 eres2 := mostlyCopy(res2)
333 if !reflect.DeepEqual(eres1, eres2) {
334 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
335 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
336 }
337 if !reflect.DeepEqual(res1body, res2body) {
338 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
339 }
340 if fn := tt.CheckResponse; fn != nil {
341 res1.Body, res2.Body = res1body, res2body
342 fn("HTTP/1.1", res1)
343 fn("HTTP/2.0", res2)
344 }
345 }
346
347 func mostlyCopy(r *Response) *Response {
348 c := *r
349 c.Body = nil
350 c.TransferEncoding = nil
351 c.TLS = nil
352 c.Request = nil
353 return &c
354 }
355
356 type slurpResult struct {
357 io.ReadCloser
358 body []byte
359 err error
360 }
361
362 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
363
364 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
365 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
366 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
367 } else {
368 t.Errorf("got %q response; want %q", res.Proto, wantProto)
369 }
370 slurp, err := io.ReadAll(res.Body)
371
372 res.Body.Close()
373 res.Body = slurpResult{
374 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
375 body: slurp,
376 err: err,
377 }
378 for i, v := range res.Header["Date"] {
379 res.Header["Date"][i] = strings.Repeat("x", len(v))
380 }
381 if res.Request == nil {
382 t.Errorf("for %s, no request", wantProto)
383 }
384 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
385 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
386 }
387 }
388
389
390 func TestH12_HeadContentLengthNoBody(t *testing.T) {
391 h12Compare{
392 ReqFunc: (*Client).Head,
393 Handler: func(w ResponseWriter, r *Request) {
394 },
395 }.run(t)
396 }
397
398 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
399 h12Compare{
400 ReqFunc: (*Client).Head,
401 Handler: func(w ResponseWriter, r *Request) {
402 io.WriteString(w, "small")
403 },
404 }.run(t)
405 }
406
407 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
408 h12Compare{
409 ReqFunc: (*Client).Head,
410 Handler: func(w ResponseWriter, r *Request) {
411 chunk := strings.Repeat("x", 512<<10)
412 for i := 0; i < 10; i++ {
413 io.WriteString(w, chunk)
414 }
415 },
416 }.run(t)
417 }
418
419 func TestH12_200NoBody(t *testing.T) {
420 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
421 }
422
423 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
424 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
425 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
426
427 func testH12_noBody(t *testing.T, status int) {
428 h12Compare{Handler: func(w ResponseWriter, r *Request) {
429 w.WriteHeader(status)
430 }}.run(t)
431 }
432
433 func TestH12_SmallBody(t *testing.T) {
434 h12Compare{Handler: func(w ResponseWriter, r *Request) {
435 io.WriteString(w, "small body")
436 }}.run(t)
437 }
438
439 func TestH12_ExplicitContentLength(t *testing.T) {
440 h12Compare{Handler: func(w ResponseWriter, r *Request) {
441 w.Header().Set("Content-Length", "3")
442 io.WriteString(w, "foo")
443 }}.run(t)
444 }
445
446 func TestH12_FlushBeforeBody(t *testing.T) {
447 h12Compare{Handler: func(w ResponseWriter, r *Request) {
448 w.(Flusher).Flush()
449 io.WriteString(w, "foo")
450 }}.run(t)
451 }
452
453 func TestH12_FlushMidBody(t *testing.T) {
454 h12Compare{Handler: func(w ResponseWriter, r *Request) {
455 io.WriteString(w, "foo")
456 w.(Flusher).Flush()
457 io.WriteString(w, "bar")
458 }}.run(t)
459 }
460
461 func TestH12_Head_ExplicitLen(t *testing.T) {
462 h12Compare{
463 ReqFunc: (*Client).Head,
464 Handler: func(w ResponseWriter, r *Request) {
465 if r.Method != "HEAD" {
466 t.Errorf("unexpected method %q", r.Method)
467 }
468 w.Header().Set("Content-Length", "1235")
469 },
470 }.run(t)
471 }
472
473 func TestH12_Head_ImplicitLen(t *testing.T) {
474 h12Compare{
475 ReqFunc: (*Client).Head,
476 Handler: func(w ResponseWriter, r *Request) {
477 if r.Method != "HEAD" {
478 t.Errorf("unexpected method %q", r.Method)
479 }
480 io.WriteString(w, "foo")
481 },
482 }.run(t)
483 }
484
485 func TestH12_HandlerWritesTooLittle(t *testing.T) {
486 h12Compare{
487 Handler: func(w ResponseWriter, r *Request) {
488 w.Header().Set("Content-Length", "3")
489 io.WriteString(w, "12")
490 },
491 CheckResponse: func(proto string, res *Response) {
492 sr, ok := res.Body.(slurpResult)
493 if !ok {
494 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
495 return
496 }
497 if sr.err != io.ErrUnexpectedEOF {
498 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
499 }
500 if string(sr.body) != "12" {
501 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
502 }
503 },
504 }.run(t)
505 }
506
507
508
509
510
511
512
513 func TestH12_HandlerWritesTooMuch(t *testing.T) {
514 h12Compare{
515 Handler: func(w ResponseWriter, r *Request) {
516 w.Header().Set("Content-Length", "3")
517 w.(Flusher).Flush()
518 io.WriteString(w, "123")
519 w.(Flusher).Flush()
520 n, err := io.WriteString(w, "x")
521 if n > 0 || err == nil {
522 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
523 }
524 },
525 }.run(t)
526 }
527
528
529
530 func TestH12_AutoGzip(t *testing.T) {
531 h12Compare{
532 Handler: func(w ResponseWriter, r *Request) {
533 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
534 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
535 }
536 w.Header().Set("Content-Encoding", "gzip")
537 gz := gzip.NewWriter(w)
538 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
539 gz.Close()
540 },
541 }.run(t)
542 }
543
544 func TestH12_AutoGzip_Disabled(t *testing.T) {
545 h12Compare{
546 Opts: []any{
547 func(tr *Transport) { tr.DisableCompression = true },
548 },
549 Handler: func(w ResponseWriter, r *Request) {
550 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
551 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
552 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
553 }
554 },
555 }.run(t)
556 }
557
558
559
560
561 func Test304Responses(t *testing.T) { run(t, test304Responses) }
562 func test304Responses(t *testing.T, mode testMode) {
563 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
564 w.WriteHeader(StatusNotModified)
565 _, err := w.Write([]byte("illegal body"))
566 if err != ErrBodyNotAllowed {
567 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
568 }
569 }))
570 defer cst.close()
571 res, err := cst.c.Get(cst.ts.URL)
572 if err != nil {
573 t.Fatal(err)
574 }
575 if len(res.TransferEncoding) > 0 {
576 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
577 }
578 body, err := io.ReadAll(res.Body)
579 if err != nil {
580 t.Error(err)
581 }
582 if len(body) > 0 {
583 t.Errorf("got unexpected body %q", string(body))
584 }
585 }
586
587 func TestH12_ServerEmptyContentLength(t *testing.T) {
588 h12Compare{
589 Handler: func(w ResponseWriter, r *Request) {
590 w.Header()["Content-Type"] = []string{""}
591 io.WriteString(w, "<html><body>hi</body></html>")
592 },
593 }.run(t)
594 }
595
596 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
597 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
598 }
599
600 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
601 h12requestContentLength(t, func() io.Reader { return nil }, 0)
602 }
603
604 func TestH12_RequestContentLength_Unknown(t *testing.T) {
605 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
606 }
607
608 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
609 h12Compare{
610 Handler: func(w ResponseWriter, r *Request) {
611 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
612 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
613 },
614 ReqFunc: func(c *Client, url string) (*Response, error) {
615 return c.Post(url, "text/plain", bodyfn())
616 },
617 CheckResponse: func(proto string, res *Response) {
618 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
619 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
620 }
621 },
622 }.run(t)
623 }
624
625
626
627 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
628 func testCancelRequestMidBody(t *testing.T, mode testMode) {
629 unblock := make(chan bool)
630 didFlush := make(chan bool, 1)
631 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
632 io.WriteString(w, "Hello")
633 w.(Flusher).Flush()
634 didFlush <- true
635 <-unblock
636 io.WriteString(w, ", world.")
637 }))
638 defer close(unblock)
639
640 req, _ := NewRequest("GET", cst.ts.URL, nil)
641 cancel := make(chan struct{})
642 req.Cancel = cancel
643
644 res, err := cst.c.Do(req)
645 if err != nil {
646 t.Fatal(err)
647 }
648 defer res.Body.Close()
649 <-didFlush
650
651
652
653 firstRead := make([]byte, 10)
654 n, err := res.Body.Read(firstRead)
655 if err != nil {
656 t.Fatal(err)
657 }
658 firstRead = firstRead[:n]
659
660 close(cancel)
661
662 rest, err := io.ReadAll(res.Body)
663 all := string(firstRead) + string(rest)
664 if all != "Hello" {
665 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
666 }
667 if err != ExportErrRequestCanceled {
668 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
669 }
670 }
671
672
673 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
674 func testTrailersClientToServer(t *testing.T, mode testMode) {
675 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
676 var decl []string
677 for k := range r.Trailer {
678 decl = append(decl, k)
679 }
680 sort.Strings(decl)
681
682 slurp, err := io.ReadAll(r.Body)
683 if err != nil {
684 t.Errorf("Server reading request body: %v", err)
685 }
686 if string(slurp) != "foo" {
687 t.Errorf("Server read request body %q; want foo", slurp)
688 }
689 if r.Trailer == nil {
690 io.WriteString(w, "nil Trailer")
691 } else {
692 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
693 decl,
694 r.Trailer.Get("Client-Trailer-A"),
695 r.Trailer.Get("Client-Trailer-B"))
696 }
697 }))
698
699 var req *Request
700 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
701 eofReaderFunc(func() {
702 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
703 }),
704 strings.NewReader("foo"),
705 eofReaderFunc(func() {
706 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
707 }),
708 ))
709 req.Trailer = Header{
710 "Client-Trailer-A": nil,
711 "Client-Trailer-B": nil,
712 }
713 req.ContentLength = -1
714 res, err := cst.c.Do(req)
715 if err != nil {
716 t.Fatal(err)
717 }
718 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
719 t.Error(err)
720 }
721 }
722
723
724 func TestTrailersServerToClient(t *testing.T) {
725 run(t, func(t *testing.T, mode testMode) {
726 testTrailersServerToClient(t, mode, false)
727 })
728 }
729 func TestTrailersServerToClientFlush(t *testing.T) {
730 run(t, func(t *testing.T, mode testMode) {
731 testTrailersServerToClient(t, mode, true)
732 })
733 }
734
735 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
736 const body = "Some body"
737 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
738 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
739 w.Header().Add("Trailer", "Server-Trailer-C")
740
741 io.WriteString(w, body)
742 if flush {
743 w.(Flusher).Flush()
744 }
745
746
747
748
749
750 w.Header().Set("Server-Trailer-A", "valuea")
751 w.Header().Set("Server-Trailer-C", "valuec")
752 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
753 }))
754
755 res, err := cst.c.Get(cst.ts.URL)
756 if err != nil {
757 t.Fatal(err)
758 }
759
760 wantHeader := Header{
761 "Content-Type": {"text/plain; charset=utf-8"},
762 }
763 wantLen := -1
764 if mode == http2Mode && !flush {
765
766
767
768
769
770 wantLen = len(body)
771 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
772 }
773 if res.ContentLength != int64(wantLen) {
774 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
775 }
776
777 delete(res.Header, "Date")
778 if !reflect.DeepEqual(res.Header, wantHeader) {
779 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
780 }
781
782 if got, want := res.Trailer, (Header{
783 "Server-Trailer-A": nil,
784 "Server-Trailer-B": nil,
785 "Server-Trailer-C": nil,
786 }); !reflect.DeepEqual(got, want) {
787 t.Errorf("Trailer before body read = %v; want %v", got, want)
788 }
789
790 if err := wantBody(res, nil, body); err != nil {
791 t.Fatal(err)
792 }
793
794 if got, want := res.Trailer, (Header{
795 "Server-Trailer-A": {"valuea"},
796 "Server-Trailer-B": nil,
797 "Server-Trailer-C": {"valuec"},
798 }); !reflect.DeepEqual(got, want) {
799 t.Errorf("Trailer after body read = %v; want %v", got, want)
800 }
801 }
802
803
804 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
805 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
806 const body = "Some body"
807 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
808 io.WriteString(w, body)
809 }))
810 res, err := cst.c.Get(cst.ts.URL)
811 if err != nil {
812 t.Fatal(err)
813 }
814 res.Body.Close()
815 data, err := io.ReadAll(res.Body)
816 if len(data) != 0 || err == nil {
817 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
818 }
819 }
820
821 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
822 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
823 const reqBody = "some request body"
824 const resBody = "some response body"
825 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
826 var wg sync.WaitGroup
827 wg.Add(2)
828 didRead := make(chan bool, 1)
829
830 go func() {
831 defer wg.Done()
832 data, err := io.ReadAll(r.Body)
833 if string(data) != reqBody {
834 t.Errorf("Handler read %q; want %q", data, reqBody)
835 }
836 if err != nil {
837 t.Errorf("Handler Read: %v", err)
838 }
839 didRead <- true
840 }()
841
842 go func() {
843 defer wg.Done()
844 if mode != http2Mode {
845
846
847
848
849 <-didRead
850 }
851 io.WriteString(w, resBody)
852 }()
853 wg.Wait()
854 }))
855 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
856 req.Header.Add("Expect", "100-continue")
857 res, err := cst.c.Do(req)
858 if err != nil {
859 t.Fatal(err)
860 }
861 data, err := io.ReadAll(res.Body)
862 defer res.Body.Close()
863 if err != nil {
864 t.Fatal(err)
865 }
866 if string(data) != resBody {
867 t.Errorf("read %q; want %q", data, resBody)
868 }
869 }
870
871 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
872 func testConnectRequest(t *testing.T, mode testMode) {
873 gotc := make(chan *Request, 1)
874 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
875 gotc <- r
876 }))
877
878 u, err := url.Parse(cst.ts.URL)
879 if err != nil {
880 t.Fatal(err)
881 }
882
883 tests := []struct {
884 req *Request
885 want string
886 }{
887 {
888 req: &Request{
889 Method: "CONNECT",
890 Header: Header{},
891 URL: u,
892 },
893 want: u.Host,
894 },
895 {
896 req: &Request{
897 Method: "CONNECT",
898 Header: Header{},
899 URL: u,
900 Host: "example.com:123",
901 },
902 want: "example.com:123",
903 },
904 }
905
906 for i, tt := range tests {
907 res, err := cst.c.Do(tt.req)
908 if err != nil {
909 t.Errorf("%d. RoundTrip = %v", i, err)
910 continue
911 }
912 res.Body.Close()
913 req := <-gotc
914 if req.Method != "CONNECT" {
915 t.Errorf("method = %q; want CONNECT", req.Method)
916 }
917 if req.Host != tt.want {
918 t.Errorf("Host = %q; want %q", req.Host, tt.want)
919 }
920 if req.URL.Host != tt.want {
921 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
922 }
923 }
924 }
925
926 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
927 func testTransportUserAgent(t *testing.T, mode testMode) {
928 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
929 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
930 }))
931
932 either := func(a, b string) string {
933 if mode == http2Mode {
934 return b
935 }
936 return a
937 }
938
939 tests := []struct {
940 setup func(*Request)
941 want string
942 }{
943 {
944 func(r *Request) {},
945 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
946 },
947 {
948 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
949 `["foo/1.2.3"]`,
950 },
951 {
952 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
953 `["single"]`,
954 },
955 {
956 func(r *Request) { r.Header.Set("User-Agent", "") },
957 `[]`,
958 },
959 {
960 func(r *Request) { r.Header["User-Agent"] = nil },
961 `[]`,
962 },
963 }
964 for i, tt := range tests {
965 req, _ := NewRequest("GET", cst.ts.URL, nil)
966 tt.setup(req)
967 res, err := cst.c.Do(req)
968 if err != nil {
969 t.Errorf("%d. RoundTrip = %v", i, err)
970 continue
971 }
972 slurp, err := io.ReadAll(res.Body)
973 res.Body.Close()
974 if err != nil {
975 t.Errorf("%d. read body = %v", i, err)
976 continue
977 }
978 if string(slurp) != tt.want {
979 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
980 }
981 }
982 }
983
984 func TestStarRequestMethod(t *testing.T) {
985 for _, method := range []string{"FOO", "OPTIONS"} {
986 t.Run(method, func(t *testing.T) {
987 run(t, func(t *testing.T, mode testMode) {
988 testStarRequest(t, method, mode)
989 })
990 })
991 }
992 }
993 func testStarRequest(t *testing.T, method string, mode testMode) {
994 gotc := make(chan *Request, 1)
995 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
996 w.Header().Set("foo", "bar")
997 gotc <- r
998 w.(Flusher).Flush()
999 }))
1000
1001 u, err := url.Parse(cst.ts.URL)
1002 if err != nil {
1003 t.Fatal(err)
1004 }
1005 u.Path = "*"
1006
1007 req := &Request{
1008 Method: method,
1009 Header: Header{},
1010 URL: u,
1011 }
1012
1013 res, err := cst.c.Do(req)
1014 if err != nil {
1015 t.Fatalf("RoundTrip = %v", err)
1016 }
1017 res.Body.Close()
1018
1019 wantFoo := "bar"
1020 wantLen := int64(-1)
1021 if method == "OPTIONS" {
1022 wantFoo = ""
1023 wantLen = 0
1024 }
1025 if res.StatusCode != 200 {
1026 t.Errorf("status code = %v; want %d", res.Status, 200)
1027 }
1028 if res.ContentLength != wantLen {
1029 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1030 }
1031 if got := res.Header.Get("foo"); got != wantFoo {
1032 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1033 }
1034 select {
1035 case req = <-gotc:
1036 default:
1037 req = nil
1038 }
1039 if req == nil {
1040 if method != "OPTIONS" {
1041 t.Fatalf("handler never got request")
1042 }
1043 return
1044 }
1045 if req.Method != method {
1046 t.Errorf("method = %q; want %q", req.Method, method)
1047 }
1048 if req.URL.Path != "*" {
1049 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1050 }
1051 if req.RequestURI != "*" {
1052 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1053 }
1054 }
1055
1056
1057 func TestTransportDiscardsUnneededConns(t *testing.T) {
1058 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1059 }
1060 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1061 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1062 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1063 }))
1064 defer cst.close()
1065
1066 var numOpen, numClose int32
1067
1068 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1069 tr := &Transport{
1070 TLSClientConfig: tlsConfig,
1071 DialTLS: func(_, addr string) (net.Conn, error) {
1072 time.Sleep(10 * time.Millisecond)
1073 rc, err := net.Dial("tcp", addr)
1074 if err != nil {
1075 return nil, err
1076 }
1077 atomic.AddInt32(&numOpen, 1)
1078 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1079 return tls.Client(c, tlsConfig), nil
1080 },
1081 }
1082 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1083 t.Fatal(err)
1084 }
1085 defer tr.CloseIdleConnections()
1086
1087 c := &Client{Transport: tr}
1088
1089 const N = 10
1090 gotBody := make(chan string, N)
1091 var wg sync.WaitGroup
1092 for i := 0; i < N; i++ {
1093 wg.Add(1)
1094 go func() {
1095 defer wg.Done()
1096 resp, err := c.Get(cst.ts.URL)
1097 if err != nil {
1098
1099
1100 time.Sleep(10 * time.Millisecond)
1101 resp, err = c.Get(cst.ts.URL)
1102 if err != nil {
1103 t.Errorf("Get: %v", err)
1104 return
1105 }
1106 }
1107 defer resp.Body.Close()
1108 slurp, err := io.ReadAll(resp.Body)
1109 if err != nil {
1110 t.Error(err)
1111 }
1112 gotBody <- string(slurp)
1113 }()
1114 }
1115 wg.Wait()
1116 close(gotBody)
1117
1118 var last string
1119 for got := range gotBody {
1120 if last == "" {
1121 last = got
1122 continue
1123 }
1124 if got != last {
1125 t.Errorf("Response body changed: %q -> %q", last, got)
1126 }
1127 }
1128
1129 var open, close int32
1130 for i := 0; i < 150; i++ {
1131 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1132 if open < 1 {
1133 t.Fatalf("open = %d; want at least", open)
1134 }
1135 if close == open-1 {
1136
1137 return
1138 }
1139 time.Sleep(10 * time.Millisecond)
1140 }
1141 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1142 }
1143
1144
1145 func TestTransportGCRequest(t *testing.T) {
1146 run(t, func(t *testing.T, mode testMode) {
1147 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1148 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1149 })
1150 }
1151 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1152 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1153 io.ReadAll(r.Body)
1154 if body {
1155 io.WriteString(w, "Hello.")
1156 }
1157 }))
1158
1159 didGC := make(chan struct{})
1160 (func() {
1161 body := strings.NewReader("some body")
1162 req, _ := NewRequest("POST", cst.ts.URL, body)
1163 runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1164 res, err := cst.c.Do(req)
1165 if err != nil {
1166 t.Fatal(err)
1167 }
1168 if _, err := io.ReadAll(res.Body); err != nil {
1169 t.Fatal(err)
1170 }
1171 if err := res.Body.Close(); err != nil {
1172 t.Fatal(err)
1173 }
1174 })()
1175 for {
1176 select {
1177 case <-didGC:
1178 return
1179 case <-time.After(1 * time.Millisecond):
1180 runtime.GC()
1181 }
1182 }
1183 }
1184
1185 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1186 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1187 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1188 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1189 }), optQuietLog)
1190 cst.tr.DisableKeepAlives = true
1191
1192 tests := []struct {
1193 key, val string
1194 ok bool
1195 }{
1196 {"Foo", "capital-key", true},
1197 {"Foo", "foo\x00bar", false},
1198 {"Foo", "two\nlines", false},
1199 {"bogus\nkey", "v", false},
1200 {"A space", "v", false},
1201 {"имя", "v", false},
1202 {"name", "валю", true},
1203 {"", "v", false},
1204 {"k", "", true},
1205 }
1206 for _, tt := range tests {
1207 dialedc := make(chan bool, 1)
1208 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1209 dialedc <- true
1210 return net.Dial(netw, addr)
1211 }
1212 req, _ := NewRequest("GET", cst.ts.URL, nil)
1213 req.Header[tt.key] = []string{tt.val}
1214 res, err := cst.c.Do(req)
1215 var body []byte
1216 if err == nil {
1217 body, _ = io.ReadAll(res.Body)
1218 res.Body.Close()
1219 }
1220 var dialed bool
1221 select {
1222 case <-dialedc:
1223 dialed = true
1224 default:
1225 }
1226
1227 if !tt.ok && dialed {
1228 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1229 } else if (err == nil) != tt.ok {
1230 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1231 }
1232 }
1233 }
1234
1235 func TestInterruptWithPanic(t *testing.T) {
1236 run(t, func(t *testing.T, mode testMode) {
1237 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1238 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1239 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1240 }, testNotParallel)
1241 }
1242 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1243 const msg = "hello"
1244
1245 testDone := make(chan struct{})
1246 defer close(testDone)
1247
1248 var errorLog lockedBytesBuffer
1249 gotHeaders := make(chan bool, 1)
1250 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1251 io.WriteString(w, msg)
1252 w.(Flusher).Flush()
1253
1254 select {
1255 case <-gotHeaders:
1256 case <-testDone:
1257 }
1258 panic(panicValue)
1259 }), func(ts *httptest.Server) {
1260 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1261 })
1262 res, err := cst.c.Get(cst.ts.URL)
1263 if err != nil {
1264 t.Fatal(err)
1265 }
1266 gotHeaders <- true
1267 defer res.Body.Close()
1268 slurp, err := io.ReadAll(res.Body)
1269 if string(slurp) != msg {
1270 t.Errorf("client read %q; want %q", slurp, msg)
1271 }
1272 if err == nil {
1273 t.Errorf("client read all successfully; want some error")
1274 }
1275 logOutput := func() string {
1276 errorLog.Lock()
1277 defer errorLog.Unlock()
1278 return errorLog.String()
1279 }
1280 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1281
1282 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1283 gotLog := logOutput()
1284 if !wantStackLogged {
1285 if gotLog == "" {
1286 return true
1287 }
1288 t.Fatalf("want no log output; got: %s", gotLog)
1289 }
1290 if gotLog == "" {
1291 if d > 0 {
1292 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1293 }
1294 return false
1295 }
1296 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1297 if d > 0 {
1298 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1299 }
1300 return false
1301 }
1302 return true
1303 })
1304 }
1305
1306 type lockedBytesBuffer struct {
1307 sync.Mutex
1308 bytes.Buffer
1309 }
1310
1311 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1312 b.Lock()
1313 defer b.Unlock()
1314 return b.Buffer.Write(p)
1315 }
1316
1317
1318 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1319 h12Compare{
1320 Handler: func(w ResponseWriter, r *Request) {
1321 h := w.Header()
1322 h.Set("Content-Encoding", "gzip")
1323 h.Set("Content-Length", "23")
1324 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1325 },
1326 EarlyCheckResponse: func(proto string, res *Response) {
1327 if !res.Uncompressed {
1328 t.Errorf("%s: expected Uncompressed to be set", proto)
1329 }
1330 dump, err := httputil.DumpResponse(res, true)
1331 if err != nil {
1332 t.Errorf("%s: DumpResponse: %v", proto, err)
1333 return
1334 }
1335 if strings.Contains(string(dump), "Connection: close") {
1336 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1337 }
1338 if !strings.Contains(string(dump), "FOO") {
1339 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1340 }
1341 },
1342 }.run(t)
1343 }
1344
1345
1346 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1347 func testCloseIdleConnections(t *testing.T, mode testMode) {
1348 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1349 w.Header().Set("X-Addr", r.RemoteAddr)
1350 }))
1351 get := func() string {
1352 res, err := cst.c.Get(cst.ts.URL)
1353 if err != nil {
1354 t.Fatal(err)
1355 }
1356 res.Body.Close()
1357 v := res.Header.Get("X-Addr")
1358 if v == "" {
1359 t.Fatal("didn't get X-Addr")
1360 }
1361 return v
1362 }
1363 a1 := get()
1364 cst.tr.CloseIdleConnections()
1365 a2 := get()
1366 if a1 == a2 {
1367 t.Errorf("didn't close connection")
1368 }
1369 }
1370
1371 type noteCloseConn struct {
1372 net.Conn
1373 closeFunc func()
1374 }
1375
1376 func (x noteCloseConn) Close() error {
1377 x.closeFunc()
1378 return x.Conn.Close()
1379 }
1380
1381 type testErrorReader struct{ t *testing.T }
1382
1383 func (r testErrorReader) Read(p []byte) (n int, err error) {
1384 r.t.Error("unexpected Read call")
1385 return 0, io.EOF
1386 }
1387
1388 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1389 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1390 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1391 w.WriteHeader(StatusUnauthorized)
1392 }))
1393
1394
1395 cst.tr.ExpectContinueTimeout = 10 * time.Second
1396
1397 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1398 if err != nil {
1399 t.Fatal(err)
1400 }
1401 req.ContentLength = 0
1402 req.Header.Set("Expect", "100-continue")
1403 res, err := cst.tr.RoundTrip(req)
1404 if err != nil {
1405 t.Fatal(err)
1406 }
1407 defer res.Body.Close()
1408 if res.StatusCode != StatusUnauthorized {
1409 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1410 }
1411 }
1412
1413 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1414 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1415 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1416 w.Header().Set("Foo", "Bar")
1417 w.Header().Set("Trailer:Foo", "Baz")
1418 w.(Flusher).Flush()
1419 w.Header().Add("Trailer:Foo", "Baz2")
1420 w.Header().Set("Trailer:Bar", "Quux")
1421 }))
1422 res, err := cst.c.Get(cst.ts.URL)
1423 if err != nil {
1424 t.Fatal(err)
1425 }
1426 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1427 t.Fatal(err)
1428 }
1429 res.Body.Close()
1430 delete(res.Header, "Date")
1431 delete(res.Header, "Content-Type")
1432
1433 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1434 t.Errorf("Header = %#v; want %#v", res.Header, want)
1435 }
1436 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1437 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1438 }
1439 }
1440
1441 func TestBadResponseAfterReadingBody(t *testing.T) {
1442 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1443 }
1444 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1445 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1446 _, err := io.Copy(io.Discard, r.Body)
1447 if err != nil {
1448 t.Fatal(err)
1449 }
1450 c, _, err := w.(Hijacker).Hijack()
1451 if err != nil {
1452 t.Fatal(err)
1453 }
1454 defer c.Close()
1455 fmt.Fprintln(c, "some bogus crap")
1456 }))
1457
1458 closes := 0
1459 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1460 if err == nil {
1461 res.Body.Close()
1462 t.Fatal("expected an error to be returned from Post")
1463 }
1464 if closes != 1 {
1465 t.Errorf("closes = %d; want 1", closes)
1466 }
1467 }
1468
1469 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1470 func testWriteHeader0(t *testing.T, mode testMode) {
1471 gotpanic := make(chan bool, 1)
1472 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1473 defer close(gotpanic)
1474 defer func() {
1475 if e := recover(); e != nil {
1476 got := fmt.Sprintf("%T, %v", e, e)
1477 want := "string, invalid WriteHeader code 0"
1478 if got != want {
1479 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1480 }
1481 gotpanic <- true
1482
1483
1484
1485
1486 w.WriteHeader(503)
1487 }
1488 }()
1489 w.WriteHeader(0)
1490 }))
1491 res, err := cst.c.Get(cst.ts.URL)
1492 if err != nil {
1493 t.Fatal(err)
1494 }
1495 if res.StatusCode != 503 {
1496 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1497 }
1498 if !<-gotpanic {
1499 t.Error("expected panic in handler")
1500 }
1501 }
1502
1503
1504
1505 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1506 run(t, func(t *testing.T, mode testMode) {
1507 testWriteHeaderAfterWrite(t, mode, false)
1508 })
1509 }
1510 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1511 testWriteHeaderAfterWrite(t, http1Mode, true)
1512 }
1513 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1514 var errorLog lockedBytesBuffer
1515 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1516 if hijack {
1517 conn, _, _ := w.(Hijacker).Hijack()
1518 defer conn.Close()
1519 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1520 w.WriteHeader(0)
1521 conn.Write([]byte("bar"))
1522 return
1523 }
1524 io.WriteString(w, "foo")
1525 w.(Flusher).Flush()
1526 w.WriteHeader(0)
1527 io.WriteString(w, "bar")
1528 }), func(ts *httptest.Server) {
1529 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1530 })
1531 res, err := cst.c.Get(cst.ts.URL)
1532 if err != nil {
1533 t.Fatal(err)
1534 }
1535 defer res.Body.Close()
1536 body, err := io.ReadAll(res.Body)
1537 if err != nil {
1538 t.Fatal(err)
1539 }
1540 if got, want := string(body), "foobar"; got != want {
1541 t.Errorf("got = %q; want %q", got, want)
1542 }
1543
1544
1545 if mode == http2Mode {
1546
1547
1548 return
1549 }
1550 gotLog := strings.TrimSpace(errorLog.String())
1551 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1552 if hijack {
1553 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1554 }
1555 if !strings.HasPrefix(gotLog, wantLog) {
1556 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1557 }
1558 }
1559
1560 func TestBidiStreamReverseProxy(t *testing.T) {
1561 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1562 }
1563 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1564 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1565 if _, err := io.Copy(w, r.Body); err != nil {
1566 log.Printf("bidi backend copy: %v", err)
1567 }
1568 }))
1569
1570 backURL, err := url.Parse(backend.ts.URL)
1571 if err != nil {
1572 t.Fatal(err)
1573 }
1574 rp := httputil.NewSingleHostReverseProxy(backURL)
1575 rp.Transport = backend.tr
1576 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1577 rp.ServeHTTP(w, r)
1578 }))
1579
1580 bodyRes := make(chan any, 1)
1581 pr, pw := io.Pipe()
1582 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1583 const size = 4 << 20
1584 go func() {
1585 h := sha1.New()
1586 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1587 go pw.Close()
1588 if err != nil {
1589 bodyRes <- err
1590 } else {
1591 bodyRes <- h
1592 }
1593 }()
1594 res, err := backend.c.Do(req)
1595 if err != nil {
1596 t.Fatal(err)
1597 }
1598 defer res.Body.Close()
1599 hgot := sha1.New()
1600 n, err := io.Copy(hgot, res.Body)
1601 if err != nil {
1602 t.Fatal(err)
1603 }
1604 if n != size {
1605 t.Fatalf("got %d bytes; want %d", n, size)
1606 }
1607 select {
1608 case v := <-bodyRes:
1609 switch v := v.(type) {
1610 default:
1611 t.Fatalf("body copy: %v", err)
1612 case hash.Hash:
1613 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1614 t.Errorf("written bytes didn't match received bytes")
1615 }
1616 }
1617 case <-time.After(10 * time.Second):
1618 t.Fatal("timeout")
1619 }
1620
1621 }
1622
1623
1624 func TestH12_WebSocketUpgrade(t *testing.T) {
1625 h12Compare{
1626 Handler: func(w ResponseWriter, r *Request) {
1627 h := w.Header()
1628 h.Set("Foo", "bar")
1629 },
1630 ReqFunc: func(c *Client, url string) (*Response, error) {
1631 req, _ := NewRequest("GET", url, nil)
1632 req.Header.Set("Connection", "Upgrade")
1633 req.Header.Set("Upgrade", "WebSocket")
1634 return c.Do(req)
1635 },
1636 EarlyCheckResponse: func(proto string, res *Response) {
1637 if res.Proto != "HTTP/1.1" {
1638 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1639 }
1640 res.Proto = "HTTP/IGNORE"
1641 },
1642 }.run(t)
1643 }
1644
1645 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1646 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1647 const body = "body"
1648 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1649 gotBody, _ := io.ReadAll(r.Body)
1650 if got, want := string(gotBody), body; got != want {
1651 t.Errorf("got request body = %q; want %q", got, want)
1652 }
1653 w.Header().Set("Transfer-Encoding", "identity")
1654 w.WriteHeader(StatusOK)
1655 w.(Flusher).Flush()
1656 io.WriteString(w, body)
1657 }))
1658 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1659 res, err := cst.c.Do(req)
1660 if err != nil {
1661 t.Fatal(err)
1662 }
1663 defer res.Body.Close()
1664 gotBody, err := io.ReadAll(res.Body)
1665 if err != nil {
1666 t.Fatal(err)
1667 }
1668 if got, want := string(gotBody), body; got != want {
1669 t.Errorf("got response body = %q; want %q", got, want)
1670 }
1671 }
1672
1673 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1674 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1675 var wg sync.WaitGroup
1676 wg.Add(1)
1677 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1678 h := w.Header()
1679
1680 h.Add("Content-Length", "123")
1681 h.Add("Link", "</style.css>; rel=preload; as=style")
1682 h.Add("Link", "</script.js>; rel=preload; as=script")
1683 w.WriteHeader(StatusEarlyHints)
1684
1685 wg.Wait()
1686
1687 h.Add("Link", "</foo.js>; rel=preload; as=script")
1688 w.WriteHeader(StatusEarlyHints)
1689
1690 w.Write([]byte("Hello"))
1691 }))
1692
1693 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1694 t.Helper()
1695
1696 if len(expected) != len(got) {
1697 t.Errorf("got %d expected %d", len(got), len(expected))
1698 }
1699
1700 for i := range expected {
1701 if expected[i] != got[i] {
1702 t.Errorf("got %q expected %q", got[i], expected[i])
1703 }
1704 }
1705 }
1706
1707 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1708 t.Helper()
1709
1710 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1711 if v, ok := header[h]; ok {
1712 t.Errorf("%s is %q; must not be sent", h, v)
1713 }
1714 }
1715 }
1716
1717 var respCounter uint8
1718 trace := &httptrace.ClientTrace{
1719 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1720 switch respCounter {
1721 case 0:
1722 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1723 checkExcludedHeaders(t, header)
1724
1725 wg.Done()
1726 case 1:
1727 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1728 checkExcludedHeaders(t, header)
1729
1730 default:
1731 t.Error("Unexpected 1xx response")
1732 }
1733
1734 respCounter++
1735
1736 return nil
1737 },
1738 }
1739 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1740
1741 res, err := cst.c.Do(req)
1742 if err != nil {
1743 t.Fatal(err)
1744 }
1745 defer res.Body.Close()
1746
1747 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1748 if cl := res.Header.Get("Content-Length"); cl != "123" {
1749 t.Errorf("Content-Length is %q; want 123", cl)
1750 }
1751
1752 body, _ := io.ReadAll(res.Body)
1753 if string(body) != "Hello" {
1754 t.Errorf("Read body %q; want Hello", body)
1755 }
1756 }
1757
View as plain text