...

Source file src/github.com/gin-gonic/gin/gin_test.go

Documentation: github.com/gin-gonic/gin

     1  // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
     2  // Use of this source code is governed by a MIT style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gin
     6  
     7  import (
     8  	"crypto/tls"
     9  	"fmt"
    10  	"html/template"
    11  	"io"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"reflect"
    16  	"strconv"
    17  	"sync/atomic"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  func formatAsDate(t time.Time) string {
    26  	year, month, day := t.Date()
    27  	return fmt.Sprintf("%d/%02d/%02d", year, month, day)
    28  }
    29  
    30  func setupHTMLFiles(t *testing.T, mode string, tls bool, loadMethod func(*Engine)) *httptest.Server {
    31  	SetMode(mode)
    32  	defer SetMode(TestMode)
    33  
    34  	var router *Engine
    35  	captureOutput(t, func() {
    36  		router = New()
    37  		router.Delims("{[{", "}]}")
    38  		router.SetFuncMap(template.FuncMap{
    39  			"formatAsDate": formatAsDate,
    40  		})
    41  		loadMethod(router)
    42  		router.GET("/test", func(c *Context) {
    43  			c.HTML(http.StatusOK, "hello.tmpl", map[string]string{"name": "world"})
    44  		})
    45  		router.GET("/raw", func(c *Context) {
    46  			c.HTML(http.StatusOK, "raw.tmpl", map[string]any{
    47  				"now": time.Date(2017, 07, 01, 0, 0, 0, 0, time.UTC),
    48  			})
    49  		})
    50  	})
    51  
    52  	var ts *httptest.Server
    53  
    54  	if tls {
    55  		ts = httptest.NewTLSServer(router)
    56  	} else {
    57  		ts = httptest.NewServer(router)
    58  	}
    59  
    60  	return ts
    61  }
    62  
    63  func TestLoadHTMLGlobDebugMode(t *testing.T) {
    64  	ts := setupHTMLFiles(
    65  		t,
    66  		DebugMode,
    67  		false,
    68  		func(router *Engine) {
    69  			router.LoadHTMLGlob("./testdata/template/*")
    70  		},
    71  	)
    72  	defer ts.Close()
    73  
    74  	res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
    75  	if err != nil {
    76  		t.Error(err)
    77  	}
    78  
    79  	resp, _ := io.ReadAll(res.Body)
    80  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
    81  }
    82  
    83  func TestH2c(t *testing.T) {
    84  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    85  	if err != nil {
    86  		t.Error(err)
    87  	}
    88  	r := Default()
    89  	r.UseH2C = true
    90  	r.GET("/", func(c *Context) {
    91  		c.String(200, "<h1>Hello world</h1>")
    92  	})
    93  	go func() {
    94  		err := http.Serve(ln, r.Handler())
    95  		if err != nil {
    96  			t.Log(err)
    97  		}
    98  	}()
    99  	defer ln.Close()
   100  
   101  	url := "http://" + ln.Addr().String() + "/"
   102  
   103  	httpClient := http.Client{
   104  		Transport: &http2.Transport{
   105  			AllowHTTP: true,
   106  			DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
   107  				return net.Dial(netw, addr)
   108  			},
   109  		},
   110  	}
   111  
   112  	res, err := httpClient.Get(url)
   113  	if err != nil {
   114  		t.Error(err)
   115  	}
   116  
   117  	resp, _ := io.ReadAll(res.Body)
   118  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   119  }
   120  
   121  func TestLoadHTMLGlobTestMode(t *testing.T) {
   122  	ts := setupHTMLFiles(
   123  		t,
   124  		TestMode,
   125  		false,
   126  		func(router *Engine) {
   127  			router.LoadHTMLGlob("./testdata/template/*")
   128  		},
   129  	)
   130  	defer ts.Close()
   131  
   132  	res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
   133  	if err != nil {
   134  		t.Error(err)
   135  	}
   136  
   137  	resp, _ := io.ReadAll(res.Body)
   138  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   139  }
   140  
   141  func TestLoadHTMLGlobReleaseMode(t *testing.T) {
   142  	ts := setupHTMLFiles(
   143  		t,
   144  		ReleaseMode,
   145  		false,
   146  		func(router *Engine) {
   147  			router.LoadHTMLGlob("./testdata/template/*")
   148  		},
   149  	)
   150  	defer ts.Close()
   151  
   152  	res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
   153  	if err != nil {
   154  		t.Error(err)
   155  	}
   156  
   157  	resp, _ := io.ReadAll(res.Body)
   158  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   159  }
   160  
   161  func TestLoadHTMLGlobUsingTLS(t *testing.T) {
   162  	ts := setupHTMLFiles(
   163  		t,
   164  		DebugMode,
   165  		true,
   166  		func(router *Engine) {
   167  			router.LoadHTMLGlob("./testdata/template/*")
   168  		},
   169  	)
   170  	defer ts.Close()
   171  
   172  	// Use InsecureSkipVerify for avoiding `x509: certificate signed by unknown authority` error
   173  	tr := &http.Transport{
   174  		TLSClientConfig: &tls.Config{
   175  			InsecureSkipVerify: true,
   176  		},
   177  	}
   178  	client := &http.Client{Transport: tr}
   179  	res, err := client.Get(fmt.Sprintf("%s/test", ts.URL))
   180  	if err != nil {
   181  		t.Error(err)
   182  	}
   183  
   184  	resp, _ := io.ReadAll(res.Body)
   185  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   186  }
   187  
   188  func TestLoadHTMLGlobFromFuncMap(t *testing.T) {
   189  	ts := setupHTMLFiles(
   190  		t,
   191  		DebugMode,
   192  		false,
   193  		func(router *Engine) {
   194  			router.LoadHTMLGlob("./testdata/template/*")
   195  		},
   196  	)
   197  	defer ts.Close()
   198  
   199  	res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL))
   200  	if err != nil {
   201  		t.Error(err)
   202  	}
   203  
   204  	resp, _ := io.ReadAll(res.Body)
   205  	assert.Equal(t, "Date: 2017/07/01", string(resp))
   206  }
   207  
   208  func init() {
   209  	SetMode(TestMode)
   210  }
   211  
   212  func TestCreateEngine(t *testing.T) {
   213  	router := New()
   214  	assert.Equal(t, "/", router.basePath)
   215  	assert.Equal(t, router.engine, router)
   216  	assert.Empty(t, router.Handlers)
   217  }
   218  
   219  func TestLoadHTMLFilesTestMode(t *testing.T) {
   220  	ts := setupHTMLFiles(
   221  		t,
   222  		TestMode,
   223  		false,
   224  		func(router *Engine) {
   225  			router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
   226  		},
   227  	)
   228  	defer ts.Close()
   229  
   230  	res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
   231  	if err != nil {
   232  		t.Error(err)
   233  	}
   234  
   235  	resp, _ := io.ReadAll(res.Body)
   236  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   237  }
   238  
   239  func TestLoadHTMLFilesDebugMode(t *testing.T) {
   240  	ts := setupHTMLFiles(
   241  		t,
   242  		DebugMode,
   243  		false,
   244  		func(router *Engine) {
   245  			router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
   246  		},
   247  	)
   248  	defer ts.Close()
   249  
   250  	res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
   251  	if err != nil {
   252  		t.Error(err)
   253  	}
   254  
   255  	resp, _ := io.ReadAll(res.Body)
   256  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   257  }
   258  
   259  func TestLoadHTMLFilesReleaseMode(t *testing.T) {
   260  	ts := setupHTMLFiles(
   261  		t,
   262  		ReleaseMode,
   263  		false,
   264  		func(router *Engine) {
   265  			router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
   266  		},
   267  	)
   268  	defer ts.Close()
   269  
   270  	res, err := http.Get(fmt.Sprintf("%s/test", ts.URL))
   271  	if err != nil {
   272  		t.Error(err)
   273  	}
   274  
   275  	resp, _ := io.ReadAll(res.Body)
   276  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   277  }
   278  
   279  func TestLoadHTMLFilesUsingTLS(t *testing.T) {
   280  	ts := setupHTMLFiles(
   281  		t,
   282  		TestMode,
   283  		true,
   284  		func(router *Engine) {
   285  			router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
   286  		},
   287  	)
   288  	defer ts.Close()
   289  
   290  	// Use InsecureSkipVerify for avoiding `x509: certificate signed by unknown authority` error
   291  	tr := &http.Transport{
   292  		TLSClientConfig: &tls.Config{
   293  			InsecureSkipVerify: true,
   294  		},
   295  	}
   296  	client := &http.Client{Transport: tr}
   297  	res, err := client.Get(fmt.Sprintf("%s/test", ts.URL))
   298  	if err != nil {
   299  		t.Error(err)
   300  	}
   301  
   302  	resp, _ := io.ReadAll(res.Body)
   303  	assert.Equal(t, "<h1>Hello world</h1>", string(resp))
   304  }
   305  
   306  func TestLoadHTMLFilesFuncMap(t *testing.T) {
   307  	ts := setupHTMLFiles(
   308  		t,
   309  		TestMode,
   310  		false,
   311  		func(router *Engine) {
   312  			router.LoadHTMLFiles("./testdata/template/hello.tmpl", "./testdata/template/raw.tmpl")
   313  		},
   314  	)
   315  	defer ts.Close()
   316  
   317  	res, err := http.Get(fmt.Sprintf("%s/raw", ts.URL))
   318  	if err != nil {
   319  		t.Error(err)
   320  	}
   321  
   322  	resp, _ := io.ReadAll(res.Body)
   323  	assert.Equal(t, "Date: 2017/07/01", string(resp))
   324  }
   325  
   326  func TestAddRoute(t *testing.T) {
   327  	router := New()
   328  	router.addRoute("GET", "/", HandlersChain{func(_ *Context) {}})
   329  
   330  	assert.Len(t, router.trees, 1)
   331  	assert.NotNil(t, router.trees.get("GET"))
   332  	assert.Nil(t, router.trees.get("POST"))
   333  
   334  	router.addRoute("POST", "/", HandlersChain{func(_ *Context) {}})
   335  
   336  	assert.Len(t, router.trees, 2)
   337  	assert.NotNil(t, router.trees.get("GET"))
   338  	assert.NotNil(t, router.trees.get("POST"))
   339  
   340  	router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}})
   341  	assert.Len(t, router.trees, 2)
   342  }
   343  
   344  func TestAddRouteFails(t *testing.T) {
   345  	router := New()
   346  	assert.Panics(t, func() { router.addRoute("", "/", HandlersChain{func(_ *Context) {}}) })
   347  	assert.Panics(t, func() { router.addRoute("GET", "a", HandlersChain{func(_ *Context) {}}) })
   348  	assert.Panics(t, func() { router.addRoute("GET", "/", HandlersChain{}) })
   349  
   350  	router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}})
   351  	assert.Panics(t, func() {
   352  		router.addRoute("POST", "/post", HandlersChain{func(_ *Context) {}})
   353  	})
   354  }
   355  
   356  func TestCreateDefaultRouter(t *testing.T) {
   357  	router := Default()
   358  	assert.Len(t, router.Handlers, 2)
   359  }
   360  
   361  func TestNoRouteWithoutGlobalHandlers(t *testing.T) {
   362  	var middleware0 HandlerFunc = func(c *Context) {}
   363  	var middleware1 HandlerFunc = func(c *Context) {}
   364  
   365  	router := New()
   366  
   367  	router.NoRoute(middleware0)
   368  	assert.Nil(t, router.Handlers)
   369  	assert.Len(t, router.noRoute, 1)
   370  	assert.Len(t, router.allNoRoute, 1)
   371  	compareFunc(t, router.noRoute[0], middleware0)
   372  	compareFunc(t, router.allNoRoute[0], middleware0)
   373  
   374  	router.NoRoute(middleware1, middleware0)
   375  	assert.Len(t, router.noRoute, 2)
   376  	assert.Len(t, router.allNoRoute, 2)
   377  	compareFunc(t, router.noRoute[0], middleware1)
   378  	compareFunc(t, router.allNoRoute[0], middleware1)
   379  	compareFunc(t, router.noRoute[1], middleware0)
   380  	compareFunc(t, router.allNoRoute[1], middleware0)
   381  }
   382  
   383  func TestNoRouteWithGlobalHandlers(t *testing.T) {
   384  	var middleware0 HandlerFunc = func(c *Context) {}
   385  	var middleware1 HandlerFunc = func(c *Context) {}
   386  	var middleware2 HandlerFunc = func(c *Context) {}
   387  
   388  	router := New()
   389  	router.Use(middleware2)
   390  
   391  	router.NoRoute(middleware0)
   392  	assert.Len(t, router.allNoRoute, 2)
   393  	assert.Len(t, router.Handlers, 1)
   394  	assert.Len(t, router.noRoute, 1)
   395  
   396  	compareFunc(t, router.Handlers[0], middleware2)
   397  	compareFunc(t, router.noRoute[0], middleware0)
   398  	compareFunc(t, router.allNoRoute[0], middleware2)
   399  	compareFunc(t, router.allNoRoute[1], middleware0)
   400  
   401  	router.Use(middleware1)
   402  	assert.Len(t, router.allNoRoute, 3)
   403  	assert.Len(t, router.Handlers, 2)
   404  	assert.Len(t, router.noRoute, 1)
   405  
   406  	compareFunc(t, router.Handlers[0], middleware2)
   407  	compareFunc(t, router.Handlers[1], middleware1)
   408  	compareFunc(t, router.noRoute[0], middleware0)
   409  	compareFunc(t, router.allNoRoute[0], middleware2)
   410  	compareFunc(t, router.allNoRoute[1], middleware1)
   411  	compareFunc(t, router.allNoRoute[2], middleware0)
   412  }
   413  
   414  func TestNoMethodWithoutGlobalHandlers(t *testing.T) {
   415  	var middleware0 HandlerFunc = func(c *Context) {}
   416  	var middleware1 HandlerFunc = func(c *Context) {}
   417  
   418  	router := New()
   419  
   420  	router.NoMethod(middleware0)
   421  	assert.Empty(t, router.Handlers)
   422  	assert.Len(t, router.noMethod, 1)
   423  	assert.Len(t, router.allNoMethod, 1)
   424  	compareFunc(t, router.noMethod[0], middleware0)
   425  	compareFunc(t, router.allNoMethod[0], middleware0)
   426  
   427  	router.NoMethod(middleware1, middleware0)
   428  	assert.Len(t, router.noMethod, 2)
   429  	assert.Len(t, router.allNoMethod, 2)
   430  	compareFunc(t, router.noMethod[0], middleware1)
   431  	compareFunc(t, router.allNoMethod[0], middleware1)
   432  	compareFunc(t, router.noMethod[1], middleware0)
   433  	compareFunc(t, router.allNoMethod[1], middleware0)
   434  }
   435  
   436  func TestRebuild404Handlers(t *testing.T) {
   437  }
   438  
   439  func TestNoMethodWithGlobalHandlers(t *testing.T) {
   440  	var middleware0 HandlerFunc = func(c *Context) {}
   441  	var middleware1 HandlerFunc = func(c *Context) {}
   442  	var middleware2 HandlerFunc = func(c *Context) {}
   443  
   444  	router := New()
   445  	router.Use(middleware2)
   446  
   447  	router.NoMethod(middleware0)
   448  	assert.Len(t, router.allNoMethod, 2)
   449  	assert.Len(t, router.Handlers, 1)
   450  	assert.Len(t, router.noMethod, 1)
   451  
   452  	compareFunc(t, router.Handlers[0], middleware2)
   453  	compareFunc(t, router.noMethod[0], middleware0)
   454  	compareFunc(t, router.allNoMethod[0], middleware2)
   455  	compareFunc(t, router.allNoMethod[1], middleware0)
   456  
   457  	router.Use(middleware1)
   458  	assert.Len(t, router.allNoMethod, 3)
   459  	assert.Len(t, router.Handlers, 2)
   460  	assert.Len(t, router.noMethod, 1)
   461  
   462  	compareFunc(t, router.Handlers[0], middleware2)
   463  	compareFunc(t, router.Handlers[1], middleware1)
   464  	compareFunc(t, router.noMethod[0], middleware0)
   465  	compareFunc(t, router.allNoMethod[0], middleware2)
   466  	compareFunc(t, router.allNoMethod[1], middleware1)
   467  	compareFunc(t, router.allNoMethod[2], middleware0)
   468  }
   469  
   470  func compareFunc(t *testing.T, a, b any) {
   471  	sf1 := reflect.ValueOf(a)
   472  	sf2 := reflect.ValueOf(b)
   473  	if sf1.Pointer() != sf2.Pointer() {
   474  		t.Error("different functions")
   475  	}
   476  }
   477  
   478  func TestListOfRoutes(t *testing.T) {
   479  	router := New()
   480  	router.GET("/favicon.ico", handlerTest1)
   481  	router.GET("/", handlerTest1)
   482  	group := router.Group("/users")
   483  	{
   484  		group.GET("/", handlerTest2)
   485  		group.GET("/:id", handlerTest1)
   486  		group.POST("/:id", handlerTest2)
   487  	}
   488  	router.Static("/static", ".")
   489  
   490  	list := router.Routes()
   491  
   492  	assert.Len(t, list, 7)
   493  	assertRoutePresent(t, list, RouteInfo{
   494  		Method:  "GET",
   495  		Path:    "/favicon.ico",
   496  		Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
   497  	})
   498  	assertRoutePresent(t, list, RouteInfo{
   499  		Method:  "GET",
   500  		Path:    "/",
   501  		Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
   502  	})
   503  	assertRoutePresent(t, list, RouteInfo{
   504  		Method:  "GET",
   505  		Path:    "/users/",
   506  		Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$",
   507  	})
   508  	assertRoutePresent(t, list, RouteInfo{
   509  		Method:  "GET",
   510  		Path:    "/users/:id",
   511  		Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest1$",
   512  	})
   513  	assertRoutePresent(t, list, RouteInfo{
   514  		Method:  "POST",
   515  		Path:    "/users/:id",
   516  		Handler: "^(.*/vendor/)?github.com/gin-gonic/gin.handlerTest2$",
   517  	})
   518  }
   519  
   520  func TestEngineHandleContext(t *testing.T) {
   521  	r := New()
   522  	r.GET("/", func(c *Context) {
   523  		c.Request.URL.Path = "/v2"
   524  		r.HandleContext(c)
   525  	})
   526  	v2 := r.Group("/v2")
   527  	{
   528  		v2.GET("/", func(c *Context) {})
   529  	}
   530  
   531  	assert.NotPanics(t, func() {
   532  		w := PerformRequest(r, "GET", "/")
   533  		assert.Equal(t, 301, w.Code)
   534  	})
   535  }
   536  
   537  func TestEngineHandleContextManyReEntries(t *testing.T) {
   538  	expectValue := 10000
   539  
   540  	var handlerCounter, middlewareCounter int64
   541  
   542  	r := New()
   543  	r.Use(func(c *Context) {
   544  		atomic.AddInt64(&middlewareCounter, 1)
   545  	})
   546  	r.GET("/:count", func(c *Context) {
   547  		countStr := c.Param("count")
   548  		count, err := strconv.Atoi(countStr)
   549  		assert.NoError(t, err)
   550  
   551  		n, err := c.Writer.Write([]byte("."))
   552  		assert.NoError(t, err)
   553  		assert.Equal(t, 1, n)
   554  
   555  		switch {
   556  		case count > 0:
   557  			c.Request.URL.Path = "/" + strconv.Itoa(count-1)
   558  			r.HandleContext(c)
   559  		}
   560  	}, func(c *Context) {
   561  		atomic.AddInt64(&handlerCounter, 1)
   562  	})
   563  
   564  	assert.NotPanics(t, func() {
   565  		w := PerformRequest(r, "GET", "/"+strconv.Itoa(expectValue-1)) // include 0 value
   566  		assert.Equal(t, 200, w.Code)
   567  		assert.Equal(t, expectValue, w.Body.Len())
   568  	})
   569  
   570  	assert.Equal(t, int64(expectValue), handlerCounter)
   571  	assert.Equal(t, int64(expectValue), middlewareCounter)
   572  }
   573  
   574  func TestPrepareTrustedCIRDsWith(t *testing.T) {
   575  	r := New()
   576  
   577  	// valid ipv4 cidr
   578  	{
   579  		expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")}
   580  		err := r.SetTrustedProxies([]string{"0.0.0.0/0"})
   581  
   582  		assert.NoError(t, err)
   583  		assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
   584  	}
   585  
   586  	// invalid ipv4 cidr
   587  	{
   588  		err := r.SetTrustedProxies([]string{"192.168.1.33/33"})
   589  
   590  		assert.Error(t, err)
   591  	}
   592  
   593  	// valid ipv4 address
   594  	{
   595  		expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")}
   596  
   597  		err := r.SetTrustedProxies([]string{"192.168.1.33"})
   598  
   599  		assert.NoError(t, err)
   600  		assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
   601  	}
   602  
   603  	// invalid ipv4 address
   604  	{
   605  		err := r.SetTrustedProxies([]string{"192.168.1.256"})
   606  
   607  		assert.Error(t, err)
   608  	}
   609  
   610  	// valid ipv6 address
   611  	{
   612  		expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
   613  		err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"})
   614  
   615  		assert.NoError(t, err)
   616  		assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
   617  	}
   618  
   619  	// invalid ipv6 address
   620  	{
   621  		err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"})
   622  
   623  		assert.Error(t, err)
   624  	}
   625  
   626  	// valid ipv6 cidr
   627  	{
   628  		expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")}
   629  		err := r.SetTrustedProxies([]string{"::/0"})
   630  
   631  		assert.NoError(t, err)
   632  		assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
   633  	}
   634  
   635  	// invalid ipv6 cidr
   636  	{
   637  		err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"})
   638  
   639  		assert.Error(t, err)
   640  	}
   641  
   642  	// valid combination
   643  	{
   644  		expectedTrustedCIDRs := []*net.IPNet{
   645  			parseCIDR("::/0"),
   646  			parseCIDR("192.168.0.0/16"),
   647  			parseCIDR("172.16.0.1/32"),
   648  		}
   649  		err := r.SetTrustedProxies([]string{
   650  			"::/0",
   651  			"192.168.0.0/16",
   652  			"172.16.0.1",
   653  		})
   654  
   655  		assert.NoError(t, err)
   656  		assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
   657  	}
   658  
   659  	// invalid combination
   660  	{
   661  		err := r.SetTrustedProxies([]string{
   662  			"::/0",
   663  			"192.168.0.0/16",
   664  			"172.16.0.256",
   665  		})
   666  
   667  		assert.Error(t, err)
   668  	}
   669  
   670  	// nil value
   671  	{
   672  		err := r.SetTrustedProxies(nil)
   673  
   674  		assert.Nil(t, r.trustedCIDRs)
   675  		assert.Nil(t, err)
   676  	}
   677  }
   678  
   679  func parseCIDR(cidr string) *net.IPNet {
   680  	_, parsedCIDR, err := net.ParseCIDR(cidr)
   681  	if err != nil {
   682  		fmt.Println(err)
   683  	}
   684  	return parsedCIDR
   685  }
   686  
   687  func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) {
   688  	for _, gotRoute := range gotRoutes {
   689  		if gotRoute.Path == wantRoute.Path && gotRoute.Method == wantRoute.Method {
   690  			assert.Regexp(t, wantRoute.Handler, gotRoute.Handler)
   691  			return
   692  		}
   693  	}
   694  	t.Errorf("route not found: %v", wantRoute)
   695  }
   696  
   697  func handlerTest1(c *Context) {}
   698  func handlerTest2(c *Context) {}
   699  

View as plain text