1
2
3
4
5 package gin
6
7 import (
8 "crypto/tls"
9 "fmt"
10 "html/template"
11 "io"
12 "net"
13 "net/http"
14 "net/http/httptest"
15 "reflect"
16 "strconv"
17 "sync/atomic"
18 "testing"
19 "time"
20
21 "github.com/stretchr/testify/assert"
22 "golang.org/x/net/http2"
23 )
24
25 func formatAsDate(t time.Time) string {
26 year, month, day := t.Date()
27 return fmt.Sprintf("%d/%02d/%02d", year, month, day)
28 }
29
30 func setupHTMLFiles(t *testing.T, mode string, tls bool, loadMethod func(*Engine)) *httptest.Server {
31 SetMode(mode)
32 defer SetMode(TestMode)
33
34 var router *Engine
35 captureOutput(t, func() {
36 router = New()
37 router.Delims("{[{", "}]}")
38 router.SetFuncMap(template.FuncMap{
39 "formatAsDate": formatAsDate,
40 })
41 loadMethod(router)
42 router.GET("/test", func(c *Context) {
43 c.HTML(http.StatusOK, "hello.tmpl", map[string]string{"name": "world"})
44 })
45 router.GET("/raw", func(c *Context) {
46 c.HTML(http.StatusOK, "raw.tmpl", map[string]any{
47 "now": time.Date(2017, 07, 01, 0, 0, 0, 0, time.UTC),
48 })
49 })
50 })
51
52 var ts *httptest.Server
53
54 if tls {
55 ts = httptest.NewTLSServer(router)
56 } else {
57 ts = httptest.NewServer(router)
58 }
59
60 return ts
61 }
62
63 func TestLoadHTMLGlobDebugMode(t *testing.T) {
64 ts := setupHTMLFiles(
65 t,
66 DebugMode,
67 false,
68 func(router *Engine) {
69 router.LoadHTMLGlob("./testdata/template/*")
70 },
71 )
72 defer ts.Close()
73
74 res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
75 if err != nil {
76 t.Error(err)
77 }
78
79 resp, _ := io.ReadAll(res.Body)
80 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
81 }
82
83 func TestH2c(t *testing.T) {
84 ln, err := net.Listen("tcp", "127.0.0.1:0")
85 if err != nil {
86 t.Error(err)
87 }
88 r := Default()
89 r.UseH2C = true
90 r.GET("/", func(c *Context) {
91 c.String(200, "<h1>Hello world</h1>")
92 })
93 go func() {
94 err := http.Serve(ln, r.Handler())
95 if err != nil {
96 t.Log(err)
97 }
98 }()
99 defer ln.Close()
100
101 url := "http://" + ln.Addr().String() + "/"
102
103 httpClient := http.Client{
104 Transport: &http2.Transport{
105 AllowHTTP: true,
106 DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
107 return net.Dial(netw, addr)
108 },
109 },
110 }
111
112 res, err := httpClient.Get(url)
113 if err != nil {
114 t.Error(err)
115 }
116
117 resp, _ := io.ReadAll(res.Body)
118 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
119 }
120
121 func TestLoadHTMLGlobTestMode(t *testing.T) {
122 ts := setupHTMLFiles(
123 t,
124 TestMode,
125 false,
126 func(router *Engine) {
127 router.LoadHTMLGlob("./testdata/template/*")
128 },
129 )
130 defer ts.Close()
131
132 res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
133 if err != nil {
134 t.Error(err)
135 }
136
137 resp, _ := io.ReadAll(res.Body)
138 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
139 }
140
141 func TestLoadHTMLGlobReleaseMode(t *testing.T) {
142 ts := setupHTMLFiles(
143 t,
144 ReleaseMode,
145 false,
146 func(router *Engine) {
147 router.LoadHTMLGlob("./testdata/template/*")
148 },
149 )
150 defer ts.Close()
151
152 res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
153 if err != nil {
154 t.Error(err)
155 }
156
157 resp, _ := io.ReadAll(res.Body)
158 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
159 }
160
161 func TestLoadHTMLGlobUsingTLS(t *testing.T) {
162 ts := setupHTMLFiles(
163 t,
164 DebugMode,
165 true,
166 func(router *Engine) {
167 router.LoadHTMLGlob("./testdata/template/*")
168 },
169 )
170 defer ts.Close()
171
172
173 tr := &http.Transport{
174 TLSClientConfig: &tls.Config{
175 InsecureSkipVerify: true,
176 },
177 }
178 client := &http.Client{Transport: tr}
179 res, err := client.Get(fmt.Sprintf("%s/test", ts.URL))
180 if err != nil {
181 t.Error(err)
182 }
183
184 resp, _ := io.ReadAll(res.Body)
185 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
186 }
187
188 func TestLoadHTMLGlobFromFuncMap(t *testing.T) {
189 ts := setupHTMLFiles(
190 t,
191 DebugMode,
192 false,
193 func(router *Engine) {
194 router.LoadHTMLGlob("./testdata/template/*")
195 },
196 )
197 defer ts.Close()
198
199 res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL))
200 if err != nil {
201 t.Error(err)
202 }
203
204 resp, _ := io.ReadAll(res.Body)
205 assert.Equal(t, "Date: 2017/07/01", string(resp))
206 }
207
208 func init() {
209 SetMode(TestMode)
210 }
211
212 func TestCreateEngine(t *testing.T) {
213 router := New()
214 assert.Equal(t, "/", router.basePath)
215 assert.Equal(t, router.engine, router)
216 assert.Empty(t, router.Handlers)
217 }
218
219 func TestLoadHTMLFilesTestMode(t *testing.T) {
220 ts := setupHTMLFiles(
221 t,
222 TestMode,
223 false,
224 func(router *Engine) {
225 router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
226 },
227 )
228 defer ts.Close()
229
230 res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
231 if err != nil {
232 t.Error(err)
233 }
234
235 resp, _ := io.ReadAll(res.Body)
236 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
237 }
238
239 func TestLoadHTMLFilesDebugMode(t *testing.T) {
240 ts := setupHTMLFiles(
241 t,
242 DebugMode,
243 false,
244 func(router *Engine) {
245 router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
246 },
247 )
248 defer ts.Close()
249
250 res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
251 if err != nil {
252 t.Error(err)
253 }
254
255 resp, _ := io.ReadAll(res.Body)
256 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
257 }
258
259 func TestLoadHTMLFilesReleaseMode(t *testing.T) {
260 ts := setupHTMLFiles(
261 t,
262 ReleaseMode,
263 false,
264 func(router *Engine) {
265 router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
266 },
267 )
268 defer ts.Close()
269
270 res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
271 if err != nil {
272 t.Error(err)
273 }
274
275 resp, _ := io.ReadAll(res.Body)
276 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
277 }
278
279 func TestLoadHTMLFilesUsingTLS(t *testing.T) {
280 ts := setupHTMLFiles(
281 t,
282 TestMode,
283 true,
284 func(router *Engine) {
285 router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
286 },
287 )
288 defer ts.Close()
289
290
291 tr := &http.Transport{
292 TLSClientConfig: &tls.Config{
293 InsecureSkipVerify: true,
294 },
295 }
296 client := &http.Client{Transport: tr}
297 res, err := client.Get(fmt.Sprintf("%s/test", ts.URL))
298 if err != nil {
299 t.Error(err)
300 }
301
302 resp, _ := io.ReadAll(res.Body)
303 assert.Equal(t, "<h1>Hello world</h1>", string(resp))
304 }
305
306 func TestLoadHTMLFilesFuncMap(t *testing.T) {
307 ts := setupHTMLFiles(
308 t,
309 TestMode,
310 false,
311 func(router *Engine) {
312 router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
313 },
314 )
315 defer ts.Close()
316
317 res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL))
318 if err != nil {
319 t.Error(err)
320 }
321
322 resp, _ := io.ReadAll(res.Body)
323 assert.Equal(t, "Date: 2017/07/01", string(resp))
324 }
325
326 func TestAddRoute(t *testing.T) {
327 router := New()
328 router.addRoute("GET", "/", HandlersChain{func(_ *Context) {}})
329
330 assert.Len(t, router.trees, 1)
331 assert.NotNil(t, router.trees.get("GET"))
332 assert.Nil(t, router.trees.get("POST"))
333
334 router.addRoute("POST", "/", HandlersChain{func(_ *Context) {}})
335
336 assert.Len(t, router.trees, 2)
337 assert.NotNil(t, router.trees.get("GET"))
338 assert.NotNil(t, router.trees.get("POST"))
339
340 router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}})
341 assert.Len(t, router.trees, 2)
342 }
343
344 func TestAddRouteFails(t *testing.T) {
345 router := New()
346 assert.Panics(t, func() { router.addRoute("", "/", HandlersChain{func(_ *Context) {}}) })
347 assert.Panics(t, func() { router.addRoute("GET", "a", HandlersChain{func(_ *Context) {}}) })
348 assert.Panics(t, func() { router.addRoute("GET", "/", HandlersChain{}) })
349
350 router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}})
351 assert.Panics(t, func() {
352 router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}})
353 })
354 }
355
356 func TestCreateDefaultRouter(t *testing.T) {
357 router := Default()
358 assert.Len(t, router.Handlers, 2)
359 }
360
361 func TestNoRouteWithoutGlobalHandlers(t *testing.T) {
362 var middleware0 HandlerFunc = func(c *Context) {}
363 var middleware1 HandlerFunc = func(c *Context) {}
364
365 router := New()
366
367 router.NoRoute(middleware0)
368 assert.Nil(t, router.Handlers)
369 assert.Len(t, router.noRoute, 1)
370 assert.Len(t, router.allNoRoute, 1)
371 compareFunc(t, router.noRoute[0], middleware0)
372 compareFunc(t, router.allNoRoute[0], middleware0)
373
374 router.NoRoute(middleware1, middleware0)
375 assert.Len(t, router.noRoute, 2)
376 assert.Len(t, router.allNoRoute, 2)
377 compareFunc(t, router.noRoute[0], middleware1)
378 compareFunc(t, router.allNoRoute[0], middleware1)
379 compareFunc(t, router.noRoute[1], middleware0)
380 compareFunc(t, router.allNoRoute[1], middleware0)
381 }
382
383 func TestNoRouteWithGlobalHandlers(t *testing.T) {
384 var middleware0 HandlerFunc = func(c *Context) {}
385 var middleware1 HandlerFunc = func(c *Context) {}
386 var middleware2 HandlerFunc = func(c *Context) {}
387
388 router := New()
389 router.Use(middleware2)
390
391 router.NoRoute(middleware0)
392 assert.Len(t, router.allNoRoute, 2)
393 assert.Len(t, router.Handlers, 1)
394 assert.Len(t, router.noRoute, 1)
395
396 compareFunc(t, router.Handlers[0], middleware2)
397 compareFunc(t, router.noRoute[0], middleware0)
398 compareFunc(t, router.allNoRoute[0], middleware2)
399 compareFunc(t, router.allNoRoute[1], middleware0)
400
401 router.Use(middleware1)
402 assert.Len(t, router.allNoRoute, 3)
403 assert.Len(t, router.Handlers, 2)
404 assert.Len(t, router.noRoute, 1)
405
406 compareFunc(t, router.Handlers[0], middleware2)
407 compareFunc(t, router.Handlers[1], middleware1)
408 compareFunc(t, router.noRoute[0], middleware0)
409 compareFunc(t, router.allNoRoute[0], middleware2)
410 compareFunc(t, router.allNoRoute[1], middleware1)
411 compareFunc(t, router.allNoRoute[2], middleware0)
412 }
413
414 func TestNoMethodWithoutGlobalHandlers(t *testing.T) {
415 var middleware0 HandlerFunc = func(c *Context) {}
416 var middleware1 HandlerFunc = func(c *Context) {}
417
418 router := New()
419
420 router.NoMethod(middleware0)
421 assert.Empty(t, router.Handlers)
422 assert.Len(t, router.noMethod, 1)
423 assert.Len(t, router.allNoMethod, 1)
424 compareFunc(t, router.noMethod[0], middleware0)
425 compareFunc(t, router.allNoMethod[0], middleware0)
426
427 router.NoMethod(middleware1, middleware0)
428 assert.Len(t, router.noMethod, 2)
429 assert.Len(t, router.allNoMethod, 2)
430 compareFunc(t, router.noMethod[0], middleware1)
431 compareFunc(t, router.allNoMethod[0], middleware1)
432 compareFunc(t, router.noMethod[1], middleware0)
433 compareFunc(t, router.allNoMethod[1], middleware0)
434 }
435
436 func TestRebuild404Handlers(t *testing.T) {
437 }
438
439 func TestNoMethodWithGlobalHandlers(t *testing.T) {
440 var middleware0 HandlerFunc = func(c *Context) {}
441 var middleware1 HandlerFunc = func(c *Context) {}
442 var middleware2 HandlerFunc = func(c *Context) {}
443
444 router := New()
445 router.Use(middleware2)
446
447 router.NoMethod(middleware0)
448 assert.Len(t, router.allNoMethod, 2)
449 assert.Len(t, router.Handlers, 1)
450 assert.Len(t, router.noMethod, 1)
451
452 compareFunc(t, router.Handlers[0], middleware2)
453 compareFunc(t, router.noMethod[0], middleware0)
454 compareFunc(t, router.allNoMethod[0], middleware2)
455 compareFunc(t, router.allNoMethod[1], middleware0)
456
457 router.Use(middleware1)
458 assert.Len(t, router.allNoMethod, 3)
459 assert.Len(t, router.Handlers, 2)
460 assert.Len(t, router.noMethod, 1)
461
462 compareFunc(t, router.Handlers[0], middleware2)
463 compareFunc(t, router.Handlers[1], middleware1)
464 compareFunc(t, router.noMethod[0], middleware0)
465 compareFunc(t, router.allNoMethod[0], middleware2)
466 compareFunc(t, router.allNoMethod[1], middleware1)
467 compareFunc(t, router.allNoMethod[2], middleware0)
468 }
469
470 func compareFunc(t *testing.T, a, b any) {
471 sf1 := reflect.ValueOf(a)
472 sf2 := reflect.ValueOf(b)
473 if sf1.Pointer() != sf2.Pointer() {
474 t.Error("different functions")
475 }
476 }
477
478 func TestListOfRoutes(t *testing.T) {
479 router := New()
480 router.GET("/favicon.ico", handlerTest1)
481 router.GET("/", handlerTest1)
482 group := router.Group("/users")
483 {
484 group.GET("/", handlerTest2)
485 group.GET("/:id", handlerTest1)
486 group.POST("/:id", handlerTest2)
487 }
488 router.Static("/static", ".")
489
490 list := router.Routes()
491
492 assert.Len(t, list, 7)
493 assertRoutePresent(t, list, RouteInfo{
494 Method: "GET",
495 Path: "/favicon.ico",
496 Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
497 })
498 assertRoutePresent(t, list, RouteInfo{
499 Method: "GET",
500 Path: "/",
501 Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
502 })
503 assertRoutePresent(t, list, RouteInfo{
504 Method: "GET",
505 Path: "/users/",
506 Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$",
507 })
508 assertRoutePresent(t, list, RouteInfo{
509 Method: "GET",
510 Path: "/users/:id",
511 Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
512 })
513 assertRoutePresent(t, list, RouteInfo{
514 Method: "POST",
515 Path: "/users/:id",
516 Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$",
517 })
518 }
519
520 func TestEngineHandleContext(t *testing.T) {
521 r := New()
522 r.GET("/", func(c *Context) {
523 c.Request.URL.Path = "/v2"
524 r.HandleContext(c)
525 })
526 v2 := r.Group("/v2")
527 {
528 v2.GET("/", func(c *Context) {})
529 }
530
531 assert.NotPanics(t, func() {
532 w := PerformRequest(r, "GET", "/")
533 assert.Equal(t, 301, w.Code)
534 })
535 }
536
537 func TestEngineHandleContextManyReEntries(t *testing.T) {
538 expectValue := 10000
539
540 var handlerCounter, middlewareCounter int64
541
542 r := New()
543 r.Use(func(c *Context) {
544 atomic.AddInt64(&middlewareCounter, 1)
545 })
546 r.GET("/:count", func(c *Context) {
547 countStr := c.Param("count")
548 count, err := strconv.Atoi(countStr)
549 assert.NoError(t, err)
550
551 n, err := c.Writer.Write([]byte("."))
552 assert.NoError(t, err)
553 assert.Equal(t, 1, n)
554
555 switch {
556 case count > 0:
557 c.Request.URL.Path = "/" + strconv.Itoa(count-1)
558 r.HandleContext(c)
559 }
560 }, func(c *Context) {
561 atomic.AddInt64(&handlerCounter, 1)
562 })
563
564 assert.NotPanics(t, func() {
565 w := PerformRequest(r, "GET", "/"+strconv.Itoa(expectValue-1))
566 assert.Equal(t, 200, w.Code)
567 assert.Equal(t, expectValue, w.Body.Len())
568 })
569
570 assert.Equal(t, int64(expectValue), handlerCounter)
571 assert.Equal(t, int64(expectValue), middlewareCounter)
572 }
573
574 func TestPrepareTrustedCIRDsWith(t *testing.T) {
575 r := New()
576
577
578 {
579 expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")}
580 err := r.SetTrustedProxies([]string{"0.0.0.0/0"})
581
582 assert.NoError(t, err)
583 assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
584 }
585
586
587 {
588 err := r.SetTrustedProxies([]string{"192.168.1.33/33"})
589
590 assert.Error(t, err)
591 }
592
593
594 {
595 expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")}
596
597 err := r.SetTrustedProxies([]string{"192.168.1.33"})
598
599 assert.NoError(t, err)
600 assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
601 }
602
603
604 {
605 err := r.SetTrustedProxies([]string{"192.168.1.256"})
606
607 assert.Error(t, err)
608 }
609
610
611 {
612 expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
613 err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"})
614
615 assert.NoError(t, err)
616 assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
617 }
618
619
620 {
621 err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"})
622
623 assert.Error(t, err)
624 }
625
626
627 {
628 expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")}
629 err := r.SetTrustedProxies([]string{"::/0"})
630
631 assert.NoError(t, err)
632 assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
633 }
634
635
636 {
637 err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"})
638
639 assert.Error(t, err)
640 }
641
642
643 {
644 expectedTrustedCIDRs := []*net.IPNet{
645 parseCIDR("::/0"),
646 parseCIDR("192.168.0.0/16"),
647 parseCIDR("172.16.0.1/32"),
648 }
649 err := r.SetTrustedProxies([]string{
650 "::/0",
651 "192.168.0.0/16",
652 "172.16.0.1",
653 })
654
655 assert.NoError(t, err)
656 assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
657 }
658
659
660 {
661 err := r.SetTrustedProxies([]string{
662 "::/0",
663 "192.168.0.0/16",
664 "172.16.0.256",
665 })
666
667 assert.Error(t, err)
668 }
669
670
671 {
672 err := r.SetTrustedProxies(nil)
673
674 assert.Nil(t, r.trustedCIDRs)
675 assert.Nil(t, err)
676 }
677 }
678
679 func parseCIDR(cidr string) *net.IPNet {
680 _, parsedCIDR, err := net.ParseCIDR(cidr)
681 if err != nil {
682 fmt.Println(err)
683 }
684 return parsedCIDR
685 }
686
687 func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) {
688 for _, gotRoute := range gotRoutes {
689 if gotRoute.Path == wantRoute.Path && gotRoute.Method == wantRoute.Method {
690 assert.Regexp(t, wantRoute.Handler, gotRoute.Handler)
691 return
692 }
693 }
694 t.Errorf("route not found: %v", wantRoute)
695 }
696
697 func handlerTest1(c *Context) {}
698 func handlerTest2(c *Context) {}
699
View as plain text