1
2
3
4
5
6
7 package httptest
8
9 import (
10 "crypto/tls"
11 "crypto/x509"
12 "flag"
13 "fmt"
14 "log"
15 "net"
16 "net/http"
17 "net/http/internal/testcert"
18 "os"
19 "strings"
20 "sync"
21 "time"
22 )
23
24
25
26 type Server struct {
27 URL string
28 Listener net.Listener
29
30
31
32
33 EnableHTTP2 bool
34
35
36
37
38 TLS *tls.Config
39
40
41
42 Config *http.Server
43
44
45 certificate *x509.Certificate
46
47
48
49 wg sync.WaitGroup
50
51 mu sync.Mutex
52 closed bool
53 conns map[net.Conn]http.ConnState
54
55
56
57 client *http.Client
58 }
59
60 func newLocalListener() net.Listener {
61 if serveFlag != "" {
62 l, err := net.Listen("tcp", serveFlag)
63 if err != nil {
64 panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
65 }
66 return l
67 }
68 l, err := net.Listen("tcp", "127.0.0.1:0")
69 if err != nil {
70 if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
71 panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
72 }
73 }
74 return l
75 }
76
77
78
79
80
81
82
83
84
85
86 var serveFlag string
87
88 func init() {
89 if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
90 flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
91 }
92 }
93
94 func strSliceContainsPrefix(v []string, pre string) bool {
95 for _, s := range v {
96 if strings.HasPrefix(s, pre) {
97 return true
98 }
99 }
100 return false
101 }
102
103
104
105 func NewServer(handler http.Handler) *Server {
106 ts := NewUnstartedServer(handler)
107 ts.Start()
108 return ts
109 }
110
111
112
113
114
115
116
117 func NewUnstartedServer(handler http.Handler) *Server {
118 return &Server{
119 Listener: newLocalListener(),
120 Config: &http.Server{Handler: handler},
121 }
122 }
123
124
125 func (s *Server) Start() {
126 if s.URL != "" {
127 panic("Server already started")
128 }
129 if s.client == nil {
130 s.client = &http.Client{Transport: &http.Transport{}}
131 }
132 s.URL = "http://" + s.Listener.Addr().String()
133 s.wrap()
134 s.goServe()
135 if serveFlag != "" {
136 fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
137 select {}
138 }
139 }
140
141
142 func (s *Server) StartTLS() {
143 if s.URL != "" {
144 panic("Server already started")
145 }
146 if s.client == nil {
147 s.client = &http.Client{}
148 }
149 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
150 if err != nil {
151 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
152 }
153
154 existingConfig := s.TLS
155 if existingConfig != nil {
156 s.TLS = existingConfig.Clone()
157 } else {
158 s.TLS = new(tls.Config)
159 }
160 if s.TLS.NextProtos == nil {
161 nextProtos := []string{"http/1.1"}
162 if s.EnableHTTP2 {
163 nextProtos = []string{"h2"}
164 }
165 s.TLS.NextProtos = nextProtos
166 }
167 if len(s.TLS.Certificates) == 0 {
168 s.TLS.Certificates = []tls.Certificate{cert}
169 }
170 s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
171 if err != nil {
172 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
173 }
174 certpool := x509.NewCertPool()
175 certpool.AddCert(s.certificate)
176 s.client.Transport = &http.Transport{
177 TLSClientConfig: &tls.Config{
178 RootCAs: certpool,
179 },
180 ForceAttemptHTTP2: s.EnableHTTP2,
181 }
182 s.Listener = tls.NewListener(s.Listener, s.TLS)
183 s.URL = "https://" + s.Listener.Addr().String()
184 s.wrap()
185 s.goServe()
186 }
187
188
189
190 func NewTLSServer(handler http.Handler) *Server {
191 ts := NewUnstartedServer(handler)
192 ts.StartTLS()
193 return ts
194 }
195
196 type closeIdleTransport interface {
197 CloseIdleConnections()
198 }
199
200
201
202 func (s *Server) Close() {
203 s.mu.Lock()
204 if !s.closed {
205 s.closed = true
206 s.Listener.Close()
207 s.Config.SetKeepAlivesEnabled(false)
208 for c, st := range s.conns {
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227 if st == http.StateIdle || st == http.StateNew {
228 s.closeConn(c)
229 }
230 }
231
232 t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
233 defer t.Stop()
234 }
235 s.mu.Unlock()
236
237
238
239
240 if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
241 t.CloseIdleConnections()
242 }
243
244
245 if s.client != nil {
246 if t, ok := s.client.Transport.(closeIdleTransport); ok {
247 t.CloseIdleConnections()
248 }
249 }
250
251 s.wg.Wait()
252 }
253
254 func (s *Server) logCloseHangDebugInfo() {
255 s.mu.Lock()
256 defer s.mu.Unlock()
257 var buf strings.Builder
258 buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
259 for c, st := range s.conns {
260 fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
261 }
262 log.Print(buf.String())
263 }
264
265
266 func (s *Server) CloseClientConnections() {
267 s.mu.Lock()
268 nconn := len(s.conns)
269 ch := make(chan struct{}, nconn)
270 for c := range s.conns {
271 go s.closeConnChan(c, ch)
272 }
273 s.mu.Unlock()
274
275
276
277
278
279
280
281 timer := time.NewTimer(5 * time.Second)
282 defer timer.Stop()
283 for i := 0; i < nconn; i++ {
284 select {
285 case <-ch:
286 case <-timer.C:
287
288 return
289 }
290 }
291 }
292
293
294
295 func (s *Server) Certificate() *x509.Certificate {
296 return s.certificate
297 }
298
299
300
301
302 func (s *Server) Client() *http.Client {
303 return s.client
304 }
305
306 func (s *Server) goServe() {
307 s.wg.Add(1)
308 go func() {
309 defer s.wg.Done()
310 s.Config.Serve(s.Listener)
311 }()
312 }
313
314
315
316 func (s *Server) wrap() {
317 oldHook := s.Config.ConnState
318 s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
319 s.mu.Lock()
320 defer s.mu.Unlock()
321
322 switch cs {
323 case http.StateNew:
324 if _, exists := s.conns[c]; exists {
325 panic("invalid state transition")
326 }
327 if s.conns == nil {
328 s.conns = make(map[net.Conn]http.ConnState)
329 }
330
331
332 s.wg.Add(1)
333 s.conns[c] = cs
334 if s.closed {
335
336
337
338
339 s.closeConn(c)
340 }
341 case http.StateActive:
342 if oldState, ok := s.conns[c]; ok {
343 if oldState != http.StateNew && oldState != http.StateIdle {
344 panic("invalid state transition")
345 }
346 s.conns[c] = cs
347 }
348 case http.StateIdle:
349 if oldState, ok := s.conns[c]; ok {
350 if oldState != http.StateActive {
351 panic("invalid state transition")
352 }
353 s.conns[c] = cs
354 }
355 if s.closed {
356 s.closeConn(c)
357 }
358 case http.StateHijacked, http.StateClosed:
359
360
361 if _, ok := s.conns[c]; ok {
362 delete(s.conns, c)
363
364
365 defer s.wg.Done()
366 }
367 }
368 if oldHook != nil {
369 oldHook(c, cs)
370 }
371 }
372 }
373
374
375
376 func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
377
378
379
380 func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
381 c.Close()
382 if done != nil {
383 done <- struct{}{}
384 }
385 }
386
View as plain text