1
2
3
4
5 package h2c
6
7 import (
8 "context"
9 "crypto/tls"
10 "fmt"
11 "io"
12 "io/ioutil"
13 "log"
14 "net"
15 "net/http"
16 "net/http/httptest"
17 "strings"
18 "testing"
19
20 "golang.org/x/net/http2"
21 )
22
23 func ExampleNewHandler() {
24 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25 fmt.Fprint(w, "Hello world")
26 })
27 h2s := &http2.Server{
28
29 }
30 h1s := &http.Server{
31 Addr: ":8080",
32 Handler: NewHandler(handler, h2s),
33 }
34 log.Fatal(h1s.ListenAndServe())
35 }
36
37 func TestContext(t *testing.T) {
38 baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
39
40 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41 if r.ProtoMajor != 2 {
42 t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
43 }
44 if r.Context().Value("testkey") != "testvalue" {
45 t.Errorf("Request doesn't have expected base context: %v", r.Context())
46 }
47 fmt.Fprint(w, "Hello world")
48 })
49
50 h2s := &http2.Server{}
51 h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
52 h1s.Config.BaseContext = func(_ net.Listener) context.Context {
53 return baseCtx
54 }
55 h1s.Start()
56 defer h1s.Close()
57
58 client := &http.Client{
59 Transport: &http2.Transport{
60 AllowHTTP: true,
61 DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
62 return net.Dial(network, addr)
63 },
64 },
65 }
66
67 resp, err := client.Get(h1s.URL)
68 if err != nil {
69 t.Fatal(err)
70 }
71 _, err = ioutil.ReadAll(resp.Body)
72 if err != nil {
73 t.Fatal(err)
74 }
75 if err := resp.Body.Close(); err != nil {
76 t.Fatal(err)
77 }
78 }
79
80 func TestPropagation(t *testing.T) {
81 var (
82 server *http.Server
83
84 headerSize = 1 << 11
85 headerLimit = 1 << 10
86 )
87
88 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
89 if r.ProtoMajor != 2 {
90 t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
91 }
92 if r.Context().Value(http.ServerContextKey).(*http.Server) != server {
93 t.Errorf("Request doesn't have expected http server: %v", r.Context())
94 }
95 if len(r.Header.Get("Long-Header")) != headerSize {
96 t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header")))
97 }
98 fmt.Fprint(w, "Hello world")
99 })
100
101 h2s := &http2.Server{}
102 h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
103
104 server = h1s.Config
105 server.MaxHeaderBytes = headerLimit
106 server.ConnState = func(conn net.Conn, state http.ConnState) {
107 t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state)
108 }
109
110 h1s.Start()
111 defer h1s.Close()
112
113 client := &http.Client{
114 Transport: &http2.Transport{
115 AllowHTTP: true,
116 DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
117 conn, err := net.Dial(network, addr)
118 if conn != nil {
119 t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr())
120 }
121 return conn, err
122 },
123 },
124 }
125
126 req, err := http.NewRequest("GET", h1s.URL, nil)
127 if err != nil {
128 t.Fatal(err)
129 }
130
131 req.Header.Set("Long-Header", strings.Repeat("A", headerSize))
132
133 _, err = client.Do(req)
134 if err == nil {
135 t.Fatal("expected server err, got nil")
136 }
137 }
138
139 func TestMaxBytesHandler(t *testing.T) {
140 const bodyLimit = 10
141 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
142 t.Errorf("got request, expected to be blocked by body limit")
143 })
144
145 h2s := &http2.Server{}
146 h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit))
147 h1s.Start()
148 defer h1s.Close()
149
150
151 body := "0123456789abcdef"
152 req, err := http.NewRequest("POST", h1s.URL, struct{ io.Reader }{strings.NewReader(body)})
153 if err != nil {
154 t.Fatal(err)
155 }
156 req.Header.Set("Http2-Settings", "")
157 req.Header.Set("Upgrade", "h2c")
158 req.Header.Set("Connection", "Upgrade, HTTP2-Settings")
159
160 resp, err := h1s.Client().Do(req)
161 if err != nil {
162 t.Fatal(err)
163 }
164 defer resp.Body.Close()
165 _, err = ioutil.ReadAll(resp.Body)
166 if err != nil {
167 t.Fatal(err)
168 }
169 if got, want := resp.StatusCode, http.StatusInternalServerError; got != want {
170 t.Errorf("resp.StatusCode = %v, want %v", got, want)
171 }
172 }
173
View as plain text