1
2
3
4
5 package http2
6
7 import (
8 "bytes"
9 "compress/gzip"
10 "compress/zlib"
11 "context"
12 "crypto/tls"
13 "errors"
14 "flag"
15 "fmt"
16 "io"
17 "io/ioutil"
18 "log"
19 "net"
20 "net/http"
21 "net/http/httptest"
22 "os"
23 "reflect"
24 "runtime"
25 "strconv"
26 "strings"
27 "sync"
28 "testing"
29 "time"
30
31 "golang.org/x/net/http2/hpack"
32 )
33
34 var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
35
36 func stderrv() io.Writer {
37 if *stderrVerbose {
38 return os.Stderr
39 }
40
41 return ioutil.Discard
42 }
43
44 type safeBuffer struct {
45 b bytes.Buffer
46 m sync.Mutex
47 }
48
49 func (sb *safeBuffer) Write(d []byte) (int, error) {
50 sb.m.Lock()
51 defer sb.m.Unlock()
52 return sb.b.Write(d)
53 }
54
55 func (sb *safeBuffer) Bytes() []byte {
56 sb.m.Lock()
57 defer sb.m.Unlock()
58 return sb.b.Bytes()
59 }
60
61 func (sb *safeBuffer) Len() int {
62 sb.m.Lock()
63 defer sb.m.Unlock()
64 return sb.b.Len()
65 }
66
67 type serverTester struct {
68 cc net.Conn
69 t testing.TB
70 ts *httptest.Server
71 fr *Framer
72 serverLogBuf safeBuffer
73 logFilter []string
74 scMu sync.Mutex
75 sc *serverConn
76 hpackDec *hpack.Decoder
77 decodedHeaders [][2]string
78
79
80
81
82
83 frameReadLogMu sync.Mutex
84 frameReadLogBuf bytes.Buffer
85 frameWriteLogMu sync.Mutex
86 frameWriteLogBuf bytes.Buffer
87
88
89 headerBuf bytes.Buffer
90 hpackEnc *hpack.Encoder
91 }
92
93 func init() {
94 testHookOnPanicMu = new(sync.Mutex)
95 goAwayTimeout = 25 * time.Millisecond
96 }
97
98 func resetHooks() {
99 testHookOnPanicMu.Lock()
100 testHookOnPanic = nil
101 testHookOnPanicMu.Unlock()
102 }
103
104 type serverTesterOpt string
105
106 var optOnlyServer = serverTesterOpt("only_server")
107 var optQuiet = serverTesterOpt("quiet_logging")
108 var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
109
110 func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
111 resetHooks()
112
113 ts := httptest.NewUnstartedServer(handler)
114
115 tlsConfig := &tls.Config{
116 InsecureSkipVerify: true,
117 NextProtos: []string{NextProtoTLS},
118 }
119
120 var onlyServer, quiet, framerReuseFrames bool
121 h2server := new(Server)
122 for _, opt := range opts {
123 switch v := opt.(type) {
124 case func(*tls.Config):
125 v(tlsConfig)
126 case func(*httptest.Server):
127 v(ts)
128 case func(*Server):
129 v(h2server)
130 case serverTesterOpt:
131 switch v {
132 case optOnlyServer:
133 onlyServer = true
134 case optQuiet:
135 quiet = true
136 case optFramerReuseFrames:
137 framerReuseFrames = true
138 }
139 case func(net.Conn, http.ConnState):
140 ts.Config.ConnState = v
141 default:
142 t.Fatalf("unknown newServerTester option type %T", v)
143 }
144 }
145
146 ConfigureServer(ts.Config, h2server)
147
148
149
150
151
152 ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
153
154 st := &serverTester{
155 t: t,
156 ts: ts,
157 }
158 st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
159 st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
160
161 ts.TLS = ts.Config.TLSConfig
162 if quiet {
163 ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
164 } else {
165 ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
166 }
167 ts.StartTLS()
168
169 if VerboseLogs {
170 t.Logf("Running test server at: %s", ts.URL)
171 }
172 testHookGetServerConn = func(v *serverConn) {
173 st.scMu.Lock()
174 defer st.scMu.Unlock()
175 st.sc = v
176 }
177 log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
178 if !onlyServer {
179 cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
180 if err != nil {
181 t.Fatal(err)
182 }
183 st.cc = cc
184 st.fr = NewFramer(cc, cc)
185 if framerReuseFrames {
186 st.fr.SetReuseFrames()
187 }
188 if !logFrameReads && !logFrameWrites {
189 st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
190 m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
191 st.frameReadLogMu.Lock()
192 fmt.Fprintf(&st.frameReadLogBuf, m, v...)
193 st.frameReadLogMu.Unlock()
194 }
195 st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
196 m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
197 st.frameWriteLogMu.Lock()
198 fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
199 st.frameWriteLogMu.Unlock()
200 }
201 st.fr.logReads = true
202 st.fr.logWrites = true
203 }
204 }
205 return st
206 }
207
208 func (st *serverTester) closeConn() {
209 st.scMu.Lock()
210 defer st.scMu.Unlock()
211 st.sc.conn.Close()
212 }
213
214 func (st *serverTester) addLogFilter(phrase string) {
215 st.logFilter = append(st.logFilter, phrase)
216 }
217
218 func (st *serverTester) stream(id uint32) *stream {
219 ch := make(chan *stream, 1)
220 st.sc.serveMsgCh <- func(int) {
221 ch <- st.sc.streams[id]
222 }
223 return <-ch
224 }
225
226 func (st *serverTester) streamState(id uint32) streamState {
227 ch := make(chan streamState, 1)
228 st.sc.serveMsgCh <- func(int) {
229 state, _ := st.sc.state(id)
230 ch <- state
231 }
232 return <-ch
233 }
234
235
236 func (st *serverTester) loopNum() int {
237 lastc := make(chan int, 1)
238 st.sc.serveMsgCh <- func(loopNum int) {
239 lastc <- loopNum
240 }
241 return <-lastc
242 }
243
244
245
246
247 func (st *serverTester) awaitIdle() {
248 remain := 50
249 last := st.loopNum()
250 for remain > 0 {
251 n := st.loopNum()
252 if n == last+1 {
253 remain--
254 } else {
255 remain = 50
256 }
257 last = n
258 }
259 }
260
261 func (st *serverTester) Close() {
262 if st.t.Failed() {
263 st.frameReadLogMu.Lock()
264 if st.frameReadLogBuf.Len() > 0 {
265 st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
266 }
267 st.frameReadLogMu.Unlock()
268
269 st.frameWriteLogMu.Lock()
270 if st.frameWriteLogBuf.Len() > 0 {
271 st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
272 }
273 st.frameWriteLogMu.Unlock()
274
275
276
277
278
279 if st.cc != nil {
280 st.cc.Close()
281 }
282 }
283 st.ts.Close()
284 if st.cc != nil {
285 st.cc.Close()
286 }
287 log.SetOutput(os.Stderr)
288 }
289
290
291
292 func (st *serverTester) greet() {
293 st.greetAndCheckSettings(func(Setting) error { return nil })
294 }
295
296 func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
297 st.writePreface()
298 st.writeInitialSettings()
299 st.wantSettings().ForeachSetting(checkSetting)
300 st.writeSettingsAck()
301
302
303 var gotSettingsAck bool
304 var gotWindowUpdate bool
305
306 for i := 0; i < 2; i++ {
307 f, err := st.readFrame()
308 if err != nil {
309 st.t.Fatal(err)
310 }
311 switch f := f.(type) {
312 case *SettingsFrame:
313 if !f.Header().Flags.Has(FlagSettingsAck) {
314 st.t.Fatal("Settings Frame didn't have ACK set")
315 }
316 gotSettingsAck = true
317
318 case *WindowUpdateFrame:
319 if f.FrameHeader.StreamID != 0 {
320 st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
321 }
322 incr := uint32(st.sc.srv.initialConnRecvWindowSize() - initialWindowSize)
323 if f.Increment != incr {
324 st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
325 }
326 gotWindowUpdate = true
327
328 default:
329 st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
330 }
331 }
332
333 if !gotSettingsAck {
334 st.t.Fatalf("Didn't get a settings ACK")
335 }
336 if !gotWindowUpdate {
337 st.t.Fatalf("Didn't get a window update")
338 }
339 }
340
341 func (st *serverTester) writePreface() {
342 n, err := st.cc.Write(clientPreface)
343 if err != nil {
344 st.t.Fatalf("Error writing client preface: %v", err)
345 }
346 if n != len(clientPreface) {
347 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface))
348 }
349 }
350
351 func (st *serverTester) writeInitialSettings() {
352 if err := st.fr.WriteSettings(); err != nil {
353 if runtime.GOOS == "openbsd" && strings.HasSuffix(err.Error(), "write: broken pipe") {
354 st.t.Logf("Error writing initial SETTINGS frame from client to server: %v", err)
355 st.t.Skipf("Skipping test with known OpenBSD failure mode. (See https://go.dev/issue/52208.)")
356 }
357 st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
358 }
359 }
360
361 func (st *serverTester) writeSettingsAck() {
362 if err := st.fr.WriteSettingsAck(); err != nil {
363 st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
364 }
365 }
366
367 func (st *serverTester) writeHeaders(p HeadersFrameParam) {
368 if err := st.fr.WriteHeaders(p); err != nil {
369 st.t.Fatalf("Error writing HEADERS: %v", err)
370 }
371 }
372
373 func (st *serverTester) writePriority(id uint32, p PriorityParam) {
374 if err := st.fr.WritePriority(id, p); err != nil {
375 st.t.Fatalf("Error writing PRIORITY: %v", err)
376 }
377 }
378
379 func (st *serverTester) encodeHeaderField(k, v string) {
380 err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
381 if err != nil {
382 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
383 }
384 }
385
386
387
388 func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
389 if len(headers)%2 == 1 {
390 panic("odd number of kv args")
391 }
392 st.headerBuf.Reset()
393 for len(headers) > 0 {
394 k, v := headers[0], headers[1]
395 st.encodeHeaderField(k, v)
396 headers = headers[2:]
397 }
398 return st.headerBuf.Bytes()
399 }
400
401
402
403
404
405
406 func (st *serverTester) encodeHeader(headers ...string) []byte {
407 if len(headers)%2 == 1 {
408 panic("odd number of kv args")
409 }
410
411 st.headerBuf.Reset()
412 defaultAuthority := st.ts.Listener.Addr().String()
413
414 if len(headers) == 0 {
415
416
417 st.encodeHeaderField(":method", "GET")
418 st.encodeHeaderField(":scheme", "https")
419 st.encodeHeaderField(":authority", defaultAuthority)
420 st.encodeHeaderField(":path", "/")
421 return st.headerBuf.Bytes()
422 }
423
424 if len(headers) == 2 && headers[0] == ":method" {
425
426 st.encodeHeaderField(":method", headers[1])
427 st.encodeHeaderField(":scheme", "https")
428 st.encodeHeaderField(":authority", defaultAuthority)
429 st.encodeHeaderField(":path", "/")
430 return st.headerBuf.Bytes()
431 }
432
433 pseudoCount := map[string]int{}
434 keys := []string{":method", ":scheme", ":authority", ":path"}
435 vals := map[string][]string{
436 ":method": {"GET"},
437 ":scheme": {"https"},
438 ":authority": {defaultAuthority},
439 ":path": {"/"},
440 }
441 for len(headers) > 0 {
442 k, v := headers[0], headers[1]
443 headers = headers[2:]
444 if _, ok := vals[k]; !ok {
445 keys = append(keys, k)
446 }
447 if strings.HasPrefix(k, ":") {
448 pseudoCount[k]++
449 if pseudoCount[k] == 1 {
450 vals[k] = []string{v}
451 } else {
452
453 vals[k] = append(vals[k], v)
454 }
455 } else {
456 vals[k] = append(vals[k], v)
457 }
458 }
459 for _, k := range keys {
460 for _, v := range vals[k] {
461 st.encodeHeaderField(k, v)
462 }
463 }
464 return st.headerBuf.Bytes()
465 }
466
467
468 func (st *serverTester) bodylessReq1(headers ...string) {
469 st.writeHeaders(HeadersFrameParam{
470 StreamID: 1,
471 BlockFragment: st.encodeHeader(headers...),
472 EndStream: true,
473 EndHeaders: true,
474 })
475 }
476
477 func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
478 if err := st.fr.WriteData(streamID, endStream, data); err != nil {
479 st.t.Fatalf("Error writing DATA: %v", err)
480 }
481 }
482
483 func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
484 if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
485 st.t.Fatalf("Error writing DATA: %v", err)
486 }
487 }
488
489
490
491 func (st *serverTester) writeReadPing() {
492 data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
493 if err := st.fr.WritePing(false, data); err != nil {
494 st.t.Fatalf("Error writing PING: %v", err)
495 }
496 p := st.wantPing()
497 if p.Flags&FlagPingAck == 0 {
498 st.t.Fatalf("got a PING, want a PING ACK")
499 }
500 if p.Data != data {
501 st.t.Fatalf("got PING data = %x, want %x", p.Data, data)
502 }
503 }
504
505 func (st *serverTester) readFrame() (Frame, error) {
506 return st.fr.ReadFrame()
507 }
508
509 func (st *serverTester) wantHeaders() *HeadersFrame {
510 f, err := st.readFrame()
511 if err != nil {
512 st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
513 }
514 hf, ok := f.(*HeadersFrame)
515 if !ok {
516 st.t.Fatalf("got a %T; want *HeadersFrame", f)
517 }
518 return hf
519 }
520
521 func (st *serverTester) wantContinuation() *ContinuationFrame {
522 f, err := st.readFrame()
523 if err != nil {
524 st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err)
525 }
526 cf, ok := f.(*ContinuationFrame)
527 if !ok {
528 st.t.Fatalf("got a %T; want *ContinuationFrame", f)
529 }
530 return cf
531 }
532
533 func (st *serverTester) wantData() *DataFrame {
534 f, err := st.readFrame()
535 if err != nil {
536 st.t.Fatalf("Error while expecting a DATA frame: %v", err)
537 }
538 df, ok := f.(*DataFrame)
539 if !ok {
540 st.t.Fatalf("got a %T; want *DataFrame", f)
541 }
542 return df
543 }
544
545 func (st *serverTester) wantSettings() *SettingsFrame {
546 f, err := st.readFrame()
547 if err != nil {
548 st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
549 }
550 sf, ok := f.(*SettingsFrame)
551 if !ok {
552 st.t.Fatalf("got a %T; want *SettingsFrame", f)
553 }
554 return sf
555 }
556
557 func (st *serverTester) wantPing() *PingFrame {
558 f, err := st.readFrame()
559 if err != nil {
560 st.t.Fatalf("Error while expecting a PING frame: %v", err)
561 }
562 pf, ok := f.(*PingFrame)
563 if !ok {
564 st.t.Fatalf("got a %T; want *PingFrame", f)
565 }
566 return pf
567 }
568
569 func (st *serverTester) wantGoAway() *GoAwayFrame {
570 f, err := st.readFrame()
571 if err != nil {
572 st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err)
573 }
574 gf, ok := f.(*GoAwayFrame)
575 if !ok {
576 st.t.Fatalf("got a %T; want *GoAwayFrame", f)
577 }
578 return gf
579 }
580
581 func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
582 f, err := st.readFrame()
583 if err != nil {
584 st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
585 }
586 rs, ok := f.(*RSTStreamFrame)
587 if !ok {
588 st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
589 }
590 if rs.FrameHeader.StreamID != streamID {
591 st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
592 }
593 if rs.ErrCode != errCode {
594 st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
595 }
596 }
597
598 func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
599 f, err := st.readFrame()
600 if err != nil {
601 st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err)
602 }
603 wu, ok := f.(*WindowUpdateFrame)
604 if !ok {
605 st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
606 }
607 if wu.FrameHeader.StreamID != streamID {
608 st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
609 }
610 if wu.Increment != incr {
611 st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
612 }
613 }
614
615 func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) {
616 var initial int32
617 if streamID == 0 {
618 initial = st.sc.srv.initialConnRecvWindowSize()
619 } else {
620 initial = st.sc.srv.initialStreamRecvWindowSize()
621 }
622 donec := make(chan struct{})
623 st.sc.sendServeMsg(func(sc *serverConn) {
624 defer close(donec)
625 var avail int32
626 if streamID == 0 {
627 avail = sc.inflow.avail + sc.inflow.unsent
628 } else {
629 }
630 if got, want := initial-avail, consumed; got != want {
631 st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want)
632 }
633 })
634 <-donec
635 }
636
637 func (st *serverTester) wantSettingsAck() {
638 f, err := st.readFrame()
639 if err != nil {
640 st.t.Fatal(err)
641 }
642 sf, ok := f.(*SettingsFrame)
643 if !ok {
644 st.t.Fatalf("Wanting a settings ACK, received a %T", f)
645 }
646 if !sf.Header().Flags.Has(FlagSettingsAck) {
647 st.t.Fatal("Settings Frame didn't have ACK set")
648 }
649 }
650
651 func (st *serverTester) wantPushPromise() *PushPromiseFrame {
652 f, err := st.readFrame()
653 if err != nil {
654 st.t.Fatal(err)
655 }
656 ppf, ok := f.(*PushPromiseFrame)
657 if !ok {
658 st.t.Fatalf("Wanted PushPromise, received %T", ppf)
659 }
660 return ppf
661 }
662
663 func TestServer(t *testing.T) {
664 gotReq := make(chan bool, 1)
665 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
666 w.Header().Set("Foo", "Bar")
667 gotReq <- true
668 })
669 defer st.Close()
670
671 covers("3.5", `
672 The server connection preface consists of a potentially empty
673 SETTINGS frame ([SETTINGS]) that MUST be the first frame the
674 server sends in the HTTP/2 connection.
675 `)
676
677 st.greet()
678 st.writeHeaders(HeadersFrameParam{
679 StreamID: 1,
680 BlockFragment: st.encodeHeader(),
681 EndStream: true,
682 EndHeaders: true,
683 })
684
685 <-gotReq
686 }
687
688 func TestServer_Request_Get(t *testing.T) {
689 testServerRequest(t, func(st *serverTester) {
690 st.writeHeaders(HeadersFrameParam{
691 StreamID: 1,
692 BlockFragment: st.encodeHeader("foo-bar", "some-value"),
693 EndStream: true,
694 EndHeaders: true,
695 })
696 }, func(r *http.Request) {
697 if r.Method != "GET" {
698 t.Errorf("Method = %q; want GET", r.Method)
699 }
700 if r.URL.Path != "/" {
701 t.Errorf("URL.Path = %q; want /", r.URL.Path)
702 }
703 if r.ContentLength != 0 {
704 t.Errorf("ContentLength = %v; want 0", r.ContentLength)
705 }
706 if r.Close {
707 t.Error("Close = true; want false")
708 }
709 if !strings.Contains(r.RemoteAddr, ":") {
710 t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
711 }
712 if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
713 t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
714 }
715 wantHeader := http.Header{
716 "Foo-Bar": []string{"some-value"},
717 }
718 if !reflect.DeepEqual(r.Header, wantHeader) {
719 t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
720 }
721 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
722 t.Errorf("Read = %d, %v; want 0, EOF", n, err)
723 }
724 })
725 }
726
727 func TestServer_Request_Get_PathSlashes(t *testing.T) {
728 testServerRequest(t, func(st *serverTester) {
729 st.writeHeaders(HeadersFrameParam{
730 StreamID: 1,
731 BlockFragment: st.encodeHeader(":path", "/%2f/"),
732 EndStream: true,
733 EndHeaders: true,
734 })
735 }, func(r *http.Request) {
736 if r.RequestURI != "/%2f/" {
737 t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
738 }
739 if r.URL.Path != "///" {
740 t.Errorf("URL.Path = %q; want ///", r.URL.Path)
741 }
742 })
743 }
744
745
746
747
748
749 func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
750 testServerRequest(t, func(st *serverTester) {
751 st.writeHeaders(HeadersFrameParam{
752 StreamID: 1,
753 BlockFragment: st.encodeHeader(":method", "POST"),
754 EndStream: true,
755 EndHeaders: true,
756 })
757 }, func(r *http.Request) {
758 if r.Method != "POST" {
759 t.Errorf("Method = %q; want POST", r.Method)
760 }
761 if r.ContentLength != 0 {
762 t.Errorf("ContentLength = %v; want 0", r.ContentLength)
763 }
764 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
765 t.Errorf("Read = %d, %v; want 0, EOF", n, err)
766 }
767 })
768 }
769
770 func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
771 testBodyContents(t, -1, "", func(st *serverTester) {
772 st.writeHeaders(HeadersFrameParam{
773 StreamID: 1,
774 BlockFragment: st.encodeHeader(":method", "POST"),
775 EndStream: false,
776 EndHeaders: true,
777 })
778 st.writeData(1, true, nil)
779 })
780 }
781
782 func TestServer_Request_Post_Body_OneData(t *testing.T) {
783 const content = "Some content"
784 testBodyContents(t, -1, content, func(st *serverTester) {
785 st.writeHeaders(HeadersFrameParam{
786 StreamID: 1,
787 BlockFragment: st.encodeHeader(":method", "POST"),
788 EndStream: false,
789 EndHeaders: true,
790 })
791 st.writeData(1, true, []byte(content))
792 })
793 }
794
795 func TestServer_Request_Post_Body_TwoData(t *testing.T) {
796 const content = "Some content"
797 testBodyContents(t, -1, content, func(st *serverTester) {
798 st.writeHeaders(HeadersFrameParam{
799 StreamID: 1,
800 BlockFragment: st.encodeHeader(":method", "POST"),
801 EndStream: false,
802 EndHeaders: true,
803 })
804 st.writeData(1, false, []byte(content[:5]))
805 st.writeData(1, true, []byte(content[5:]))
806 })
807 }
808
809 func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
810 const content = "Some content"
811 testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
812 st.writeHeaders(HeadersFrameParam{
813 StreamID: 1,
814 BlockFragment: st.encodeHeader(
815 ":method", "POST",
816 "content-length", strconv.Itoa(len(content)),
817 ),
818 EndStream: false,
819 EndHeaders: true,
820 })
821 st.writeData(1, true, []byte(content))
822 })
823 }
824
825 func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
826 testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
827 func(st *serverTester) {
828 st.writeHeaders(HeadersFrameParam{
829 StreamID: 1,
830 BlockFragment: st.encodeHeader(
831 ":method", "POST",
832 "content-length", "3",
833 ),
834 EndStream: false,
835 EndHeaders: true,
836 })
837 st.writeData(1, true, []byte("12"))
838 })
839 }
840
841 func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
842 testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
843 func(st *serverTester) {
844 st.writeHeaders(HeadersFrameParam{
845 StreamID: 1,
846 BlockFragment: st.encodeHeader(
847 ":method", "POST",
848 "content-length", "4",
849 ),
850 EndStream: false,
851 EndHeaders: true,
852 })
853 st.writeData(1, true, []byte("12345"))
854
855
856 st.wantRSTStream(1, ErrCodeProtocol)
857 st.wantFlowControlConsumed(0, 0)
858 })
859 }
860
861 func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
862 testServerRequest(t, write, func(r *http.Request) {
863 if r.Method != "POST" {
864 t.Errorf("Method = %q; want POST", r.Method)
865 }
866 if r.ContentLength != wantContentLength {
867 t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
868 }
869 all, err := ioutil.ReadAll(r.Body)
870 if err != nil {
871 t.Fatal(err)
872 }
873 if string(all) != wantBody {
874 t.Errorf("Read = %q; want %q", all, wantBody)
875 }
876 if err := r.Body.Close(); err != nil {
877 t.Fatalf("Close: %v", err)
878 }
879 })
880 }
881
882 func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
883 testServerRequest(t, write, func(r *http.Request) {
884 if r.Method != "POST" {
885 t.Errorf("Method = %q; want POST", r.Method)
886 }
887 if r.ContentLength != wantContentLength {
888 t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
889 }
890 all, err := ioutil.ReadAll(r.Body)
891 if err == nil {
892 t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
893 wantReadError, all)
894 }
895 if !strings.Contains(err.Error(), wantReadError) {
896 t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
897 }
898 if err := r.Body.Close(); err != nil {
899 t.Fatalf("Close: %v", err)
900 }
901 })
902 }
903
904
905 func TestServer_Request_Get_Host(t *testing.T) {
906 const host = "example.com"
907 testServerRequest(t, func(st *serverTester) {
908 st.writeHeaders(HeadersFrameParam{
909 StreamID: 1,
910 BlockFragment: st.encodeHeader(":authority", "", "host", host),
911 EndStream: true,
912 EndHeaders: true,
913 })
914 }, func(r *http.Request) {
915 if r.Host != host {
916 t.Errorf("Host = %q; want %q", r.Host, host)
917 }
918 })
919 }
920
921
922 func TestServer_Request_Get_Authority(t *testing.T) {
923 const host = "example.com"
924 testServerRequest(t, func(st *serverTester) {
925 st.writeHeaders(HeadersFrameParam{
926 StreamID: 1,
927 BlockFragment: st.encodeHeader(":authority", host),
928 EndStream: true,
929 EndHeaders: true,
930 })
931 }, func(r *http.Request) {
932 if r.Host != host {
933 t.Errorf("Host = %q; want %q", r.Host, host)
934 }
935 })
936 }
937
938 func TestServer_Request_WithContinuation(t *testing.T) {
939 wantHeader := http.Header{
940 "Foo-One": []string{"value-one"},
941 "Foo-Two": []string{"value-two"},
942 "Foo-Three": []string{"value-three"},
943 }
944 testServerRequest(t, func(st *serverTester) {
945 fullHeaders := st.encodeHeader(
946 "foo-one", "value-one",
947 "foo-two", "value-two",
948 "foo-three", "value-three",
949 )
950 remain := fullHeaders
951 chunks := 0
952 for len(remain) > 0 {
953 const maxChunkSize = 5
954 chunk := remain
955 if len(chunk) > maxChunkSize {
956 chunk = chunk[:maxChunkSize]
957 }
958 remain = remain[len(chunk):]
959
960 if chunks == 0 {
961 st.writeHeaders(HeadersFrameParam{
962 StreamID: 1,
963 BlockFragment: chunk,
964 EndStream: true,
965 EndHeaders: false,
966 })
967 } else {
968 err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
969 if err != nil {
970 t.Fatal(err)
971 }
972 }
973 chunks++
974 }
975 if chunks < 2 {
976 t.Fatal("too few chunks")
977 }
978 }, func(r *http.Request) {
979 if !reflect.DeepEqual(r.Header, wantHeader) {
980 t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
981 }
982 })
983 }
984
985
986 func TestServer_Request_CookieConcat(t *testing.T) {
987 const host = "example.com"
988 testServerRequest(t, func(st *serverTester) {
989 st.bodylessReq1(
990 ":authority", host,
991 "cookie", "a=b",
992 "cookie", "c=d",
993 "cookie", "e=f",
994 )
995 }, func(r *http.Request) {
996 const want = "a=b; c=d; e=f"
997 if got := r.Header.Get("Cookie"); got != want {
998 t.Errorf("Cookie = %q; want %q", got, want)
999 }
1000 })
1001 }
1002
1003 func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
1004 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
1005 }
1006
1007 func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
1008 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
1009 }
1010
1011 func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
1012 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
1013 }
1014
1015 func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
1016 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
1017 }
1018
1019 func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
1020 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
1021 }
1022
1023 func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
1024 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
1025 }
1026
1027 func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
1028 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
1029 }
1030
1031 func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
1032 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
1033 }
1034
1035 func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
1036
1037
1038 testRejectRequest(t, func(st *serverTester) {
1039 st.addLogFilter("duplicate pseudo-header")
1040 st.bodylessReq1(":method", "GET", ":method", "POST")
1041 })
1042 }
1043
1044 func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
1045
1046
1047
1048
1049
1050
1051 testRejectRequest(t, func(st *serverTester) {
1052 st.addLogFilter("pseudo-header after regular header")
1053 var buf bytes.Buffer
1054 enc := hpack.NewEncoder(&buf)
1055 enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1056 enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1057 enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1058 enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1059 st.writeHeaders(HeadersFrameParam{
1060 StreamID: 1,
1061 BlockFragment: buf.Bytes(),
1062 EndStream: true,
1063 EndHeaders: true,
1064 })
1065 })
1066 }
1067
1068 func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1069 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1070 }
1071
1072 func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1073 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1074 }
1075
1076 func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1077 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1078 }
1079
1080 func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1081 testRejectRequest(t, func(st *serverTester) {
1082 st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1083 st.bodylessReq1(":unknown_thing", "")
1084 })
1085 }
1086
1087 func testRejectRequest(t *testing.T, send func(*serverTester)) {
1088 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1089 t.Error("server request made it to handler; should've been rejected")
1090 })
1091 defer st.Close()
1092
1093 st.greet()
1094 send(st)
1095 st.wantRSTStream(1, ErrCodeProtocol)
1096 }
1097
1098 func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) {
1099 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1100 t.Error("server request made it to handler; should've been rejected")
1101 }, optQuiet)
1102 defer st.Close()
1103
1104 st.greet()
1105 send(st)
1106 gf := st.wantGoAway()
1107 if gf.ErrCode != ErrCodeProtocol {
1108 t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol)
1109 }
1110 }
1111
1112
1113
1114
1115 func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1116 testRejectRequestWithProtocolError(t, func(st *serverTester) {
1117 st.fr.WriteWindowUpdate(123, 456)
1118 })
1119 }
1120 func TestRejectFrameOnIdle_Data(t *testing.T) {
1121 testRejectRequestWithProtocolError(t, func(st *serverTester) {
1122 st.fr.WriteData(123, true, nil)
1123 })
1124 }
1125 func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
1126 testRejectRequestWithProtocolError(t, func(st *serverTester) {
1127 st.fr.WriteRSTStream(123, ErrCodeCancel)
1128 })
1129 }
1130
1131 func TestServer_Request_Connect(t *testing.T) {
1132 testServerRequest(t, func(st *serverTester) {
1133 st.writeHeaders(HeadersFrameParam{
1134 StreamID: 1,
1135 BlockFragment: st.encodeHeaderRaw(
1136 ":method", "CONNECT",
1137 ":authority", "example.com:123",
1138 ),
1139 EndStream: true,
1140 EndHeaders: true,
1141 })
1142 }, func(r *http.Request) {
1143 if g, w := r.Method, "CONNECT"; g != w {
1144 t.Errorf("Method = %q; want %q", g, w)
1145 }
1146 if g, w := r.RequestURI, "example.com:123"; g != w {
1147 t.Errorf("RequestURI = %q; want %q", g, w)
1148 }
1149 if g, w := r.URL.Host, "example.com:123"; g != w {
1150 t.Errorf("URL.Host = %q; want %q", g, w)
1151 }
1152 })
1153 }
1154
1155 func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1156 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1157 st.writeHeaders(HeadersFrameParam{
1158 StreamID: 1,
1159 BlockFragment: st.encodeHeaderRaw(
1160 ":method", "CONNECT",
1161 ":authority", "example.com:123",
1162 ":path", "/bogus",
1163 ),
1164 EndStream: true,
1165 EndHeaders: true,
1166 })
1167 })
1168 }
1169
1170 func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1171 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1172 st.writeHeaders(HeadersFrameParam{
1173 StreamID: 1,
1174 BlockFragment: st.encodeHeaderRaw(
1175 ":method", "CONNECT",
1176 ":authority", "example.com:123",
1177 ":scheme", "https",
1178 ),
1179 EndStream: true,
1180 EndHeaders: true,
1181 })
1182 })
1183 }
1184
1185 func TestServer_Ping(t *testing.T) {
1186 st := newServerTester(t, nil)
1187 defer st.Close()
1188 st.greet()
1189
1190
1191 ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1192 if err := st.fr.WritePing(true, ackPingData); err != nil {
1193 t.Fatal(err)
1194 }
1195
1196
1197 pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1198 if err := st.fr.WritePing(false, pingData); err != nil {
1199 t.Fatal(err)
1200 }
1201
1202 pf := st.wantPing()
1203 if !pf.Flags.Has(FlagPingAck) {
1204 t.Error("response ping doesn't have ACK set")
1205 }
1206 if pf.Data != pingData {
1207 t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1208 }
1209 }
1210
1211 type filterListener struct {
1212 net.Listener
1213 accept func(conn net.Conn) (net.Conn, error)
1214 }
1215
1216 func (l *filterListener) Accept() (net.Conn, error) {
1217 c, err := l.Listener.Accept()
1218 if err != nil {
1219 return nil, err
1220 }
1221 return l.accept(c)
1222 }
1223
1224 func TestServer_MaxQueuedControlFrames(t *testing.T) {
1225 if testing.Short() {
1226 t.Skip("skipping in short mode")
1227 }
1228
1229 st := newServerTester(t, nil, func(ts *httptest.Server) {
1230
1231
1232 ts.Listener = &filterListener{
1233 Listener: ts.Listener,
1234 accept: func(conn net.Conn) (net.Conn, error) {
1235 return newBlockingWriteConn(conn, 10000), nil
1236 },
1237 }
1238 })
1239 defer st.Close()
1240 st.greet()
1241
1242 const extraPings = 500000
1243
1244 for i := 0; i < maxQueuedControlFrames+extraPings; i++ {
1245 pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1246 if err := st.fr.WritePing(false, pingData); err != nil {
1247 if i == 0 {
1248 t.Fatal(err)
1249 }
1250
1251
1252 t.Logf("sent %d PING frames", i)
1253 return
1254 }
1255 }
1256 t.Errorf("unexpected success sending all PING frames")
1257 }
1258
1259 func TestServer_RejectsLargeFrames(t *testing.T) {
1260 if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "zos" {
1261 t.Skip("see golang.org/issue/13434, golang.org/issue/37321")
1262 }
1263 st := newServerTester(t, nil)
1264 defer st.Close()
1265 st.greet()
1266
1267
1268
1269
1270 st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
1271
1272 gf := st.wantGoAway()
1273 if gf.ErrCode != ErrCodeFrameSize {
1274 t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize)
1275 }
1276 if st.serverLogBuf.Len() != 0 {
1277
1278
1279 t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes())
1280 }
1281 }
1282
1283 func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1284
1285
1286
1287
1288 const windowSize = 65535 * 2
1289 puppet := newHandlerPuppet()
1290 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1291 puppet.act(w, r)
1292 }, func(s *Server) {
1293 s.MaxUploadBufferPerConnection = windowSize
1294 s.MaxUploadBufferPerStream = windowSize
1295 })
1296 defer st.Close()
1297 defer puppet.done()
1298
1299 st.greet()
1300 st.writeHeaders(HeadersFrameParam{
1301 StreamID: 1,
1302 BlockFragment: st.encodeHeader(":method", "POST"),
1303 EndStream: false,
1304 EndHeaders: true,
1305 })
1306 st.writeReadPing()
1307
1308
1309
1310
1311 data := make([]byte, windowSize)
1312 st.writeData(1, false, data[:1024])
1313 puppet.do(readBodyHandler(t, string(data[:1024])))
1314 st.writeReadPing()
1315
1316
1317
1318 st.writeData(1, false, data[1024:])
1319 st.wantWindowUpdate(0, 1024)
1320 st.wantWindowUpdate(1, 1024)
1321 st.writeReadPing()
1322
1323
1324 puppet.do(readBodyHandler(t, string(data[1024:])))
1325 st.wantWindowUpdate(0, windowSize-1024)
1326 st.wantWindowUpdate(1, windowSize-1024)
1327 st.writeReadPing()
1328 }
1329
1330
1331
1332 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1333 const windowSize = 65535 * 2
1334 puppet := newHandlerPuppet()
1335 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1336 puppet.act(w, r)
1337 }, func(s *Server) {
1338 s.MaxUploadBufferPerConnection = windowSize
1339 s.MaxUploadBufferPerStream = windowSize
1340 })
1341 defer st.Close()
1342 defer puppet.done()
1343
1344 st.greet()
1345 st.writeHeaders(HeadersFrameParam{
1346 StreamID: 1,
1347 BlockFragment: st.encodeHeader(":method", "POST"),
1348 EndStream: false,
1349 EndHeaders: true,
1350 })
1351 st.writeReadPing()
1352
1353
1354
1355
1356 data := make([]byte, windowSize/2)
1357 pad := make([]byte, 4)
1358 st.writeDataPadded(1, false, data, pad)
1359 st.writeReadPing()
1360
1361
1362
1363
1364 puppet.do(readBodyHandler(t, string(data)))
1365 st.wantWindowUpdate(0, uint32(len(data)+1+len(pad)))
1366 st.wantWindowUpdate(1, uint32(len(data)+1+len(pad)))
1367 }
1368
1369 func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1370 st := newServerTester(t, nil)
1371 defer st.Close()
1372 st.greet()
1373 if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1374 t.Fatal(err)
1375 }
1376 gf := st.wantGoAway()
1377 if gf.ErrCode != ErrCodeFlowControl {
1378 t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
1379 }
1380 if gf.LastStreamID != 0 {
1381 t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
1382 }
1383 }
1384
1385 func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1386 inHandler := make(chan bool)
1387 blockHandler := make(chan bool)
1388 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1389 inHandler <- true
1390 <-blockHandler
1391 })
1392 defer st.Close()
1393 defer close(blockHandler)
1394 st.greet()
1395 st.writeHeaders(HeadersFrameParam{
1396 StreamID: 1,
1397 BlockFragment: st.encodeHeader(":method", "POST"),
1398 EndStream: false,
1399 EndHeaders: true,
1400 })
1401 <-inHandler
1402
1403 if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1404 t.Fatal(err)
1405 }
1406 st.wantRSTStream(1, ErrCodeFlowControl)
1407 }
1408
1409
1410
1411
1412 func testServerPostUnblock(t *testing.T,
1413 handler func(http.ResponseWriter, *http.Request) error,
1414 fn func(*serverTester),
1415 checkErr func(error),
1416 otherHeaders ...string) {
1417 inHandler := make(chan bool)
1418 errc := make(chan error, 1)
1419 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1420 inHandler <- true
1421 errc <- handler(w, r)
1422 })
1423 defer st.Close()
1424 st.greet()
1425 st.writeHeaders(HeadersFrameParam{
1426 StreamID: 1,
1427 BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1428 EndStream: false,
1429 EndHeaders: true,
1430 })
1431 <-inHandler
1432 fn(st)
1433 err := <-errc
1434 if checkErr != nil {
1435 checkErr(err)
1436 }
1437 }
1438
1439 func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1440 testServerPostUnblock(t,
1441 func(w http.ResponseWriter, r *http.Request) (err error) {
1442 _, err = r.Body.Read(make([]byte, 1))
1443 return
1444 },
1445 func(st *serverTester) {
1446 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1447 t.Fatal(err)
1448 }
1449 },
1450 func(err error) {
1451 want := StreamError{StreamID: 0x1, Code: 0x8}
1452 if !reflect.DeepEqual(err, want) {
1453 t.Errorf("Read error = %v; want %v", err, want)
1454 }
1455 },
1456 )
1457 }
1458
1459 func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1460
1461
1462 n := 50
1463 if testing.Short() {
1464 n = 5
1465 }
1466 for i := 0; i < n; i++ {
1467 testServer_RSTStream_Unblocks_Header_Write(t)
1468 }
1469 }
1470
1471 func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1472 inHandler := make(chan bool, 1)
1473 unblockHandler := make(chan bool, 1)
1474 headerWritten := make(chan bool, 1)
1475 wroteRST := make(chan bool, 1)
1476
1477 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1478 inHandler <- true
1479 <-wroteRST
1480 w.Header().Set("foo", "bar")
1481 w.WriteHeader(200)
1482 w.(http.Flusher).Flush()
1483 headerWritten <- true
1484 <-unblockHandler
1485 })
1486 defer st.Close()
1487
1488 st.greet()
1489 st.writeHeaders(HeadersFrameParam{
1490 StreamID: 1,
1491 BlockFragment: st.encodeHeader(":method", "POST"),
1492 EndStream: false,
1493 EndHeaders: true,
1494 })
1495 <-inHandler
1496 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1497 t.Fatal(err)
1498 }
1499 wroteRST <- true
1500 st.awaitIdle()
1501 <-headerWritten
1502 unblockHandler <- true
1503 }
1504
1505 func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1506 testServerPostUnblock(t,
1507 func(w http.ResponseWriter, r *http.Request) (err error) {
1508 _, err = r.Body.Read(make([]byte, 1))
1509 return
1510 },
1511 func(st *serverTester) { st.cc.Close() },
1512 func(err error) {
1513 if err == nil {
1514 t.Error("unexpected nil error from Request.Body.Read")
1515 }
1516 },
1517 )
1518 }
1519
1520 var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1521 <-w.(http.CloseNotifier).CloseNotify()
1522 return nil
1523 }
1524
1525 func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1526 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1527 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1528 t.Fatal(err)
1529 }
1530 }, nil)
1531 }
1532
1533 func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1534 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1535 }
1536
1537
1538
1539
1540 func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1541 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1542
1543 st.writeData(1, true, []byte("1234"))
1544 }, nil, "content-length", "3")
1545 }
1546
1547 func TestServer_StateTransitions(t *testing.T) {
1548 var st *serverTester
1549 inHandler := make(chan bool)
1550 writeData := make(chan bool)
1551 leaveHandler := make(chan bool)
1552 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1553 inHandler <- true
1554 if st.stream(1) == nil {
1555 t.Errorf("nil stream 1 in handler")
1556 }
1557 if got, want := st.streamState(1), stateOpen; got != want {
1558 t.Errorf("in handler, state is %v; want %v", got, want)
1559 }
1560 writeData <- true
1561 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1562 t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1563 }
1564 if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
1565 t.Errorf("in handler, state is %v; want %v", got, want)
1566 }
1567
1568 <-leaveHandler
1569 })
1570 st.greet()
1571 if st.stream(1) != nil {
1572 t.Fatal("stream 1 should be empty")
1573 }
1574 if got := st.streamState(1); got != stateIdle {
1575 t.Fatalf("stream 1 should be idle; got %v", got)
1576 }
1577
1578 st.writeHeaders(HeadersFrameParam{
1579 StreamID: 1,
1580 BlockFragment: st.encodeHeader(":method", "POST"),
1581 EndStream: false,
1582 EndHeaders: true,
1583 })
1584 <-inHandler
1585 <-writeData
1586 st.writeData(1, true, nil)
1587
1588 leaveHandler <- true
1589 hf := st.wantHeaders()
1590 if !hf.StreamEnded() {
1591 t.Fatal("expected END_STREAM flag")
1592 }
1593
1594 if got, want := st.streamState(1), stateClosed; got != want {
1595 t.Errorf("at end, state is %v; want %v", got, want)
1596 }
1597 if st.stream(1) != nil {
1598 t.Fatal("at end, stream 1 should be gone")
1599 }
1600 }
1601
1602
1603 func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1604 testServerRejectsConn(t, func(st *serverTester) {
1605 st.writeHeaders(HeadersFrameParam{
1606 StreamID: 1,
1607 BlockFragment: st.encodeHeader(),
1608 EndStream: true,
1609 EndHeaders: false,
1610 })
1611 st.writeHeaders(HeadersFrameParam{
1612 StreamID: 3,
1613 BlockFragment: st.encodeHeader(),
1614 EndStream: true,
1615 EndHeaders: true,
1616 })
1617 })
1618 }
1619
1620
1621 func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1622 testServerRejectsConn(t, func(st *serverTester) {
1623 st.writeHeaders(HeadersFrameParam{
1624 StreamID: 1,
1625 BlockFragment: st.encodeHeader(),
1626 EndStream: true,
1627 EndHeaders: false,
1628 })
1629 if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1630 t.Fatal(err)
1631 }
1632 })
1633 }
1634
1635
1636 func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1637 testServerRejectsConn(t, func(st *serverTester) {
1638 st.writeHeaders(HeadersFrameParam{
1639 StreamID: 1,
1640 BlockFragment: st.encodeHeader(),
1641 EndStream: true,
1642 EndHeaders: true,
1643 })
1644 st.wantHeaders()
1645 if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1646 t.Fatal(err)
1647 }
1648 })
1649 }
1650
1651
1652 func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1653 testServerRejectsConn(t, func(st *serverTester) {
1654 st.writeHeaders(HeadersFrameParam{
1655 StreamID: 1,
1656 BlockFragment: st.encodeHeader(),
1657 EndStream: true,
1658 EndHeaders: false,
1659 })
1660 if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1661 t.Fatal(err)
1662 }
1663 })
1664 }
1665
1666
1667 func TestServer_Rejects_Headers0(t *testing.T) {
1668 testServerRejectsConn(t, func(st *serverTester) {
1669 st.fr.AllowIllegalWrites = true
1670 st.writeHeaders(HeadersFrameParam{
1671 StreamID: 0,
1672 BlockFragment: st.encodeHeader(),
1673 EndStream: true,
1674 EndHeaders: true,
1675 })
1676 })
1677 }
1678
1679
1680 func TestServer_Rejects_Continuation0(t *testing.T) {
1681 testServerRejectsConn(t, func(st *serverTester) {
1682 st.fr.AllowIllegalWrites = true
1683 if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1684 t.Fatal(err)
1685 }
1686 })
1687 }
1688
1689
1690 func TestServer_Rejects_Priority0(t *testing.T) {
1691 testServerRejectsConn(t, func(st *serverTester) {
1692 st.fr.AllowIllegalWrites = true
1693 st.writePriority(0, PriorityParam{StreamDep: 1})
1694 })
1695 }
1696
1697
1698 func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1699 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1700 st.fr.AllowIllegalWrites = true
1701 st.writeHeaders(HeadersFrameParam{
1702 StreamID: 1,
1703 BlockFragment: st.encodeHeader(),
1704 EndStream: true,
1705 EndHeaders: true,
1706 Priority: PriorityParam{StreamDep: 1},
1707 })
1708 })
1709 }
1710
1711
1712 func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1713 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1714 st.fr.AllowIllegalWrites = true
1715 st.writePriority(1, PriorityParam{StreamDep: 1})
1716 })
1717 }
1718
1719 func TestServer_Rejects_PushPromise(t *testing.T) {
1720 testServerRejectsConn(t, func(st *serverTester) {
1721 pp := PushPromiseParam{
1722 StreamID: 1,
1723 PromiseID: 3,
1724 }
1725 if err := st.fr.WritePushPromise(pp); err != nil {
1726 t.Fatal(err)
1727 }
1728 })
1729 }
1730
1731
1732
1733
1734 func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
1735 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1736 st.addLogFilter("connection error: PROTOCOL_ERROR")
1737 defer st.Close()
1738 st.greet()
1739 writeReq(st)
1740
1741 st.wantGoAway()
1742
1743 fr, err := st.fr.ReadFrame()
1744 if err == nil {
1745 t.Errorf("ReadFrame got frame of type %T; want io.EOF", fr)
1746 }
1747 if err != io.EOF {
1748 t.Errorf("ReadFrame = %v; want io.EOF", err)
1749 }
1750 }
1751
1752
1753
1754 func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
1755 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1756 defer st.Close()
1757 st.greet()
1758 writeReq(st)
1759 st.wantRSTStream(1, code)
1760 }
1761
1762
1763
1764
1765 func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
1766 gotReq := make(chan bool, 1)
1767 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1768 if r.Body == nil {
1769 t.Fatal("nil Body")
1770 }
1771 checkReq(r)
1772 gotReq <- true
1773 })
1774 defer st.Close()
1775
1776 st.greet()
1777 writeReq(st)
1778 <-gotReq
1779 }
1780
1781 func getSlash(st *serverTester) { st.bodylessReq1() }
1782
1783 func TestServer_Response_NoData(t *testing.T) {
1784 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1785
1786 return nil
1787 }, func(st *serverTester) {
1788 getSlash(st)
1789 hf := st.wantHeaders()
1790 if !hf.StreamEnded() {
1791 t.Fatal("want END_STREAM flag")
1792 }
1793 if !hf.HeadersEnded() {
1794 t.Fatal("want END_HEADERS flag")
1795 }
1796 })
1797 }
1798
1799 func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1800 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1801 w.Header().Set("Foo-Bar", "some-value")
1802 return nil
1803 }, func(st *serverTester) {
1804 getSlash(st)
1805 hf := st.wantHeaders()
1806 if !hf.StreamEnded() {
1807 t.Fatal("want END_STREAM flag")
1808 }
1809 if !hf.HeadersEnded() {
1810 t.Fatal("want END_HEADERS flag")
1811 }
1812 goth := st.decodeHeader(hf.HeaderBlockFragment())
1813 wanth := [][2]string{
1814 {":status", "200"},
1815 {"foo-bar", "some-value"},
1816 {"content-length", "0"},
1817 }
1818 if !reflect.DeepEqual(goth, wanth) {
1819 t.Errorf("Got headers %v; want %v", goth, wanth)
1820 }
1821 })
1822 }
1823
1824
1825
1826 func TestServerIgnoresContentLengthSignWhenWritingChunks(t *testing.T) {
1827 tests := []struct {
1828 name string
1829 cl string
1830 wantCL string
1831 }{
1832 {
1833 name: "proper content-length",
1834 cl: "3",
1835 wantCL: "3",
1836 },
1837 {
1838 name: "ignore cl with plus sign",
1839 cl: "+3",
1840 wantCL: "0",
1841 },
1842 {
1843 name: "ignore cl with minus sign",
1844 cl: "-3",
1845 wantCL: "0",
1846 },
1847 {
1848 name: "max int64, for safe uint64->int64 conversion",
1849 cl: "9223372036854775807",
1850 wantCL: "9223372036854775807",
1851 },
1852 {
1853 name: "overflows int64, so ignored",
1854 cl: "9223372036854775808",
1855 wantCL: "0",
1856 },
1857 }
1858
1859 for _, tt := range tests {
1860 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1861 w.Header().Set("content-length", tt.cl)
1862 return nil
1863 }, func(st *serverTester) {
1864 getSlash(st)
1865 hf := st.wantHeaders()
1866 goth := st.decodeHeader(hf.HeaderBlockFragment())
1867 wanth := [][2]string{
1868 {":status", "200"},
1869 {"content-length", tt.wantCL},
1870 }
1871 if !reflect.DeepEqual(goth, wanth) {
1872 t.Errorf("For case %q, value %q, got = %q; want %q", tt.name, tt.cl, goth, wanth)
1873 }
1874 })
1875 }
1876 }
1877
1878
1879
1880 func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) {
1881 tests := []struct {
1882 name string
1883 cl string
1884 wantCL int64
1885 }{
1886 {
1887 name: "proper content-length",
1888 cl: "3",
1889 wantCL: 3,
1890 },
1891 {
1892 name: "ignore cl with plus sign",
1893 cl: "+3",
1894 wantCL: 0,
1895 },
1896 {
1897 name: "ignore cl with minus sign",
1898 cl: "-3",
1899 wantCL: 0,
1900 },
1901 {
1902 name: "max int64, for safe uint64->int64 conversion",
1903 cl: "9223372036854775807",
1904 wantCL: 9223372036854775807,
1905 },
1906 {
1907 name: "overflows int64, so ignored",
1908 cl: "9223372036854775808",
1909 wantCL: 0,
1910 },
1911 }
1912
1913 for _, tt := range tests {
1914 tt := tt
1915 t.Run(tt.name, func(t *testing.T) {
1916 writeReq := func(st *serverTester) {
1917 st.writeHeaders(HeadersFrameParam{
1918 StreamID: 1,
1919 BlockFragment: st.encodeHeader("content-length", tt.cl),
1920 EndStream: false,
1921 EndHeaders: true,
1922 })
1923 st.writeData(1, false, []byte(""))
1924 }
1925 checkReq := func(r *http.Request) {
1926 if r.ContentLength != tt.wantCL {
1927 t.Fatalf("Got: %q\nWant: %q", r.ContentLength, tt.wantCL)
1928 }
1929 }
1930 testServerRequest(t, writeReq, checkReq)
1931 })
1932 }
1933 }
1934
1935 func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1936 const msg = "<html>this is HTML."
1937 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1938 w.Header().Set("Content-Type", "foo/bar")
1939 io.WriteString(w, msg)
1940 return nil
1941 }, func(st *serverTester) {
1942 getSlash(st)
1943 hf := st.wantHeaders()
1944 if hf.StreamEnded() {
1945 t.Fatal("don't want END_STREAM, expecting data")
1946 }
1947 if !hf.HeadersEnded() {
1948 t.Fatal("want END_HEADERS flag")
1949 }
1950 goth := st.decodeHeader(hf.HeaderBlockFragment())
1951 wanth := [][2]string{
1952 {":status", "200"},
1953 {"content-type", "foo/bar"},
1954 {"content-length", strconv.Itoa(len(msg))},
1955 }
1956 if !reflect.DeepEqual(goth, wanth) {
1957 t.Errorf("Got headers %v; want %v", goth, wanth)
1958 }
1959 df := st.wantData()
1960 if !df.StreamEnded() {
1961 t.Error("expected DATA to have END_STREAM flag")
1962 }
1963 if got := string(df.Data()); got != msg {
1964 t.Errorf("got DATA %q; want %q", got, msg)
1965 }
1966 })
1967 }
1968
1969 func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1970 const msg = "hi"
1971 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1972 w.Header().Set("Transfer-Encoding", "chunked")
1973 io.WriteString(w, msg)
1974 return nil
1975 }, func(st *serverTester) {
1976 getSlash(st)
1977 hf := st.wantHeaders()
1978 goth := st.decodeHeader(hf.HeaderBlockFragment())
1979 wanth := [][2]string{
1980 {":status", "200"},
1981 {"content-type", "text/plain; charset=utf-8"},
1982 {"content-length", strconv.Itoa(len(msg))},
1983 }
1984 if !reflect.DeepEqual(goth, wanth) {
1985 t.Errorf("Got headers %v; want %v", goth, wanth)
1986 }
1987 })
1988 }
1989
1990
1991 func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
1992 const msg = "<html>this is HTML."
1993 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1994 io.WriteString(w, msg)
1995 w.Header().Set("foo", "should be ignored")
1996 return nil
1997 }, func(st *serverTester) {
1998 getSlash(st)
1999 hf := st.wantHeaders()
2000 if hf.StreamEnded() {
2001 t.Fatal("unexpected END_STREAM")
2002 }
2003 if !hf.HeadersEnded() {
2004 t.Fatal("want END_HEADERS flag")
2005 }
2006 goth := st.decodeHeader(hf.HeaderBlockFragment())
2007 wanth := [][2]string{
2008 {":status", "200"},
2009 {"content-type", "text/html; charset=utf-8"},
2010 {"content-length", strconv.Itoa(len(msg))},
2011 }
2012 if !reflect.DeepEqual(goth, wanth) {
2013 t.Errorf("Got headers %v; want %v", goth, wanth)
2014 }
2015 })
2016 }
2017
2018
2019 func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
2020 const msg = "<html>this is HTML."
2021 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2022 w.Header().Set("foo", "proper value")
2023 io.WriteString(w, msg)
2024 w.Header().Set("foo", "should be ignored")
2025 return nil
2026 }, func(st *serverTester) {
2027 getSlash(st)
2028 hf := st.wantHeaders()
2029 if hf.StreamEnded() {
2030 t.Fatal("unexpected END_STREAM")
2031 }
2032 if !hf.HeadersEnded() {
2033 t.Fatal("want END_HEADERS flag")
2034 }
2035 goth := st.decodeHeader(hf.HeaderBlockFragment())
2036 wanth := [][2]string{
2037 {":status", "200"},
2038 {"foo", "proper value"},
2039 {"content-type", "text/html; charset=utf-8"},
2040 {"content-length", strconv.Itoa(len(msg))},
2041 }
2042 if !reflect.DeepEqual(goth, wanth) {
2043 t.Errorf("Got headers %v; want %v", goth, wanth)
2044 }
2045 })
2046 }
2047
2048 func TestServer_Response_Data_SniffLenType(t *testing.T) {
2049 const msg = "<html>this is HTML."
2050 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2051 io.WriteString(w, msg)
2052 return nil
2053 }, func(st *serverTester) {
2054 getSlash(st)
2055 hf := st.wantHeaders()
2056 if hf.StreamEnded() {
2057 t.Fatal("don't want END_STREAM, expecting data")
2058 }
2059 if !hf.HeadersEnded() {
2060 t.Fatal("want END_HEADERS flag")
2061 }
2062 goth := st.decodeHeader(hf.HeaderBlockFragment())
2063 wanth := [][2]string{
2064 {":status", "200"},
2065 {"content-type", "text/html; charset=utf-8"},
2066 {"content-length", strconv.Itoa(len(msg))},
2067 }
2068 if !reflect.DeepEqual(goth, wanth) {
2069 t.Errorf("Got headers %v; want %v", goth, wanth)
2070 }
2071 df := st.wantData()
2072 if !df.StreamEnded() {
2073 t.Error("expected DATA to have END_STREAM flag")
2074 }
2075 if got := string(df.Data()); got != msg {
2076 t.Errorf("got DATA %q; want %q", got, msg)
2077 }
2078 })
2079 }
2080
2081 func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
2082 const msg = "<html>this is HTML"
2083 const msg2 = ", and this is the next chunk"
2084 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2085 io.WriteString(w, msg)
2086 w.(http.Flusher).Flush()
2087 io.WriteString(w, msg2)
2088 return nil
2089 }, func(st *serverTester) {
2090 getSlash(st)
2091 hf := st.wantHeaders()
2092 if hf.StreamEnded() {
2093 t.Fatal("unexpected END_STREAM flag")
2094 }
2095 if !hf.HeadersEnded() {
2096 t.Fatal("want END_HEADERS flag")
2097 }
2098 goth := st.decodeHeader(hf.HeaderBlockFragment())
2099 wanth := [][2]string{
2100 {":status", "200"},
2101 {"content-type", "text/html; charset=utf-8"},
2102
2103 }
2104 if !reflect.DeepEqual(goth, wanth) {
2105 t.Errorf("Got headers %v; want %v", goth, wanth)
2106 }
2107 {
2108 df := st.wantData()
2109 if df.StreamEnded() {
2110 t.Error("unexpected END_STREAM flag")
2111 }
2112 if got := string(df.Data()); got != msg {
2113 t.Errorf("got DATA %q; want %q", got, msg)
2114 }
2115 }
2116 {
2117 df := st.wantData()
2118 if !df.StreamEnded() {
2119 t.Error("wanted END_STREAM flag on last data chunk")
2120 }
2121 if got := string(df.Data()); got != msg2 {
2122 t.Errorf("got DATA %q; want %q", got, msg2)
2123 }
2124 }
2125 })
2126 }
2127
2128 func TestServer_Response_LargeWrite(t *testing.T) {
2129 const size = 1 << 20
2130 const maxFrameSize = 16 << 10
2131 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2132 n, err := w.Write(bytes.Repeat([]byte("a"), size))
2133 if err != nil {
2134 return fmt.Errorf("Write error: %v", err)
2135 }
2136 if n != size {
2137 return fmt.Errorf("wrong size %d from Write", n)
2138 }
2139 return nil
2140 }, func(st *serverTester) {
2141 if err := st.fr.WriteSettings(
2142 Setting{SettingInitialWindowSize, 0},
2143 Setting{SettingMaxFrameSize, maxFrameSize},
2144 ); err != nil {
2145 t.Fatal(err)
2146 }
2147 st.wantSettingsAck()
2148
2149 getSlash(st)
2150
2151
2152 if err := st.fr.WriteWindowUpdate(1, size); err != nil {
2153 t.Fatal(err)
2154 }
2155
2156
2157 if err := st.fr.WriteWindowUpdate(0, size); err != nil {
2158 t.Fatal(err)
2159 }
2160 hf := st.wantHeaders()
2161 if hf.StreamEnded() {
2162 t.Fatal("unexpected END_STREAM flag")
2163 }
2164 if !hf.HeadersEnded() {
2165 t.Fatal("want END_HEADERS flag")
2166 }
2167 goth := st.decodeHeader(hf.HeaderBlockFragment())
2168 wanth := [][2]string{
2169 {":status", "200"},
2170 {"content-type", "text/plain; charset=utf-8"},
2171
2172 }
2173 if !reflect.DeepEqual(goth, wanth) {
2174 t.Errorf("Got headers %v; want %v", goth, wanth)
2175 }
2176 var bytes, frames int
2177 for {
2178 df := st.wantData()
2179 bytes += len(df.Data())
2180 frames++
2181 for _, b := range df.Data() {
2182 if b != 'a' {
2183 t.Fatal("non-'a' byte seen in DATA")
2184 }
2185 }
2186 if df.StreamEnded() {
2187 break
2188 }
2189 }
2190 if bytes != size {
2191 t.Errorf("Got %d bytes; want %d", bytes, size)
2192 }
2193 if want := int(size / maxFrameSize); frames < want || frames > want*2 {
2194 t.Errorf("Got %d frames; want %d", frames, size)
2195 }
2196 })
2197 }
2198
2199
2200 func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
2201
2202
2203 reads := []int{123, 1, 13, 127}
2204 size := 0
2205 for _, n := range reads {
2206 size += n
2207 }
2208
2209 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2210 w.(http.Flusher).Flush()
2211 n, err := w.Write(bytes.Repeat([]byte("a"), size))
2212 if err != nil {
2213 return fmt.Errorf("Write error: %v", err)
2214 }
2215 if n != size {
2216 return fmt.Errorf("wrong size %d from Write", n)
2217 }
2218 return nil
2219 }, func(st *serverTester) {
2220
2221
2222 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2223 t.Fatal(err)
2224 }
2225 st.wantSettingsAck()
2226
2227 getSlash(st)
2228
2229 hf := st.wantHeaders()
2230 if hf.StreamEnded() {
2231 t.Fatal("unexpected END_STREAM flag")
2232 }
2233 if !hf.HeadersEnded() {
2234 t.Fatal("want END_HEADERS flag")
2235 }
2236
2237 df := st.wantData()
2238 if got := len(df.Data()); got != reads[0] {
2239 t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got)
2240 }
2241
2242 for _, quota := range reads[1:] {
2243 if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2244 t.Fatal(err)
2245 }
2246 df := st.wantData()
2247 if int(quota) != len(df.Data()) {
2248 t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
2249 }
2250 }
2251 })
2252 }
2253
2254
2255 func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2256 const size = 1 << 20
2257 const maxFrameSize = 16 << 10
2258 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2259 w.(http.Flusher).Flush()
2260 _, err := w.Write(bytes.Repeat([]byte("a"), size))
2261 if err == nil {
2262 return errors.New("unexpected nil error from Write in handler")
2263 }
2264 return nil
2265 }, func(st *serverTester) {
2266 if err := st.fr.WriteSettings(
2267 Setting{SettingInitialWindowSize, 0},
2268 Setting{SettingMaxFrameSize, maxFrameSize},
2269 ); err != nil {
2270 t.Fatal(err)
2271 }
2272 st.wantSettingsAck()
2273
2274 getSlash(st)
2275
2276 hf := st.wantHeaders()
2277 if hf.StreamEnded() {
2278 t.Fatal("unexpected END_STREAM flag")
2279 }
2280 if !hf.HeadersEnded() {
2281 t.Fatal("want END_HEADERS flag")
2282 }
2283
2284 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2285 t.Fatal(err)
2286 }
2287 })
2288 }
2289
2290 func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2291 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2292 w.(http.Flusher).Flush()
2293
2294 return nil
2295 }, func(st *serverTester) {
2296
2297 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2298 t.Fatal(err)
2299 }
2300 st.wantSettingsAck()
2301
2302 getSlash(st)
2303
2304 hf := st.wantHeaders()
2305 if hf.StreamEnded() {
2306 t.Fatal("unexpected END_STREAM flag")
2307 }
2308 if !hf.HeadersEnded() {
2309 t.Fatal("want END_HEADERS flag")
2310 }
2311
2312 df := st.wantData()
2313 if got := len(df.Data()); got != 0 {
2314 t.Fatalf("unexpected %d DATA bytes; want 0", got)
2315 }
2316 if !df.StreamEnded() {
2317 t.Fatal("DATA didn't have END_STREAM")
2318 }
2319 })
2320 }
2321
2322 func TestServer_Response_Automatic100Continue(t *testing.T) {
2323 const msg = "foo"
2324 const reply = "bar"
2325 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2326 if v := r.Header.Get("Expect"); v != "" {
2327 t.Errorf("Expect header = %q; want empty", v)
2328 }
2329 buf := make([]byte, len(msg))
2330
2331 if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2332 return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2333 }
2334 _, err := io.WriteString(w, reply)
2335 return err
2336 }, func(st *serverTester) {
2337 st.writeHeaders(HeadersFrameParam{
2338 StreamID: 1,
2339 BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-Continue"),
2340 EndStream: false,
2341 EndHeaders: true,
2342 })
2343 hf := st.wantHeaders()
2344 if hf.StreamEnded() {
2345 t.Fatal("unexpected END_STREAM flag")
2346 }
2347 if !hf.HeadersEnded() {
2348 t.Fatal("want END_HEADERS flag")
2349 }
2350 goth := st.decodeHeader(hf.HeaderBlockFragment())
2351 wanth := [][2]string{
2352 {":status", "100"},
2353 }
2354 if !reflect.DeepEqual(goth, wanth) {
2355 t.Fatalf("Got headers %v; want %v", goth, wanth)
2356 }
2357
2358
2359
2360 st.writeData(1, true, []byte(msg))
2361
2362 hf = st.wantHeaders()
2363 if hf.StreamEnded() {
2364 t.Fatal("expected data to follow")
2365 }
2366 if !hf.HeadersEnded() {
2367 t.Fatal("want END_HEADERS flag")
2368 }
2369 goth = st.decodeHeader(hf.HeaderBlockFragment())
2370 wanth = [][2]string{
2371 {":status", "200"},
2372 {"content-type", "text/plain; charset=utf-8"},
2373 {"content-length", strconv.Itoa(len(reply))},
2374 }
2375 if !reflect.DeepEqual(goth, wanth) {
2376 t.Errorf("Got headers %v; want %v", goth, wanth)
2377 }
2378
2379 df := st.wantData()
2380 if string(df.Data()) != reply {
2381 t.Errorf("Client read %q; want %q", df.Data(), reply)
2382 }
2383 if !df.StreamEnded() {
2384 t.Errorf("expect data stream end")
2385 }
2386 })
2387 }
2388
2389 func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2390 errc := make(chan error, 1)
2391 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2392 p := []byte("some data.\n")
2393 for {
2394 _, err := w.Write(p)
2395 if err != nil {
2396 errc <- err
2397 return nil
2398 }
2399 }
2400 }, func(st *serverTester) {
2401 st.writeHeaders(HeadersFrameParam{
2402 StreamID: 1,
2403 BlockFragment: st.encodeHeader(),
2404 EndStream: false,
2405 EndHeaders: true,
2406 })
2407 hf := st.wantHeaders()
2408 if hf.StreamEnded() {
2409 t.Fatal("unexpected END_STREAM flag")
2410 }
2411 if !hf.HeadersEnded() {
2412 t.Fatal("want END_HEADERS flag")
2413 }
2414
2415 st.cc.Close()
2416 _ = <-errc
2417 })
2418 }
2419
2420 func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2421 const testPath = "/some/path"
2422
2423 inHandler := make(chan uint32)
2424 leaveHandler := make(chan bool)
2425 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2426 id := w.(*responseWriter).rws.stream.id
2427 inHandler <- id
2428 if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
2429 t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
2430 }
2431 <-leaveHandler
2432 })
2433 defer st.Close()
2434 st.greet()
2435 nextStreamID := uint32(1)
2436 streamID := func() uint32 {
2437 defer func() { nextStreamID += 2 }()
2438 return nextStreamID
2439 }
2440 sendReq := func(id uint32, headers ...string) {
2441 st.writeHeaders(HeadersFrameParam{
2442 StreamID: id,
2443 BlockFragment: st.encodeHeader(headers...),
2444 EndStream: true,
2445 EndHeaders: true,
2446 })
2447 }
2448 for i := 0; i < defaultMaxStreams; i++ {
2449 sendReq(streamID())
2450 <-inHandler
2451 }
2452 defer func() {
2453 for i := 0; i < defaultMaxStreams; i++ {
2454 leaveHandler <- true
2455 }
2456 }()
2457
2458
2459
2460
2461 rejectID := streamID()
2462 headerBlock := st.encodeHeader(":path", testPath)
2463 frag1, frag2 := headerBlock[:3], headerBlock[3:]
2464 st.writeHeaders(HeadersFrameParam{
2465 StreamID: rejectID,
2466 BlockFragment: frag1,
2467 EndStream: true,
2468 EndHeaders: false,
2469 })
2470 if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2471 t.Fatal(err)
2472 }
2473 st.wantRSTStream(rejectID, ErrCodeProtocol)
2474
2475
2476 leaveHandler <- true
2477 st.wantHeaders()
2478
2479
2480 goodID := streamID()
2481 sendReq(goodID, ":path", testPath)
2482 if got := <-inHandler; got != goodID {
2483 t.Errorf("Got stream %d; want %d", got, goodID)
2484 }
2485 }
2486
2487
2488 func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2489 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2490 h := w.Header()
2491 for i := 0; i < 5000; i++ {
2492 h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2493 }
2494 return nil
2495 }, func(st *serverTester) {
2496 getSlash(st)
2497 hf := st.wantHeaders()
2498 if hf.HeadersEnded() {
2499 t.Fatal("got unwanted END_HEADERS flag")
2500 }
2501 n := 0
2502 for {
2503 n++
2504 cf := st.wantContinuation()
2505 if cf.HeadersEnded() {
2506 break
2507 }
2508 }
2509 if n < 5 {
2510 t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2511 }
2512 })
2513 }
2514
2515
2516
2517
2518
2519
2520
2521
2522 func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2523 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2524
2525 return nil
2526 }, func(st *serverTester) {
2527 st.writeHeaders(HeadersFrameParam{
2528 StreamID: 1,
2529 BlockFragment: st.encodeHeader(),
2530 EndStream: false,
2531 EndHeaders: true,
2532 })
2533 hf := st.wantHeaders()
2534 if !hf.HeadersEnded() || !hf.StreamEnded() {
2535 t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf)
2536 }
2537
2538
2539
2540 st.wantRSTStream(1, ErrCodeNo)
2541
2542
2543
2544
2545
2546
2547 st.writeData(1, true, []byte("foo"))
2548
2549
2550
2551
2552
2553 st.wantRSTStream(1, ErrCodeStreamClosed)
2554
2555
2556
2557 st.wantFlowControlConsumed(0, 0)
2558
2559
2560
2561 var (
2562 panMu sync.Mutex
2563 panicVal interface{}
2564 )
2565
2566 testHookOnPanicMu.Lock()
2567 testHookOnPanic = func(sc *serverConn, pv interface{}) bool {
2568 panMu.Lock()
2569 panicVal = pv
2570 panMu.Unlock()
2571 return true
2572 }
2573 testHookOnPanicMu.Unlock()
2574
2575
2576 st.cc.Close()
2577 <-st.sc.doneServing
2578
2579 panMu.Lock()
2580 got := panicVal
2581 panMu.Unlock()
2582 if got != nil {
2583 t.Errorf("Got panic: %v", got)
2584 }
2585 })
2586 }
2587
2588 func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2589 func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2590
2591 func testRejectTLS(t *testing.T, max uint16) {
2592 st := newServerTester(t, nil, func(c *tls.Config) {
2593
2594
2595
2596 c.MinVersion = tls.VersionTLS10
2597 c.MaxVersion = max
2598 })
2599 defer st.Close()
2600 gf := st.wantGoAway()
2601 if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2602 t.Errorf("Got error code %v; want %v", got, want)
2603 }
2604 }
2605
2606 func TestServer_Rejects_TLSBadCipher(t *testing.T) {
2607 st := newServerTester(t, nil, func(c *tls.Config) {
2608
2609 c.MaxVersion = tls.VersionTLS12
2610
2611 c.CipherSuites = []uint16{
2612 tls.TLS_RSA_WITH_RC4_128_SHA,
2613 tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
2614 tls.TLS_RSA_WITH_AES_128_CBC_SHA,
2615 tls.TLS_RSA_WITH_AES_256_CBC_SHA,
2616 tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
2617 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
2618 tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
2619 tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
2620 tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
2621 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
2622 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
2623 cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
2624 }
2625 })
2626 defer st.Close()
2627 gf := st.wantGoAway()
2628 if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2629 t.Errorf("Got error code %v; want %v", got, want)
2630 }
2631 }
2632
2633 func TestServer_Advertises_Common_Cipher(t *testing.T) {
2634 const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2635 st := newServerTester(t, nil, func(c *tls.Config) {
2636
2637 c.CipherSuites = []uint16{requiredSuite}
2638 }, func(ts *httptest.Server) {
2639 var srv *http.Server = ts.Config
2640
2641
2642 srv.TLSConfig = nil
2643 })
2644 defer st.Close()
2645 st.greet()
2646 }
2647
2648 func (st *serverTester) onHeaderField(f hpack.HeaderField) {
2649 if f.Name == "date" {
2650 return
2651 }
2652 st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value})
2653 }
2654
2655 func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) {
2656 st.decodedHeaders = nil
2657 if _, err := st.hpackDec.Write(headerBlock); err != nil {
2658 st.t.Fatalf("hpack decoding error: %v", err)
2659 }
2660 if err := st.hpackDec.Close(); err != nil {
2661 st.t.Fatalf("hpack decoding error: %v", err)
2662 }
2663 return st.decodedHeaders
2664 }
2665
2666
2667
2668 func testServerResponse(t testing.TB,
2669 handler func(http.ResponseWriter, *http.Request) error,
2670 client func(*serverTester),
2671 ) {
2672 errc := make(chan error, 1)
2673 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2674 if r.Body == nil {
2675 t.Fatal("nil Body")
2676 }
2677 err := handler(w, r)
2678 select {
2679 case errc <- err:
2680 default:
2681 t.Errorf("unexpected duplicate request")
2682 }
2683 })
2684 defer st.Close()
2685
2686 st.greet()
2687 client(st)
2688
2689 if err := <-errc; err != nil {
2690 t.Fatalf("Error in handler: %v", err)
2691 }
2692 }
2693
2694
2695
2696
2697 func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) {
2698 return func(w http.ResponseWriter, r *http.Request) {
2699 buf := make([]byte, len(want))
2700 _, err := io.ReadFull(r.Body, buf)
2701 if err != nil {
2702 t.Error(err)
2703 return
2704 }
2705 if string(buf) != want {
2706 t.Errorf("read %q; want %q", buf, want)
2707 }
2708 }
2709 }
2710
2711 func TestServer_MaxDecoderHeaderTableSize(t *testing.T) {
2712 wantHeaderTableSize := uint32(initialHeaderTableSize * 2)
2713 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) {
2714 s.MaxDecoderHeaderTableSize = wantHeaderTableSize
2715 })
2716 defer st.Close()
2717
2718 var advHeaderTableSize *uint32
2719 st.greetAndCheckSettings(func(s Setting) error {
2720 switch s.ID {
2721 case SettingHeaderTableSize:
2722 advHeaderTableSize = &s.Val
2723 }
2724 return nil
2725 })
2726
2727 if advHeaderTableSize == nil {
2728 t.Errorf("server didn't advertise a header table size")
2729 } else if got, want := *advHeaderTableSize, wantHeaderTableSize; got != want {
2730 t.Errorf("server advertised a header table size of %d, want %d", got, want)
2731 }
2732 }
2733
2734 func TestServer_MaxEncoderHeaderTableSize(t *testing.T) {
2735 wantHeaderTableSize := uint32(initialHeaderTableSize / 2)
2736 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) {
2737 s.MaxEncoderHeaderTableSize = wantHeaderTableSize
2738 })
2739 defer st.Close()
2740
2741 st.greet()
2742
2743 if got, want := st.sc.hpackEncoder.MaxDynamicTableSize(), wantHeaderTableSize; got != want {
2744 t.Errorf("server encoder is using a header table size of %d, want %d", got, want)
2745 }
2746 }
2747
2748
2749 func TestServerDoS_MaxHeaderListSize(t *testing.T) {
2750 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2751 defer st.Close()
2752
2753
2754 frameSize := defaultMaxReadFrameSize
2755 var advHeaderListSize *uint32
2756 st.greetAndCheckSettings(func(s Setting) error {
2757 switch s.ID {
2758 case SettingMaxFrameSize:
2759 if s.Val < minMaxFrameSize {
2760 frameSize = minMaxFrameSize
2761 } else if s.Val > maxFrameSize {
2762 frameSize = maxFrameSize
2763 } else {
2764 frameSize = int(s.Val)
2765 }
2766 case SettingMaxHeaderListSize:
2767 advHeaderListSize = &s.Val
2768 }
2769 return nil
2770 })
2771
2772 if advHeaderListSize == nil {
2773 t.Errorf("server didn't advertise a max header list size")
2774 } else if *advHeaderListSize == 0 {
2775 t.Errorf("server advertised a max header list size of 0")
2776 }
2777
2778 st.encodeHeaderField(":method", "GET")
2779 st.encodeHeaderField(":path", "/")
2780 st.encodeHeaderField(":scheme", "https")
2781 cookie := strings.Repeat("*", 4058)
2782 st.encodeHeaderField("cookie", cookie)
2783 st.writeHeaders(HeadersFrameParam{
2784 StreamID: 1,
2785 BlockFragment: st.headerBuf.Bytes(),
2786 EndStream: true,
2787 EndHeaders: false,
2788 })
2789
2790
2791
2792 st.headerBuf.Reset()
2793 st.encodeHeaderField("cookie", cookie)
2794
2795
2796 const size = 1 << 20
2797 b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2798 for len(b) > 0 {
2799 chunk := b
2800 if len(chunk) > frameSize {
2801 chunk = chunk[:frameSize]
2802 }
2803 b = b[len(chunk):]
2804 st.fr.WriteContinuation(1, len(b) == 0, chunk)
2805 }
2806
2807 h := st.wantHeaders()
2808 if !h.HeadersEnded() {
2809 t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2810 }
2811 headers := st.decodeHeader(h.HeaderBlockFragment())
2812 want := [][2]string{
2813 {":status", "431"},
2814 {"content-type", "text/html; charset=utf-8"},
2815 {"content-length", "63"},
2816 }
2817 if !reflect.DeepEqual(headers, want) {
2818 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2819 }
2820 }
2821
2822 func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) {
2823 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2824 w.Header().Set("Trailer", "test-trailer")
2825 return nil
2826 }, func(st *serverTester) {
2827 getSlash(st)
2828 hf := st.wantHeaders()
2829 if !hf.HeadersEnded() {
2830 t.Fatal("want END_HEADERS flag")
2831 }
2832 df := st.wantData()
2833 if len(df.data) != 0 {
2834 t.Fatal("did not want data")
2835 }
2836 if !df.StreamEnded() {
2837 t.Fatal("want END_STREAM flag")
2838 }
2839 })
2840 }
2841
2842 func TestCompressionErrorOnWrite(t *testing.T) {
2843 const maxStrLen = 8 << 10
2844 var serverConfig *http.Server
2845 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2846
2847 }, func(ts *httptest.Server) {
2848 serverConfig = ts.Config
2849 serverConfig.MaxHeaderBytes = maxStrLen
2850 })
2851 st.addLogFilter("connection error: COMPRESSION_ERROR")
2852 defer st.Close()
2853 st.greet()
2854
2855 maxAllowed := st.sc.framer.maxHeaderStringLen()
2856
2857
2858
2859
2860
2861
2862 serverConfig.MaxHeaderBytes = 1 << 20
2863
2864
2865
2866
2867
2868 hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2869
2870 st.writeHeaders(HeadersFrameParam{
2871 StreamID: 1,
2872 BlockFragment: hbf,
2873 EndStream: true,
2874 EndHeaders: true,
2875 })
2876 h := st.wantHeaders()
2877 if !h.HeadersEnded() {
2878 t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2879 }
2880 headers := st.decodeHeader(h.HeaderBlockFragment())
2881 want := [][2]string{
2882 {":status", "431"},
2883 {"content-type", "text/html; charset=utf-8"},
2884 {"content-length", "63"},
2885 }
2886 if !reflect.DeepEqual(headers, want) {
2887 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2888 }
2889 df := st.wantData()
2890 if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2891 t.Errorf("Unexpected data body: %q", df.Data())
2892 }
2893 if !df.StreamEnded() {
2894 t.Fatalf("expect data stream end")
2895 }
2896
2897
2898 hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2899 st.writeHeaders(HeadersFrameParam{
2900 StreamID: 3,
2901 BlockFragment: hbf,
2902 EndStream: true,
2903 EndHeaders: true,
2904 })
2905 ga := st.wantGoAway()
2906 if ga.ErrCode != ErrCodeCompression {
2907 t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2908 }
2909 }
2910
2911 func TestCompressionErrorOnClose(t *testing.T) {
2912 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2913
2914 })
2915 st.addLogFilter("connection error: COMPRESSION_ERROR")
2916 defer st.Close()
2917 st.greet()
2918
2919 hbf := st.encodeHeader("foo", "bar")
2920 hbf = hbf[:len(hbf)-1]
2921 st.writeHeaders(HeadersFrameParam{
2922 StreamID: 1,
2923 BlockFragment: hbf,
2924 EndStream: true,
2925 EndHeaders: true,
2926 })
2927 ga := st.wantGoAway()
2928 if ga.ErrCode != ErrCodeCompression {
2929 t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2930 }
2931 }
2932
2933
2934 func TestServerReadsTrailers(t *testing.T) {
2935 const testBody = "some test body"
2936 writeReq := func(st *serverTester) {
2937 st.writeHeaders(HeadersFrameParam{
2938 StreamID: 1,
2939 BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2940 EndStream: false,
2941 EndHeaders: true,
2942 })
2943 st.writeData(1, false, []byte(testBody))
2944 st.writeHeaders(HeadersFrameParam{
2945 StreamID: 1,
2946 BlockFragment: st.encodeHeaderRaw(
2947 "foo", "foov",
2948 "bar", "barv",
2949 "baz", "bazv",
2950 "surprise", "wasn't declared; shouldn't show up",
2951 ),
2952 EndStream: true,
2953 EndHeaders: true,
2954 })
2955 }
2956 checkReq := func(r *http.Request) {
2957 wantTrailer := http.Header{
2958 "Foo": nil,
2959 "Bar": nil,
2960 "Baz": nil,
2961 }
2962 if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2963 t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2964 }
2965 slurp, err := ioutil.ReadAll(r.Body)
2966 if string(slurp) != testBody {
2967 t.Errorf("read body %q; want %q", slurp, testBody)
2968 }
2969 if err != nil {
2970 t.Fatalf("Body slurp: %v", err)
2971 }
2972 wantTrailerAfter := http.Header{
2973 "Foo": {"foov"},
2974 "Bar": {"barv"},
2975 "Baz": {"bazv"},
2976 }
2977 if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2978 t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2979 }
2980 }
2981 testServerRequest(t, writeReq, checkReq)
2982 }
2983
2984
2985 func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) }
2986 func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2987
2988 func testServerWritesTrailers(t *testing.T, withFlush bool) {
2989
2990 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2991 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2992 w.Header().Add("Trailer", "Server-Trailer-C")
2993 w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer")
2994
2995
2996 w.Header().Set("Foo", "Bar")
2997 w.Header().Set("Content-Length", "5")
2998
2999 io.WriteString(w, "Hello")
3000 if withFlush {
3001 w.(http.Flusher).Flush()
3002 }
3003 w.Header().Set("Server-Trailer-A", "valuea")
3004 w.Header().Set("Server-Trailer-C", "valuec")
3005
3006 w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
3007
3008
3009
3010 w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
3011 w.Header().Set("Trailer:post-header-trailer2", "hi2")
3012 w.Header().Set("Trailer:Range", "invalid")
3013 w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
3014 w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2")
3015 w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2")
3016 w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
3017 return nil
3018 }, func(st *serverTester) {
3019 getSlash(st)
3020 hf := st.wantHeaders()
3021 if hf.StreamEnded() {
3022 t.Fatal("response HEADERS had END_STREAM")
3023 }
3024 if !hf.HeadersEnded() {
3025 t.Fatal("response HEADERS didn't have END_HEADERS")
3026 }
3027 goth := st.decodeHeader(hf.HeaderBlockFragment())
3028 wanth := [][2]string{
3029 {":status", "200"},
3030 {"foo", "Bar"},
3031 {"trailer", "Server-Trailer-A, Server-Trailer-B"},
3032 {"trailer", "Server-Trailer-C"},
3033 {"trailer", "Transfer-Encoding, Content-Length, Trailer"},
3034 {"content-type", "text/plain; charset=utf-8"},
3035 {"content-length", "5"},
3036 }
3037 if !reflect.DeepEqual(goth, wanth) {
3038 t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
3039 }
3040 df := st.wantData()
3041 if string(df.Data()) != "Hello" {
3042 t.Fatalf("Client read %q; want Hello", df.Data())
3043 }
3044 if df.StreamEnded() {
3045 t.Fatalf("data frame had STREAM_ENDED")
3046 }
3047 tf := st.wantHeaders()
3048 if !tf.StreamEnded() {
3049 t.Fatalf("trailers HEADERS lacked END_STREAM")
3050 }
3051 if !tf.HeadersEnded() {
3052 t.Fatalf("trailers HEADERS lacked END_HEADERS")
3053 }
3054 wanth = [][2]string{
3055 {"post-header-trailer", "hi1"},
3056 {"post-header-trailer2", "hi2"},
3057 {"server-trailer-a", "valuea"},
3058 {"server-trailer-c", "valuec"},
3059 }
3060 goth = st.decodeHeader(tf.HeaderBlockFragment())
3061 if !reflect.DeepEqual(goth, wanth) {
3062 t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
3063 }
3064 })
3065 }
3066
3067 func TestServerWritesUndeclaredTrailers(t *testing.T) {
3068 const trailer = "Trailer-Header"
3069 const value = "hi1"
3070 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3071 w.Header().Set(http.TrailerPrefix+trailer, value)
3072 }, optOnlyServer)
3073 defer st.Close()
3074
3075 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3076 defer tr.CloseIdleConnections()
3077
3078 cl := &http.Client{Transport: tr}
3079 resp, err := cl.Get(st.ts.URL)
3080 if err != nil {
3081 t.Fatal(err)
3082 }
3083 io.Copy(io.Discard, resp.Body)
3084 resp.Body.Close()
3085
3086 if got, want := resp.Trailer.Get(trailer), value; got != want {
3087 t.Errorf("trailer %v = %q, want %q", trailer, got, want)
3088 }
3089 }
3090
3091
3092
3093 func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
3094 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3095 w.Header().Add("OK1", "x")
3096 w.Header().Add("Bad:Colon", "x")
3097 w.Header().Add("Bad1\x00", "x")
3098 w.Header().Add("Bad2", "x\x00y")
3099 return nil
3100 }, func(st *serverTester) {
3101 getSlash(st)
3102 hf := st.wantHeaders()
3103 if !hf.StreamEnded() {
3104 t.Error("response HEADERS lacked END_STREAM")
3105 }
3106 if !hf.HeadersEnded() {
3107 t.Fatal("response HEADERS didn't have END_HEADERS")
3108 }
3109 goth := st.decodeHeader(hf.HeaderBlockFragment())
3110 wanth := [][2]string{
3111 {":status", "200"},
3112 {"ok1", "x"},
3113 {"content-length", "0"},
3114 }
3115 if !reflect.DeepEqual(goth, wanth) {
3116 t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
3117 }
3118 })
3119 }
3120
3121 func BenchmarkServerGets(b *testing.B) {
3122 defer disableGoroutineTracking()()
3123 b.ReportAllocs()
3124
3125 const msg = "Hello, world"
3126 st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3127 io.WriteString(w, msg)
3128 })
3129 defer st.Close()
3130 st.greet()
3131
3132
3133 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3134 b.Fatal(err)
3135 }
3136
3137 for i := 0; i < b.N; i++ {
3138 id := 1 + uint32(i)*2
3139 st.writeHeaders(HeadersFrameParam{
3140 StreamID: id,
3141 BlockFragment: st.encodeHeader(),
3142 EndStream: true,
3143 EndHeaders: true,
3144 })
3145 st.wantHeaders()
3146 df := st.wantData()
3147 if !df.StreamEnded() {
3148 b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3149 }
3150 }
3151 }
3152
3153 func BenchmarkServerPosts(b *testing.B) {
3154 defer disableGoroutineTracking()()
3155 b.ReportAllocs()
3156
3157 const msg = "Hello, world"
3158 st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3159
3160
3161
3162 if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3163 b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3164 }
3165 io.WriteString(w, msg)
3166 })
3167 defer st.Close()
3168 st.greet()
3169
3170
3171 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3172 b.Fatal(err)
3173 }
3174
3175 for i := 0; i < b.N; i++ {
3176 id := 1 + uint32(i)*2
3177 st.writeHeaders(HeadersFrameParam{
3178 StreamID: id,
3179 BlockFragment: st.encodeHeader(":method", "POST"),
3180 EndStream: false,
3181 EndHeaders: true,
3182 })
3183 st.writeData(id, true, nil)
3184 st.wantHeaders()
3185 df := st.wantData()
3186 if !df.StreamEnded() {
3187 b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3188 }
3189 }
3190 }
3191
3192
3193
3194
3195 func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) {
3196 benchmarkServerToClientStream(b)
3197 }
3198
3199
3200
3201 func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
3202 benchmarkServerToClientStream(b, optFramerReuseFrames)
3203 }
3204
3205 func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
3206 defer disableGoroutineTracking()()
3207 b.ReportAllocs()
3208 const msgLen = 1
3209
3210 const windowSize = 1<<16 - 1
3211
3212
3213 nextMsg := func(i int) []byte {
3214 msg := make([]byte, msgLen)
3215 msg[0] = byte(i)
3216 if len(msg) != msgLen {
3217 panic("invalid test setup msg length")
3218 }
3219 return msg
3220 }
3221
3222 st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3223
3224
3225
3226 if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3227 b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3228 }
3229 for i := 0; i < b.N; i += 1 {
3230 w.Write(nextMsg(i))
3231 w.(http.Flusher).Flush()
3232 }
3233 }, newServerOpts...)
3234 defer st.Close()
3235 st.greet()
3236
3237 const id = uint32(1)
3238
3239 st.writeHeaders(HeadersFrameParam{
3240 StreamID: id,
3241 BlockFragment: st.encodeHeader(":method", "POST"),
3242 EndStream: false,
3243 EndHeaders: true,
3244 })
3245
3246 st.writeData(id, true, nil)
3247 st.wantHeaders()
3248
3249 var pendingWindowUpdate = uint32(0)
3250
3251 for i := 0; i < b.N; i += 1 {
3252 expected := nextMsg(i)
3253 df := st.wantData()
3254 if bytes.Compare(expected, df.data) != 0 {
3255 b.Fatalf("Bad message received; want %v; got %v", expected, df.data)
3256 }
3257
3258 pendingWindowUpdate += uint32(len(df.data))
3259 if pendingWindowUpdate >= windowSize/2 {
3260 if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
3261 b.Fatal(err)
3262 }
3263 if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil {
3264 b.Fatal(err)
3265 }
3266 pendingWindowUpdate = 0
3267 }
3268 }
3269 df := st.wantData()
3270 if !df.StreamEnded() {
3271 b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3272 }
3273 }
3274
3275
3276
3277 func TestIssue53(t *testing.T) {
3278 const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3279 "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3280 s := &http.Server{
3281 ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
3282 Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3283 w.Write([]byte("hello"))
3284 }),
3285 }
3286 s2 := &Server{
3287 MaxReadFrameSize: 1 << 16,
3288 PermitProhibitedCipherSuites: true,
3289 }
3290 c := &issue53Conn{[]byte(data), false, false}
3291 s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
3292 if !c.closed {
3293 t.Fatal("connection is not closed")
3294 }
3295 }
3296
3297 type issue53Conn struct {
3298 data []byte
3299 closed bool
3300 written bool
3301 }
3302
3303 func (c *issue53Conn) Read(b []byte) (n int, err error) {
3304 if len(c.data) == 0 {
3305 return 0, io.EOF
3306 }
3307 n = copy(b, c.data)
3308 c.data = c.data[n:]
3309 return
3310 }
3311
3312 func (c *issue53Conn) Write(b []byte) (n int, err error) {
3313 c.written = true
3314 return len(b), nil
3315 }
3316
3317 func (c *issue53Conn) Close() error {
3318 c.closed = true
3319 return nil
3320 }
3321
3322 func (c *issue53Conn) LocalAddr() net.Addr {
3323 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3324 }
3325 func (c *issue53Conn) RemoteAddr() net.Addr {
3326 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3327 }
3328 func (c *issue53Conn) SetDeadline(t time.Time) error { return nil }
3329 func (c *issue53Conn) SetReadDeadline(t time.Time) error { return nil }
3330 func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }
3331
3332
3333 func TestServeConnOptsNilReceiverBehavior(t *testing.T) {
3334 defer func() {
3335 if r := recover(); r != nil {
3336 t.Errorf("got a panic that should not happen: %v", r)
3337 }
3338 }()
3339
3340 var o *ServeConnOpts
3341 if o.context() == nil {
3342 t.Error("o.context should not return nil")
3343 }
3344 if o.baseConfig() == nil {
3345 t.Error("o.baseConfig should not return nil")
3346 }
3347 if o.handler() == nil {
3348 t.Error("o.handler should not return nil")
3349 }
3350 }
3351
3352
3353 func TestConfigureServer(t *testing.T) {
3354 tests := []struct {
3355 name string
3356 tlsConfig *tls.Config
3357 wantErr string
3358 }{
3359 {
3360 name: "empty server",
3361 },
3362 {
3363 name: "empty CipherSuites",
3364 tlsConfig: &tls.Config{},
3365 },
3366 {
3367 name: "bad CipherSuites but MinVersion TLS 1.3",
3368 tlsConfig: &tls.Config{
3369 MinVersion: tls.VersionTLS13,
3370 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3371 },
3372 },
3373 {
3374 name: "just the required cipher suite",
3375 tlsConfig: &tls.Config{
3376 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3377 },
3378 },
3379 {
3380 name: "just the alternative required cipher suite",
3381 tlsConfig: &tls.Config{
3382 CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
3383 },
3384 },
3385 {
3386 name: "missing required cipher suite",
3387 tlsConfig: &tls.Config{
3388 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3389 },
3390 wantErr: "is missing an HTTP/2-required",
3391 },
3392 {
3393 name: "required after bad",
3394 tlsConfig: &tls.Config{
3395 CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3396 },
3397 },
3398 {
3399 name: "bad after required",
3400 tlsConfig: &tls.Config{
3401 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3402 },
3403 },
3404 }
3405 for _, tt := range tests {
3406 srv := &http.Server{TLSConfig: tt.tlsConfig}
3407 err := ConfigureServer(srv, nil)
3408 if (err != nil) != (tt.wantErr != "") {
3409 if tt.wantErr != "" {
3410 t.Errorf("%s: success, but want error", tt.name)
3411 } else {
3412 t.Errorf("%s: unexpected error: %v", tt.name, err)
3413 }
3414 }
3415 if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3416 t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3417 }
3418 if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3419 t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3420 }
3421 }
3422 }
3423
3424 func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3425 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3426
3427 })
3428 defer st.Close()
3429 st.greet()
3430 st.writeHeaders(HeadersFrameParam{
3431 StreamID: 1,
3432 BlockFragment: st.encodeHeader(":method", "HEAD"),
3433 EndStream: true,
3434 EndHeaders: true,
3435 })
3436 h := st.wantHeaders()
3437 headers := st.decodeHeader(h.HeaderBlockFragment())
3438 want := [][2]string{
3439 {":status", "200"},
3440 }
3441 if !reflect.DeepEqual(headers, want) {
3442 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3443 }
3444 }
3445
3446
3447 func TestServerNoDuplicateContentType(t *testing.T) {
3448 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3449 w.Header()["Content-Type"] = []string{""}
3450 fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3451 })
3452 defer st.Close()
3453 st.greet()
3454 st.writeHeaders(HeadersFrameParam{
3455 StreamID: 1,
3456 BlockFragment: st.encodeHeader(),
3457 EndStream: true,
3458 EndHeaders: true,
3459 })
3460 h := st.wantHeaders()
3461 headers := st.decodeHeader(h.HeaderBlockFragment())
3462 want := [][2]string{
3463 {":status", "200"},
3464 {"content-type", ""},
3465 {"content-length", "41"},
3466 }
3467 if !reflect.DeepEqual(headers, want) {
3468 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3469 }
3470 }
3471
3472 func TestServerContentLengthCanBeDisabled(t *testing.T) {
3473 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3474 w.Header()["Content-Length"] = nil
3475 fmt.Fprintf(w, "OK")
3476 })
3477 defer st.Close()
3478 st.greet()
3479 st.writeHeaders(HeadersFrameParam{
3480 StreamID: 1,
3481 BlockFragment: st.encodeHeader(),
3482 EndStream: true,
3483 EndHeaders: true,
3484 })
3485 h := st.wantHeaders()
3486 headers := st.decodeHeader(h.HeaderBlockFragment())
3487 want := [][2]string{
3488 {":status", "200"},
3489 {"content-type", "text/plain; charset=utf-8"},
3490 }
3491 if !reflect.DeepEqual(headers, want) {
3492 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3493 }
3494 }
3495
3496 func disableGoroutineTracking() (restore func()) {
3497 old := DebugGoroutines
3498 DebugGoroutines = false
3499 return func() { DebugGoroutines = old }
3500 }
3501
3502 func BenchmarkServer_GetRequest(b *testing.B) {
3503 defer disableGoroutineTracking()()
3504 b.ReportAllocs()
3505 const msg = "Hello, world."
3506 st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3507 n, err := io.Copy(ioutil.Discard, r.Body)
3508 if err != nil || n > 0 {
3509 b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3510 }
3511 io.WriteString(w, msg)
3512 })
3513 defer st.Close()
3514
3515 st.greet()
3516
3517 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3518 b.Fatal(err)
3519 }
3520 hbf := st.encodeHeader(":method", "GET")
3521 for i := 0; i < b.N; i++ {
3522 streamID := uint32(1 + 2*i)
3523 st.writeHeaders(HeadersFrameParam{
3524 StreamID: streamID,
3525 BlockFragment: hbf,
3526 EndStream: true,
3527 EndHeaders: true,
3528 })
3529 st.wantHeaders()
3530 st.wantData()
3531 }
3532 }
3533
3534 func BenchmarkServer_PostRequest(b *testing.B) {
3535 defer disableGoroutineTracking()()
3536 b.ReportAllocs()
3537 const msg = "Hello, world."
3538 st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3539 n, err := io.Copy(ioutil.Discard, r.Body)
3540 if err != nil || n > 0 {
3541 b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3542 }
3543 io.WriteString(w, msg)
3544 })
3545 defer st.Close()
3546 st.greet()
3547
3548 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3549 b.Fatal(err)
3550 }
3551 hbf := st.encodeHeader(":method", "POST")
3552 for i := 0; i < b.N; i++ {
3553 streamID := uint32(1 + 2*i)
3554 st.writeHeaders(HeadersFrameParam{
3555 StreamID: streamID,
3556 BlockFragment: hbf,
3557 EndStream: false,
3558 EndHeaders: true,
3559 })
3560 st.writeData(streamID, true, nil)
3561 st.wantHeaders()
3562 st.wantData()
3563 }
3564 }
3565
3566 type connStateConn struct {
3567 net.Conn
3568 cs tls.ConnectionState
3569 }
3570
3571 func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
3572
3573
3574
3575 func TestServerHandleCustomConn(t *testing.T) {
3576 var s Server
3577 c1, c2 := net.Pipe()
3578 clientDone := make(chan struct{})
3579 handlerDone := make(chan struct{})
3580 var req *http.Request
3581 go func() {
3582 defer close(clientDone)
3583 defer c2.Close()
3584 fr := NewFramer(c2, c2)
3585 io.WriteString(c2, ClientPreface)
3586 fr.WriteSettings()
3587 fr.WriteSettingsAck()
3588 f, err := fr.ReadFrame()
3589 if err != nil {
3590 t.Error(err)
3591 return
3592 }
3593 if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
3594 t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
3595 return
3596 }
3597 f, err = fr.ReadFrame()
3598 if err != nil {
3599 t.Error(err)
3600 return
3601 }
3602 if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
3603 t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
3604 return
3605 }
3606 var henc hpackEncoder
3607 fr.WriteHeaders(HeadersFrameParam{
3608 StreamID: 1,
3609 BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
3610 EndStream: true,
3611 EndHeaders: true,
3612 })
3613 go io.Copy(ioutil.Discard, c2)
3614 <-handlerDone
3615 }()
3616 const testString = "my custom ConnectionState"
3617 fakeConnState := tls.ConnectionState{
3618 ServerName: testString,
3619 Version: tls.VersionTLS12,
3620 CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
3621 }
3622 go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
3623 BaseConfig: &http.Server{
3624 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3625 defer close(handlerDone)
3626 req = r
3627 }),
3628 }})
3629 <-clientDone
3630
3631 if req.TLS == nil {
3632 t.Fatalf("Request.TLS is nil. Got: %#v", req)
3633 }
3634 if req.TLS.ServerName != testString {
3635 t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
3636 }
3637 }
3638
3639
3640 func TestServer_Rejects_ConnHeaders(t *testing.T) {
3641 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3642 t.Error("should not get to Handler")
3643 })
3644 defer st.Close()
3645 st.greet()
3646 st.bodylessReq1("connection", "foo")
3647 hf := st.wantHeaders()
3648 goth := st.decodeHeader(hf.HeaderBlockFragment())
3649 wanth := [][2]string{
3650 {":status", "400"},
3651 {"content-type", "text/plain; charset=utf-8"},
3652 {"x-content-type-options", "nosniff"},
3653 {"content-length", "51"},
3654 }
3655 if !reflect.DeepEqual(goth, wanth) {
3656 t.Errorf("Got headers %v; want %v", goth, wanth)
3657 }
3658 }
3659
3660 type hpackEncoder struct {
3661 enc *hpack.Encoder
3662 buf bytes.Buffer
3663 }
3664
3665 func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
3666 if len(headers)%2 == 1 {
3667 panic("odd number of kv args")
3668 }
3669 he.buf.Reset()
3670 if he.enc == nil {
3671 he.enc = hpack.NewEncoder(&he.buf)
3672 }
3673 for len(headers) > 0 {
3674 k, v := headers[0], headers[1]
3675 err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3676 if err != nil {
3677 t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3678 }
3679 headers = headers[2:]
3680 }
3681 return he.buf.Bytes()
3682 }
3683
3684 func TestCheckValidHTTP2Request(t *testing.T) {
3685 tests := []struct {
3686 h http.Header
3687 want error
3688 }{
3689 {
3690 h: http.Header{"Te": {"trailers"}},
3691 want: nil,
3692 },
3693 {
3694 h: http.Header{"Te": {"trailers", "bogus"}},
3695 want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
3696 },
3697 {
3698 h: http.Header{"Foo": {""}},
3699 want: nil,
3700 },
3701 {
3702 h: http.Header{"Connection": {""}},
3703 want: errors.New(`request header "Connection" is not valid in HTTP/2`),
3704 },
3705 {
3706 h: http.Header{"Proxy-Connection": {""}},
3707 want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
3708 },
3709 {
3710 h: http.Header{"Keep-Alive": {""}},
3711 want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
3712 },
3713 {
3714 h: http.Header{"Upgrade": {""}},
3715 want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
3716 },
3717 }
3718 for i, tt := range tests {
3719 got := checkValidHTTP2RequestHeaders(tt.h)
3720 if !equalError(got, tt.want) {
3721 t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
3722 }
3723 }
3724 }
3725
3726
3727 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3728 const msg = "Hello"
3729 const msg2 = "World"
3730
3731 doRead := make(chan bool, 1)
3732 defer close(doRead)
3733
3734 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3735 io.WriteString(w, msg)
3736 w.(http.Flusher).Flush()
3737
3738
3739 <-doRead
3740 r.Body.Read(make([]byte, 10))
3741
3742 io.WriteString(w, msg2)
3743
3744 }, optOnlyServer)
3745 defer st.Close()
3746
3747 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3748 defer tr.CloseIdleConnections()
3749
3750 req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3751 req.Header.Set("Expect", "100-continue")
3752
3753 res, err := tr.RoundTrip(req)
3754 if err != nil {
3755 t.Fatal(err)
3756 }
3757 defer res.Body.Close()
3758
3759 buf := make([]byte, len(msg))
3760 if _, err := io.ReadFull(res.Body, buf); err != nil {
3761 t.Fatal(err)
3762 }
3763 if string(buf) != msg {
3764 t.Fatalf("msg = %q; want %q", buf, msg)
3765 }
3766
3767 doRead <- true
3768
3769 if _, err := io.ReadFull(res.Body, buf); err != nil {
3770 t.Fatal(err)
3771 }
3772 if string(buf) != msg2 {
3773 t.Fatalf("second msg = %q; want %q", buf, msg2)
3774 }
3775 }
3776
3777 type funcReader func([]byte) (n int, err error)
3778
3779 func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3780
3781
3782
3783 func TestUnreadFlowControlReturned_Server(t *testing.T) {
3784 for _, tt := range []struct {
3785 name string
3786 reqFn func(r *http.Request)
3787 }{
3788 {
3789 "body-open",
3790 func(r *http.Request) {},
3791 },
3792 {
3793 "body-closed",
3794 func(r *http.Request) {
3795 r.Body.Close()
3796 },
3797 },
3798 {
3799 "read-1-byte-and-close",
3800 func(r *http.Request) {
3801 b := make([]byte, 1)
3802 r.Body.Read(b)
3803 r.Body.Close()
3804 },
3805 },
3806 } {
3807 t.Run(tt.name, func(t *testing.T) {
3808 unblock := make(chan bool, 1)
3809 defer close(unblock)
3810
3811 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3812
3813
3814
3815 tt.reqFn(r)
3816 <-unblock
3817 }, optOnlyServer)
3818 defer st.Close()
3819
3820 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3821 defer tr.CloseIdleConnections()
3822
3823
3824 iters := 100
3825 if testing.Short() {
3826 iters = 20
3827 }
3828 for i := 0; i < iters; i++ {
3829 body := io.MultiReader(
3830 io.LimitReader(neverEnding('A'), 16<<10),
3831 funcReader(func([]byte) (n int, err error) {
3832 unblock <- true
3833 return 0, io.EOF
3834 }),
3835 )
3836 req, _ := http.NewRequest("POST", st.ts.URL, body)
3837 res, err := tr.RoundTrip(req)
3838 if err != nil {
3839 t.Fatal(tt.name, err)
3840 }
3841 res.Body.Close()
3842 }
3843 })
3844 }
3845 }
3846
3847 func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) {
3848 unblockHandler := make(chan struct{})
3849 defer close(unblockHandler)
3850
3851 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3852 r.Body.Close()
3853 w.WriteHeader(200)
3854 w.(http.Flusher).Flush()
3855 <-unblockHandler
3856 })
3857 defer st.Close()
3858
3859 st.greet()
3860 st.writeHeaders(HeadersFrameParam{
3861 StreamID: 1,
3862 BlockFragment: st.encodeHeader(),
3863 EndHeaders: true,
3864 })
3865 st.wantHeaders()
3866 const size = inflowMinRefresh
3867 st.writeData(1, false, make([]byte, size))
3868 st.wantWindowUpdate(0, size)
3869 unblockHandler <- struct{}{}
3870 st.wantData()
3871 }
3872
3873 func TestServerIdleTimeout(t *testing.T) {
3874 if testing.Short() {
3875 t.Skip("skipping in short mode")
3876 }
3877
3878 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3879 }, func(h2s *Server) {
3880 h2s.IdleTimeout = 500 * time.Millisecond
3881 })
3882 defer st.Close()
3883
3884 st.greet()
3885 ga := st.wantGoAway()
3886 if ga.ErrCode != ErrCodeNo {
3887 t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3888 }
3889 }
3890
3891 func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3892 if testing.Short() {
3893 t.Skip("skipping in short mode")
3894 }
3895 const timeout = 250 * time.Millisecond
3896
3897 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3898 time.Sleep(timeout * 2)
3899 }, func(h2s *Server) {
3900 h2s.IdleTimeout = timeout
3901 })
3902 defer st.Close()
3903
3904 st.greet()
3905
3906
3907
3908 st.bodylessReq1()
3909 st.wantHeaders()
3910
3911
3912
3913 ga := st.wantGoAway()
3914 if ga.ErrCode != ErrCodeNo {
3915 t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3916 }
3917 }
3918
3919
3920
3921
3922 func TestRequestBodyReadCloseRace(t *testing.T) {
3923 for i := 0; i < 100; i++ {
3924 body := &requestBody{
3925 pipe: &pipe{
3926 b: new(bytes.Buffer),
3927 },
3928 }
3929 body.pipe.CloseWithError(io.EOF)
3930
3931 done := make(chan bool, 1)
3932 buf := make([]byte, 10)
3933 go func() {
3934 time.Sleep(1 * time.Millisecond)
3935 body.Close()
3936 done <- true
3937 }()
3938 body.Read(buf)
3939 <-done
3940 }
3941 }
3942
3943 func TestIssue20704Race(t *testing.T) {
3944 if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3945 t.Skip("skipping in short mode")
3946 }
3947 const (
3948 itemSize = 1 << 10
3949 itemCount = 100
3950 )
3951
3952 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3953 for i := 0; i < itemCount; i++ {
3954 _, err := w.Write(make([]byte, itemSize))
3955 if err != nil {
3956 return
3957 }
3958 }
3959 }, optOnlyServer)
3960 defer st.Close()
3961
3962 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3963 defer tr.CloseIdleConnections()
3964 cl := &http.Client{Transport: tr}
3965
3966 for i := 0; i < 1000; i++ {
3967 resp, err := cl.Get(st.ts.URL)
3968 if err != nil {
3969 t.Fatal(err)
3970 }
3971
3972
3973 resp.Body.Close()
3974 }
3975 }
3976
3977 func TestServer_Rejects_TooSmall(t *testing.T) {
3978 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3979 ioutil.ReadAll(r.Body)
3980 return nil
3981 }, func(st *serverTester) {
3982 st.writeHeaders(HeadersFrameParam{
3983 StreamID: 1,
3984 BlockFragment: st.encodeHeader(
3985 ":method", "POST",
3986 "content-length", "4",
3987 ),
3988 EndStream: false,
3989 EndHeaders: true,
3990 })
3991 st.writeData(1, true, []byte("12345"))
3992 st.wantRSTStream(1, ErrCodeProtocol)
3993 st.wantFlowControlConsumed(0, 0)
3994 })
3995 }
3996
3997
3998
3999 func TestServerHandlerConnectionClose(t *testing.T) {
4000 unblockHandler := make(chan bool, 1)
4001 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
4002 w.Header().Set("Connection", "close")
4003 w.Header().Set("Foo", "bar")
4004 w.(http.Flusher).Flush()
4005 <-unblockHandler
4006 return nil
4007 }, func(st *serverTester) {
4008 defer close(unblockHandler)
4009 st.writeHeaders(HeadersFrameParam{
4010 StreamID: 1,
4011 BlockFragment: st.encodeHeader(),
4012 EndStream: true,
4013 EndHeaders: true,
4014 })
4015 var sawGoAway bool
4016 var sawRes bool
4017 var sawWindowUpdate bool
4018 for {
4019 f, err := st.readFrame()
4020 if err == io.EOF {
4021 break
4022 }
4023 if err != nil {
4024 t.Fatal(err)
4025 }
4026 switch f := f.(type) {
4027 case *GoAwayFrame:
4028 sawGoAway = true
4029 if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo {
4030 t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f))
4031 }
4032
4033
4034 st.writeHeaders(HeadersFrameParam{
4035 StreamID: 3,
4036 BlockFragment: st.encodeHeader(),
4037 EndStream: false,
4038 EndHeaders: true,
4039 })
4040 st.fr.WriteRSTStream(3, ErrCodeCancel)
4041
4042
4043
4044 st.writeHeaders(HeadersFrameParam{
4045 StreamID: 5,
4046 BlockFragment: st.encodeHeader(),
4047 EndStream: false,
4048 EndHeaders: true,
4049 })
4050
4051 st.writeData(5, true, make([]byte, 1<<19))
4052 case *HeadersFrame:
4053 goth := st.decodeHeader(f.HeaderBlockFragment())
4054 wanth := [][2]string{
4055 {":status", "200"},
4056 {"foo", "bar"},
4057 }
4058 if !reflect.DeepEqual(goth, wanth) {
4059 t.Errorf("got headers %v; want %v", goth, wanth)
4060 }
4061 sawRes = true
4062 case *DataFrame:
4063 if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 {
4064 t.Errorf("unexpected DATA frame: %v", summarizeFrame(f))
4065 }
4066 case *WindowUpdateFrame:
4067 if !sawGoAway {
4068 t.Errorf("unexpected WINDOW_UPDATE frame: %v", summarizeFrame(f))
4069 return
4070 }
4071 if f.StreamID != 0 {
4072 st.t.Fatalf("WindowUpdate StreamID = %d; want 5", f.FrameHeader.StreamID)
4073 return
4074 }
4075 sawWindowUpdate = true
4076 unblockHandler <- true
4077 default:
4078 t.Logf("unexpected frame: %v", summarizeFrame(f))
4079 }
4080 }
4081 if !sawGoAway {
4082 t.Errorf("didn't see GOAWAY")
4083 }
4084 if !sawRes {
4085 t.Errorf("didn't see response")
4086 }
4087 if !sawWindowUpdate {
4088 t.Errorf("didn't see WINDOW_UPDATE")
4089 }
4090 })
4091 }
4092
4093 func TestServer_Headers_HalfCloseRemote(t *testing.T) {
4094 var st *serverTester
4095 writeData := make(chan bool)
4096 writeHeaders := make(chan bool)
4097 leaveHandler := make(chan bool)
4098 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4099 if st.stream(1) == nil {
4100 t.Errorf("nil stream 1 in handler")
4101 }
4102 if got, want := st.streamState(1), stateOpen; got != want {
4103 t.Errorf("in handler, state is %v; want %v", got, want)
4104 }
4105 writeData <- true
4106 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
4107 t.Errorf("body read = %d, %v; want 0, EOF", n, err)
4108 }
4109 if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
4110 t.Errorf("in handler, state is %v; want %v", got, want)
4111 }
4112 writeHeaders <- true
4113
4114 <-leaveHandler
4115 })
4116 st.greet()
4117
4118 st.writeHeaders(HeadersFrameParam{
4119 StreamID: 1,
4120 BlockFragment: st.encodeHeader(),
4121 EndStream: false,
4122 EndHeaders: true,
4123 })
4124 <-writeData
4125 st.writeData(1, true, nil)
4126
4127 <-writeHeaders
4128
4129 st.writeHeaders(HeadersFrameParam{
4130 StreamID: 1,
4131 BlockFragment: st.encodeHeader(),
4132 EndStream: false,
4133 EndHeaders: true,
4134 })
4135
4136 defer close(leaveHandler)
4137
4138 st.wantRSTStream(1, ErrCodeStreamClosed)
4139 }
4140
4141 func TestServerGracefulShutdown(t *testing.T) {
4142 var st *serverTester
4143 handlerDone := make(chan struct{})
4144 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4145 defer close(handlerDone)
4146 go st.ts.Config.Shutdown(context.Background())
4147
4148 ga := st.wantGoAway()
4149 if ga.ErrCode != ErrCodeNo {
4150 t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
4151 }
4152 if ga.LastStreamID != 1 {
4153 t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
4154 }
4155
4156 w.Header().Set("x-foo", "bar")
4157 })
4158 defer st.Close()
4159
4160 st.greet()
4161 st.bodylessReq1()
4162
4163 <-handlerDone
4164 hf := st.wantHeaders()
4165 goth := st.decodeHeader(hf.HeaderBlockFragment())
4166 wanth := [][2]string{
4167 {":status", "200"},
4168 {"x-foo", "bar"},
4169 {"content-length", "0"},
4170 }
4171 if !reflect.DeepEqual(goth, wanth) {
4172 t.Errorf("Got headers %v; want %v", goth, wanth)
4173 }
4174
4175 n, err := st.cc.Read([]byte{0})
4176 if n != 0 || err == nil {
4177 t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
4178 }
4179 }
4180
4181
4182 func TestContentEncodingNoSniffing(t *testing.T) {
4183 type resp struct {
4184 name string
4185 body []byte
4186
4187
4188
4189 contentEncoding interface{}
4190 wantContentType string
4191 }
4192
4193 resps := []*resp{
4194 {
4195 name: "gzip content-encoding, gzipped",
4196 contentEncoding: "application/gzip",
4197 wantContentType: "",
4198 body: func() []byte {
4199 buf := new(bytes.Buffer)
4200 gzw := gzip.NewWriter(buf)
4201 gzw.Write([]byte("doctype html><p>Hello</p>"))
4202 gzw.Close()
4203 return buf.Bytes()
4204 }(),
4205 },
4206 {
4207 name: "zlib content-encoding, zlibbed",
4208 contentEncoding: "application/zlib",
4209 wantContentType: "",
4210 body: func() []byte {
4211 buf := new(bytes.Buffer)
4212 zw := zlib.NewWriter(buf)
4213 zw.Write([]byte("doctype html><p>Hello</p>"))
4214 zw.Close()
4215 return buf.Bytes()
4216 }(),
4217 },
4218 {
4219 name: "no content-encoding",
4220 wantContentType: "application/x-gzip",
4221 body: func() []byte {
4222 buf := new(bytes.Buffer)
4223 gzw := gzip.NewWriter(buf)
4224 gzw.Write([]byte("doctype html><p>Hello</p>"))
4225 gzw.Close()
4226 return buf.Bytes()
4227 }(),
4228 },
4229 {
4230 name: "phony content-encoding",
4231 contentEncoding: "foo/bar",
4232 body: []byte("doctype html><p>Hello</p>"),
4233 },
4234 {
4235 name: "empty but set content-encoding",
4236 contentEncoding: "",
4237 wantContentType: "audio/mpeg",
4238 body: []byte("ID3"),
4239 },
4240 }
4241
4242 for _, tt := range resps {
4243 t.Run(tt.name, func(t *testing.T) {
4244 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4245 if tt.contentEncoding != nil {
4246 w.Header().Set("Content-Encoding", tt.contentEncoding.(string))
4247 }
4248 w.Write(tt.body)
4249 }, optOnlyServer)
4250 defer st.Close()
4251
4252 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4253 defer tr.CloseIdleConnections()
4254
4255 req, _ := http.NewRequest("GET", st.ts.URL, nil)
4256 res, err := tr.RoundTrip(req)
4257 if err != nil {
4258 t.Fatalf("GET %s: %v", st.ts.URL, err)
4259 }
4260 defer res.Body.Close()
4261
4262 g := res.Header.Get("Content-Encoding")
4263 t.Logf("%s: Content-Encoding: %s", st.ts.URL, g)
4264
4265 if w := tt.contentEncoding; g != w {
4266 if w != nil {
4267 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
4268 } else if g != "" {
4269 t.Errorf("Unexpected Content-Encoding %q", g)
4270 }
4271 }
4272
4273 g = res.Header.Get("Content-Type")
4274 if w := tt.wantContentType; g != w {
4275 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
4276 }
4277 t.Logf("%s: Content-Type: %s", st.ts.URL, g)
4278 })
4279 }
4280 }
4281
4282 func TestServerWindowUpdateOnBodyClose(t *testing.T) {
4283 const windowSize = 65535 * 2
4284 content := make([]byte, windowSize)
4285 blockCh := make(chan bool)
4286 errc := make(chan error, 1)
4287 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4288 buf := make([]byte, 4)
4289 n, err := io.ReadFull(r.Body, buf)
4290 if err != nil {
4291 errc <- err
4292 return
4293 }
4294 if n != len(buf) {
4295 errc <- fmt.Errorf("too few bytes read: %d", n)
4296 return
4297 }
4298 blockCh <- true
4299 <-blockCh
4300 errc <- nil
4301 }, func(s *Server) {
4302 s.MaxUploadBufferPerConnection = windowSize
4303 s.MaxUploadBufferPerStream = windowSize
4304 })
4305 defer st.Close()
4306
4307 st.greet()
4308 st.writeHeaders(HeadersFrameParam{
4309 StreamID: 1,
4310 BlockFragment: st.encodeHeader(
4311 ":method", "POST",
4312 "content-length", strconv.Itoa(len(content)),
4313 ),
4314 EndStream: false,
4315 EndHeaders: true,
4316 })
4317 st.writeData(1, false, content[:windowSize/2])
4318 <-blockCh
4319 st.stream(1).body.CloseWithError(io.EOF)
4320 blockCh <- true
4321
4322
4323 increments := windowSize / 2
4324 for {
4325 f, err := st.readFrame()
4326 if err == io.EOF {
4327 break
4328 }
4329 if err != nil {
4330 t.Fatal(err)
4331 }
4332 if wu, ok := f.(*WindowUpdateFrame); ok && wu.StreamID == 0 {
4333 increments -= int(wu.Increment)
4334 if increments == 0 {
4335 break
4336 }
4337 }
4338 }
4339
4340
4341 st.writeData(1, false, content[windowSize/2:])
4342 st.wantWindowUpdate(0, windowSize/2)
4343
4344 if err := <-errc; err != nil {
4345 t.Error(err)
4346 }
4347 }
4348
4349 func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) {
4350 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
4351 defer st.Close()
4352
4353 st.greet()
4354
4355 content := "some content"
4356 st.writeHeaders(HeadersFrameParam{
4357 StreamID: 1,
4358 BlockFragment: st.encodeHeader(
4359 ":method", "POST",
4360 "content-length", strconv.Itoa(len(content)),
4361 ),
4362 EndStream: false,
4363 EndHeaders: true,
4364 })
4365 st.wantHeaders()
4366
4367 st.sc.startGracefulShutdown()
4368 for {
4369 f, err := st.readFrame()
4370 if err == io.EOF {
4371 st.t.Fatal("got a EOF; want *GoAwayFrame")
4372 }
4373 if err != nil {
4374 t.Fatal(err)
4375 }
4376 if gf, ok := f.(*GoAwayFrame); ok && gf.StreamID == 0 {
4377 break
4378 }
4379 }
4380
4381 st.writeData(1, true, []byte(content))
4382 time.Sleep(200 * time.Millisecond)
4383 st.Close()
4384
4385 if bytes.Contains(st.serverLogBuf.Bytes(), []byte("PROTOCOL_ERROR")) {
4386 t.Error("got protocol error")
4387 }
4388 }
4389
4390 func TestServerSendsProcessing(t *testing.T) {
4391 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
4392 w.WriteHeader(http.StatusProcessing)
4393 w.Write([]byte("stuff"))
4394
4395 return nil
4396 }, func(st *serverTester) {
4397 getSlash(st)
4398 hf := st.wantHeaders()
4399 goth := st.decodeHeader(hf.HeaderBlockFragment())
4400 wanth := [][2]string{
4401 {":status", "102"},
4402 }
4403
4404 if !reflect.DeepEqual(goth, wanth) {
4405 t.Errorf("Got = %q; want %q", goth, wanth)
4406 }
4407
4408 hf = st.wantHeaders()
4409 goth = st.decodeHeader(hf.HeaderBlockFragment())
4410 wanth = [][2]string{
4411 {":status", "200"},
4412 {"content-type", "text/plain; charset=utf-8"},
4413 {"content-length", "5"},
4414 }
4415
4416 if !reflect.DeepEqual(goth, wanth) {
4417 t.Errorf("Got = %q; want %q", goth, wanth)
4418 }
4419 })
4420 }
4421
4422 func TestServerSendsEarlyHints(t *testing.T) {
4423 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
4424 h := w.Header()
4425 h.Add("Content-Length", "123")
4426 h.Add("Link", "</style.css>; rel=preload; as=style")
4427 h.Add("Link", "</script.js>; rel=preload; as=script")
4428 w.WriteHeader(http.StatusEarlyHints)
4429
4430 h.Add("Link", "</foo.js>; rel=preload; as=script")
4431 w.WriteHeader(http.StatusEarlyHints)
4432
4433 w.Write([]byte("stuff"))
4434
4435 return nil
4436 }, func(st *serverTester) {
4437 getSlash(st)
4438 hf := st.wantHeaders()
4439 goth := st.decodeHeader(hf.HeaderBlockFragment())
4440 wanth := [][2]string{
4441 {":status", "103"},
4442 {"link", "</style.css>; rel=preload; as=style"},
4443 {"link", "</script.js>; rel=preload; as=script"},
4444 }
4445
4446 if !reflect.DeepEqual(goth, wanth) {
4447 t.Errorf("Got = %q; want %q", goth, wanth)
4448 }
4449
4450 hf = st.wantHeaders()
4451 goth = st.decodeHeader(hf.HeaderBlockFragment())
4452 wanth = [][2]string{
4453 {":status", "103"},
4454 {"link", "</style.css>; rel=preload; as=style"},
4455 {"link", "</script.js>; rel=preload; as=script"},
4456 {"link", "</foo.js>; rel=preload; as=script"},
4457 }
4458
4459 if !reflect.DeepEqual(goth, wanth) {
4460 t.Errorf("Got = %q; want %q", goth, wanth)
4461 }
4462
4463 hf = st.wantHeaders()
4464 goth = st.decodeHeader(hf.HeaderBlockFragment())
4465 wanth = [][2]string{
4466 {":status", "200"},
4467 {"link", "</style.css>; rel=preload; as=style"},
4468 {"link", "</script.js>; rel=preload; as=script"},
4469 {"link", "</foo.js>; rel=preload; as=script"},
4470 {"content-type", "text/plain; charset=utf-8"},
4471 {"content-length", "123"},
4472 }
4473
4474 if !reflect.DeepEqual(goth, wanth) {
4475 t.Errorf("Got = %q; want %q", goth, wanth)
4476 }
4477 })
4478 }
4479
4480 func TestProtocolErrorAfterGoAway(t *testing.T) {
4481 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4482 io.Copy(io.Discard, r.Body)
4483 })
4484 defer st.Close()
4485
4486 st.greet()
4487 content := "some content"
4488 st.writeHeaders(HeadersFrameParam{
4489 StreamID: 1,
4490 BlockFragment: st.encodeHeader(
4491 ":method", "POST",
4492 "content-length", strconv.Itoa(len(content)),
4493 ),
4494 EndStream: false,
4495 EndHeaders: true,
4496 })
4497 st.writeData(1, false, []byte(content[:5]))
4498 st.writeReadPing()
4499
4500
4501
4502 if err := st.fr.WriteGoAway(1, ErrCodeNo, nil); err != nil {
4503 t.Fatal(err)
4504 }
4505 if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
4506 t.Fatal(err)
4507 }
4508
4509 for {
4510 if _, err := st.readFrame(); err != nil {
4511 if err != io.EOF {
4512 t.Errorf("unexpected readFrame error: %v", err)
4513 }
4514 break
4515 }
4516 }
4517 }
4518
4519 func TestServerInitialFlowControlWindow(t *testing.T) {
4520 for _, want := range []int32{
4521 65535,
4522 1 << 19,
4523 1 << 21,
4524
4525
4526
4527
4528
4529 65535 * 2,
4530 } {
4531 t.Run(fmt.Sprint(want), func(t *testing.T) {
4532
4533 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4534 }, func(s *Server) {
4535 s.MaxUploadBufferPerConnection = want
4536 })
4537 defer st.Close()
4538 st.writePreface()
4539 st.writeInitialSettings()
4540 st.writeSettingsAck()
4541 st.writeHeaders(HeadersFrameParam{
4542 StreamID: 1,
4543 BlockFragment: st.encodeHeader(),
4544 EndStream: true,
4545 EndHeaders: true,
4546 })
4547 window := 65535
4548 Frames:
4549 for {
4550 f, err := st.readFrame()
4551 if err != nil {
4552 st.t.Fatal(err)
4553 }
4554 switch f := f.(type) {
4555 case *WindowUpdateFrame:
4556 if f.FrameHeader.StreamID != 0 {
4557 t.Errorf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
4558 return
4559 }
4560 window += int(f.Increment)
4561 case *HeadersFrame:
4562 break Frames
4563 default:
4564 }
4565 }
4566 if window != int(want) {
4567 t.Errorf("got initial flow control window = %v, want %v", window, want)
4568 }
4569 })
4570 }
4571 }
4572
4573
4574
4575 func TestCanonicalHeaderCacheGrowth(t *testing.T) {
4576 for _, size := range []int{1, (1 << 20) - 10} {
4577 base := strings.Repeat("X", size)
4578 sc := &serverConn{
4579 serveG: newGoroutineLock(),
4580 }
4581 const count = 1000
4582 for i := 0; i < count; i++ {
4583 h := fmt.Sprintf("%v-%v", base, i)
4584 c := sc.canonicalHeader(h)
4585 if len(h) != len(c) {
4586 t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c)
4587 }
4588 }
4589 total := 0
4590 for k, v := range sc.canonHeader {
4591 total += len(k) + len(v) + 100
4592 }
4593 if total > maxCachedCanonicalHeadersKeysSize {
4594 t.Errorf("after adding %v ~%v-byte headers, canonHeader cache is ~%v bytes, want <%v", count, size, total, maxCachedCanonicalHeadersKeysSize)
4595 }
4596 }
4597 }
4598
4599
4600
4601
4602
4603
4604 func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
4605 donec := make(chan struct{})
4606 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4607 defer close(donec)
4608 buf := make([]byte, 1<<20)
4609 var i byte
4610 for {
4611 i++
4612 _, err := w.Write(buf)
4613 for j := range buf {
4614 buf[j] = byte(i)
4615 }
4616 if err != nil {
4617 return
4618 }
4619 }
4620 }, optOnlyServer)
4621 defer st.Close()
4622
4623 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4624 defer tr.CloseIdleConnections()
4625
4626 req, _ := http.NewRequest("GET", st.ts.URL, nil)
4627 res, err := tr.RoundTrip(req)
4628 if err != nil {
4629 t.Fatal(err)
4630 }
4631 res.Body.Close()
4632 <-donec
4633 }
4634
4635
4636
4637
4638
4639
4640 func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
4641 donec := make(chan struct{}, 1)
4642 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4643 donec <- struct{}{}
4644 defer close(donec)
4645 buf := make([]byte, 1<<20)
4646 var i byte
4647 for {
4648 i++
4649 _, err := w.Write(buf)
4650 for j := range buf {
4651 buf[j] = byte(i)
4652 }
4653 if err != nil {
4654 return
4655 }
4656 }
4657 }, optOnlyServer)
4658 defer st.Close()
4659
4660 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4661 defer tr.CloseIdleConnections()
4662
4663 req, _ := http.NewRequest("GET", st.ts.URL, nil)
4664 res, err := tr.RoundTrip(req)
4665 if err != nil {
4666 t.Fatal(err)
4667 }
4668 defer res.Body.Close()
4669 <-donec
4670 st.ts.Config.Close()
4671 <-donec
4672 }
4673
4674 func TestServerMaxHandlerGoroutines(t *testing.T) {
4675 const maxHandlers = 10
4676 handlerc := make(chan chan bool)
4677 donec := make(chan struct{})
4678 defer close(donec)
4679 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4680 stopc := make(chan bool, 1)
4681 select {
4682 case handlerc <- stopc:
4683 case <-donec:
4684 }
4685 select {
4686 case shouldPanic := <-stopc:
4687 if shouldPanic {
4688 panic(http.ErrAbortHandler)
4689 }
4690 case <-donec:
4691 }
4692 }, func(s *Server) {
4693 s.MaxConcurrentStreams = maxHandlers
4694 })
4695 defer st.Close()
4696
4697 st.writePreface()
4698 st.writeInitialSettings()
4699 st.writeSettingsAck()
4700
4701
4702
4703 var stops []chan bool
4704 streamID := uint32(1)
4705 for i := 0; i < maxHandlers; i++ {
4706 st.writeHeaders(HeadersFrameParam{
4707 StreamID: streamID,
4708 BlockFragment: st.encodeHeader(),
4709 EndStream: true,
4710 EndHeaders: true,
4711 })
4712 stops = append(stops, <-handlerc)
4713 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4714 streamID += 2
4715 }
4716
4717
4718 st.writeHeaders(HeadersFrameParam{
4719 StreamID: streamID,
4720 BlockFragment: st.encodeHeader(),
4721 EndStream: true,
4722 EndHeaders: true,
4723 })
4724 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4725 streamID += 2
4726
4727
4728 for i := 0; i < 2; i++ {
4729 st.writeHeaders(HeadersFrameParam{
4730 StreamID: streamID,
4731 BlockFragment: st.encodeHeader(),
4732 EndStream: true,
4733 EndHeaders: true,
4734 })
4735 streamID += 2
4736 }
4737
4738
4739
4740 select {
4741 case <-handlerc:
4742 t.Errorf("handler unexpectedly started while maxHandlers are already running")
4743 case <-time.After(1 * time.Millisecond):
4744 }
4745
4746
4747
4748 stops[0] <- false
4749 stops[1] <- true
4750 stops = stops[2:]
4751 stops = append(stops, <-handlerc)
4752 stops = append(stops, <-handlerc)
4753
4754
4755
4756 for i := 0; i < 5*maxHandlers; i++ {
4757 st.writeHeaders(HeadersFrameParam{
4758 StreamID: streamID,
4759 BlockFragment: st.encodeHeader(),
4760 EndStream: true,
4761 EndHeaders: true,
4762 })
4763 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4764 streamID += 2
4765 }
4766 Frames:
4767 for {
4768 f, err := st.readFrame()
4769 if err != nil {
4770 st.t.Fatal(err)
4771 }
4772 switch f := f.(type) {
4773 case *GoAwayFrame:
4774 if f.ErrCode != ErrCodeEnhanceYourCalm {
4775 t.Errorf("err code = %v; want %v", f.ErrCode, ErrCodeEnhanceYourCalm)
4776 }
4777 break Frames
4778 default:
4779 }
4780 }
4781
4782 for _, s := range stops {
4783 close(s)
4784 }
4785 }
4786
View as plain text