1
2
3
4
5 package http2
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "io/ioutil"
12 "net/http"
13 "reflect"
14 "runtime"
15 "strconv"
16 "sync"
17 "testing"
18 "time"
19 )
20
21 func TestServer_Push_Success(t *testing.T) {
22 const (
23 mainBody = "<html>index page</html>"
24 pushedBody = "<html>pushed page</html>"
25 userAgent = "testagent"
26 cookie = "testcookie"
27 )
28
29 var stURL string
30 checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
31 if got, want := r.Method, wantMethod; got != want {
32 return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
33 }
34 if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
35 return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
36 }
37 if got, want := "https://"+r.Host, stURL; got != want {
38 return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
39 }
40 if r.Body == nil {
41 return fmt.Errorf("nil Body")
42 }
43 if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
44 return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
45 }
46 return nil
47 }
48
49 errc := make(chan error, 3)
50 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
51 switch r.URL.RequestURI() {
52 case "/":
53
54 opt := &http.PushOptions{
55 Header: http.Header{
56 "User-Agent": {userAgent},
57 },
58 }
59 if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
60 errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
61 return
62 }
63
64 opt = &http.PushOptions{
65 Method: "HEAD",
66 Header: http.Header{
67 "User-Agent": {userAgent},
68 "Cookie": {cookie},
69 },
70 }
71 if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
72 errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
73 return
74 }
75 w.Header().Set("Content-Type", "text/html")
76 w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
77 w.WriteHeader(200)
78 io.WriteString(w, mainBody)
79 errc <- nil
80
81 case "/pushed?get":
82 wantH := http.Header{}
83 wantH.Set("User-Agent", userAgent)
84 if err := checkPromisedReq(r, "GET", wantH); err != nil {
85 errc <- fmt.Errorf("/pushed?get: %v", err)
86 return
87 }
88 w.Header().Set("Content-Type", "text/html")
89 w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
90 w.WriteHeader(200)
91 io.WriteString(w, pushedBody)
92 errc <- nil
93
94 case "/pushed?head":
95 wantH := http.Header{}
96 wantH.Set("User-Agent", userAgent)
97 wantH.Set("Cookie", cookie)
98 if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
99 errc <- fmt.Errorf("/pushed?head: %v", err)
100 return
101 }
102 w.WriteHeader(204)
103 errc <- nil
104
105 default:
106 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
107 }
108 })
109 stURL = st.ts.URL
110
111
112 st.greet()
113 getSlash(st)
114 for k := 0; k < 3; k++ {
115 select {
116 case <-time.After(2 * time.Second):
117 t.Errorf("timeout waiting for handler %d to finish", k)
118 case err := <-errc:
119 if err != nil {
120 t.Fatal(err)
121 }
122 }
123 }
124
125 checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
126 pp, ok := f.(*PushPromiseFrame)
127 if !ok {
128 return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
129 }
130 if !pp.HeadersEnded() {
131 return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
132 }
133 if got, want := pp.PromiseID, promiseID; got != want {
134 return fmt.Errorf("got PromiseID %v; want %v", got, want)
135 }
136 gotH := st.decodeHeader(pp.HeaderBlockFragment())
137 if !reflect.DeepEqual(gotH, wantH) {
138 return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
139 }
140 return nil
141 }
142 checkHeaders := func(f Frame, wantH [][2]string) error {
143 hf, ok := f.(*HeadersFrame)
144 if !ok {
145 return fmt.Errorf("got a %T; want *HeadersFrame", f)
146 }
147 gotH := st.decodeHeader(hf.HeaderBlockFragment())
148 if !reflect.DeepEqual(gotH, wantH) {
149 return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
150 }
151 return nil
152 }
153 checkData := func(f Frame, wantData string) error {
154 df, ok := f.(*DataFrame)
155 if !ok {
156 return fmt.Errorf("got a %T; want *DataFrame", f)
157 }
158 if gotData := string(df.Data()); gotData != wantData {
159 return fmt.Errorf("got response data %q; want %q", gotData, wantData)
160 }
161 return nil
162 }
163
164
165
166
167 expected := map[uint32][]func(Frame) error{
168 1: {
169 func(f Frame) error {
170 return checkPushPromise(f, 2, [][2]string{
171 {":method", "GET"},
172 {":scheme", "https"},
173 {":authority", st.ts.Listener.Addr().String()},
174 {":path", "/pushed?get"},
175 {"user-agent", userAgent},
176 })
177 },
178 func(f Frame) error {
179 return checkPushPromise(f, 4, [][2]string{
180 {":method", "HEAD"},
181 {":scheme", "https"},
182 {":authority", st.ts.Listener.Addr().String()},
183 {":path", "/pushed?head"},
184 {"cookie", cookie},
185 {"user-agent", userAgent},
186 })
187 },
188 func(f Frame) error {
189 return checkHeaders(f, [][2]string{
190 {":status", "200"},
191 {"content-type", "text/html"},
192 {"content-length", strconv.Itoa(len(mainBody))},
193 })
194 },
195 func(f Frame) error {
196 return checkData(f, mainBody)
197 },
198 },
199 2: {
200 func(f Frame) error {
201 return checkHeaders(f, [][2]string{
202 {":status", "200"},
203 {"content-type", "text/html"},
204 {"content-length", strconv.Itoa(len(pushedBody))},
205 })
206 },
207 func(f Frame) error {
208 return checkData(f, pushedBody)
209 },
210 },
211 4: {
212 func(f Frame) error {
213 return checkHeaders(f, [][2]string{
214 {":status", "204"},
215 })
216 },
217 },
218 }
219
220 consumed := map[uint32]int{}
221 for k := 0; len(expected) > 0; k++ {
222 f, err := st.readFrame()
223 if err != nil {
224 for id, left := range expected {
225 t.Errorf("stream %d: missing %d frames", id, len(left))
226 }
227 t.Fatalf("readFrame %d: %v", k, err)
228 }
229 id := f.Header().StreamID
230 label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
231 if len(expected[id]) == 0 {
232 t.Fatalf("%s: unexpected frame %#+v", label, f)
233 }
234 check := expected[id][0]
235 expected[id] = expected[id][1:]
236 if len(expected[id]) == 0 {
237 delete(expected, id)
238 }
239 if err := check(f); err != nil {
240 t.Fatalf("%s: %v", label, err)
241 }
242 consumed[id]++
243 }
244 }
245
246 func TestServer_Push_SuccessNoRace(t *testing.T) {
247
248
249 errc := make(chan error, 2)
250 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
251 switch r.URL.RequestURI() {
252 case "/":
253 opt := &http.PushOptions{
254 Header: http.Header{"User-Agent": {"testagent"}},
255 }
256 if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
257 errc <- fmt.Errorf("error pushing: %v", err)
258 return
259 }
260 w.WriteHeader(200)
261 errc <- nil
262
263 case "/pushed":
264
265 r.Header.Set("User-Agent", "newagent")
266 r.Header.Set("Cookie", "cookie")
267 w.WriteHeader(200)
268 errc <- nil
269
270 default:
271 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
272 }
273 })
274
275
276 st.greet()
277 getSlash(st)
278 for k := 0; k < 2; k++ {
279 select {
280 case <-time.After(2 * time.Second):
281 t.Errorf("timeout waiting for handler %d to finish", k)
282 case err := <-errc:
283 if err != nil {
284 t.Fatal(err)
285 }
286 }
287 }
288 }
289
290 func TestServer_Push_RejectRecursivePush(t *testing.T) {
291
292 errc := make(chan error, 3)
293 handler := func(w http.ResponseWriter, r *http.Request) error {
294 baseURL := "https://" + r.Host
295 switch r.URL.Path {
296 case "/":
297 if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
298 return fmt.Errorf("first Push()=%v, want nil", err)
299 }
300 return nil
301
302 case "/push1":
303 if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
304 return fmt.Errorf("Push()=%v, want %v", got, want)
305 }
306 return nil
307
308 default:
309 return fmt.Errorf("unexpected path: %q", r.URL.Path)
310 }
311 }
312 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
313 errc <- handler(w, r)
314 })
315 defer st.Close()
316 st.greet()
317 getSlash(st)
318 if err := <-errc; err != nil {
319 t.Errorf("First request failed: %v", err)
320 }
321 if err := <-errc; err != nil {
322 t.Errorf("Second request failed: %v", err)
323 }
324 }
325
326 func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
327
328 errc := make(chan error, 2)
329 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
330 errc <- doPush(w.(http.Pusher), r)
331 })
332 defer st.Close()
333 st.greet()
334 if err := st.fr.WriteSettings(settings...); err != nil {
335 st.t.Fatalf("WriteSettings: %v", err)
336 }
337 st.wantSettingsAck()
338 getSlash(st)
339 if err := <-errc; err != nil {
340 t.Error(err)
341 }
342
343 hf := st.wantHeaders()
344 if !hf.StreamEnded() {
345 t.Error("stream should end after headers")
346 }
347 }
348
349 func TestServer_Push_RejectIfDisabled(t *testing.T) {
350 testServer_Push_RejectSingleRequest(t,
351 func(p http.Pusher, r *http.Request) error {
352 if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
353 return fmt.Errorf("Push()=%v, want %v", got, want)
354 }
355 return nil
356 },
357 Setting{SettingEnablePush, 0})
358 }
359
360 func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
361 testServer_Push_RejectSingleRequest(t,
362 func(p http.Pusher, r *http.Request) error {
363 if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
364 return fmt.Errorf("Push()=%v, want %v", got, want)
365 }
366 return nil
367 },
368 Setting{SettingMaxConcurrentStreams, 0})
369 }
370
371 func TestServer_Push_RejectWrongScheme(t *testing.T) {
372 testServer_Push_RejectSingleRequest(t,
373 func(p http.Pusher, r *http.Request) error {
374 if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
375 return errors.New("Push() should have failed (push target URL is http)")
376 }
377 return nil
378 })
379 }
380
381 func TestServer_Push_RejectMissingHost(t *testing.T) {
382 testServer_Push_RejectSingleRequest(t,
383 func(p http.Pusher, r *http.Request) error {
384 if err := p.Push("https:pushed", nil); err == nil {
385 return errors.New("Push() should have failed (push target URL missing host)")
386 }
387 return nil
388 })
389 }
390
391 func TestServer_Push_RejectRelativePath(t *testing.T) {
392 testServer_Push_RejectSingleRequest(t,
393 func(p http.Pusher, r *http.Request) error {
394 if err := p.Push("../test", nil); err == nil {
395 return errors.New("Push() should have failed (push target is a relative path)")
396 }
397 return nil
398 })
399 }
400
401 func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
402 testServer_Push_RejectSingleRequest(t,
403 func(p http.Pusher, r *http.Request) error {
404 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
405 return errors.New("Push() should have failed (cannot promise a POST)")
406 }
407 return nil
408 })
409 }
410
411 func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
412 testServer_Push_RejectSingleRequest(t,
413 func(p http.Pusher, r *http.Request) error {
414 header := http.Header{
415 "Content-Length": {"10"},
416 "Content-Encoding": {"gzip"},
417 "Trailer": {"Foo"},
418 "Te": {"trailers"},
419 "Host": {"test.com"},
420 ":authority": {"test.com"},
421 }
422 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
423 return errors.New("Push() should have failed (forbidden headers)")
424 }
425 return nil
426 })
427 }
428
429 func TestServer_Push_StateTransitions(t *testing.T) {
430 const body = "foo"
431
432 gotPromise := make(chan bool)
433 finishedPush := make(chan bool)
434
435 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
436 switch r.URL.RequestURI() {
437 case "/":
438 if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
439 t.Errorf("Push error: %v", err)
440 }
441
442
443 <-finishedPush
444 case "/pushed":
445 <-gotPromise
446 }
447 w.Header().Set("Content-Type", "text/html")
448 w.Header().Set("Content-Length", strconv.Itoa(len(body)))
449 w.WriteHeader(200)
450 io.WriteString(w, body)
451 })
452 defer st.Close()
453
454 st.greet()
455 if st.stream(2) != nil {
456 t.Fatal("stream 2 should be empty")
457 }
458 if got, want := st.streamState(2), stateIdle; got != want {
459 t.Fatalf("streamState(2)=%v, want %v", got, want)
460 }
461 getSlash(st)
462
463 st.wantPushPromise()
464 if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
465 t.Fatalf("streamState(2)=%v, want %v", got, want)
466 }
467
468
469
470
471 close(gotPromise)
472 st.wantHeaders()
473 if df := st.wantData(); !df.StreamEnded() {
474 t.Fatal("expected END_STREAM flag on DATA")
475 }
476 if got, want := st.streamState(2), stateClosed; got != want {
477 t.Fatalf("streamState(2)=%v, want %v", got, want)
478 }
479 close(finishedPush)
480 }
481
482 func TestServer_Push_RejectAfterGoAway(t *testing.T) {
483 var readyOnce sync.Once
484 ready := make(chan struct{})
485 errc := make(chan error, 2)
486 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
487 <-ready
488 if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
489 errc <- fmt.Errorf("Push()=%v, want %v", got, want)
490 }
491 errc <- nil
492 })
493 defer st.Close()
494 st.greet()
495 getSlash(st)
496
497
498 st.fr.WriteGoAway(1, ErrCodeNo, nil)
499 go func() {
500 for {
501 select {
502 case <-ready:
503 return
504 default:
505 if runtime.GOARCH == "wasm" {
506
507 runtime.Gosched()
508 }
509 }
510 st.sc.serveMsgCh <- func(loopNum int) {
511 if !st.sc.pushEnabled {
512 readyOnce.Do(func() { close(ready) })
513 }
514 }
515 }
516 }()
517 if err := <-errc; err != nil {
518 t.Error(err)
519 }
520 }
521
522 func TestServer_Push_Underflow(t *testing.T) {
523
524
525 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
526 switch r.URL.RequestURI() {
527 case "/":
528 opt := &http.PushOptions{
529 Header: http.Header{"User-Agent": {"testagent"}},
530 }
531 if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
532 t.Errorf("error pushing: %v", err)
533 }
534 w.WriteHeader(200)
535 case "/pushed":
536 r.Header.Set("User-Agent", "newagent")
537 r.Header.Set("Cookie", "cookie")
538 w.WriteHeader(200)
539 default:
540 t.Errorf("unknown RequestURL %q", r.URL.RequestURI())
541 }
542 })
543
544 st.greet()
545 const numRequests = 4
546 for i := 0; i < numRequests; i++ {
547 st.writeHeaders(HeadersFrameParam{
548 StreamID: uint32(1 + i*2),
549 BlockFragment: st.encodeHeader(),
550 EndStream: true,
551 EndHeaders: true,
552 })
553 }
554
555 numPushPromises := 0
556 numHeaders := 0
557 for numHeaders < numRequests*2 || numPushPromises < numRequests {
558 f, err := st.readFrame()
559 if err != nil {
560 st.t.Fatal(err)
561 }
562 switch f := f.(type) {
563 case *HeadersFrame:
564 if !f.Flags.Has(FlagHeadersEndStream) {
565 t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f)
566 }
567 numHeaders++
568 case *PushPromiseFrame:
569 numPushPromises++
570 }
571 }
572 }
573
View as plain text