1
2
3
4
5 package gin
6
7 import (
8 "errors"
9 "net/http"
10 "strings"
11 "testing"
12
13 "github.com/gin-contrib/sse"
14 "github.com/stretchr/testify/assert"
15 )
16
17 func TestMiddlewareGeneralCase(t *testing.T) {
18 signature := ""
19 router := New()
20 router.Use(func(c *Context) {
21 signature += "A"
22 c.Next()
23 signature += "B"
24 })
25 router.Use(func(c *Context) {
26 signature += "C"
27 })
28 router.GET("/", func(c *Context) {
29 signature += "D"
30 })
31 router.NoRoute(func(c *Context) {
32 signature += " X "
33 })
34 router.NoMethod(func(c *Context) {
35 signature += " XX "
36 })
37
38 w := PerformRequest(router, "GET", "/")
39
40
41 assert.Equal(t, http.StatusOK, w.Code)
42 assert.Equal(t, "ACDB", signature)
43 }
44
45 func TestMiddlewareNoRoute(t *testing.T) {
46 signature := ""
47 router := New()
48 router.Use(func(c *Context) {
49 signature += "A"
50 c.Next()
51 signature += "B"
52 })
53 router.Use(func(c *Context) {
54 signature += "C"
55 c.Next()
56 c.Next()
57 c.Next()
58 c.Next()
59 signature += "D"
60 })
61 router.NoRoute(func(c *Context) {
62 signature += "E"
63 c.Next()
64 signature += "F"
65 }, func(c *Context) {
66 signature += "G"
67 c.Next()
68 signature += "H"
69 })
70 router.NoMethod(func(c *Context) {
71 signature += " X "
72 })
73
74 w := PerformRequest(router, "GET", "/")
75
76
77 assert.Equal(t, http.StatusNotFound, w.Code)
78 assert.Equal(t, "ACEGHFDB", signature)
79 }
80
81 func TestMiddlewareNoMethodEnabled(t *testing.T) {
82 signature := ""
83 router := New()
84 router.HandleMethodNotAllowed = true
85 router.Use(func(c *Context) {
86 signature += "A"
87 c.Next()
88 signature += "B"
89 })
90 router.Use(func(c *Context) {
91 signature += "C"
92 c.Next()
93 signature += "D"
94 })
95 router.NoMethod(func(c *Context) {
96 signature += "E"
97 c.Next()
98 signature += "F"
99 }, func(c *Context) {
100 signature += "G"
101 c.Next()
102 signature += "H"
103 })
104 router.NoRoute(func(c *Context) {
105 signature += " X "
106 })
107 router.POST("/", func(c *Context) {
108 signature += " XX "
109 })
110
111 w := PerformRequest(router, "GET", "/")
112
113
114 assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
115 assert.Equal(t, "ACEGHFDB", signature)
116 }
117
118 func TestMiddlewareNoMethodDisabled(t *testing.T) {
119 signature := ""
120 router := New()
121
122
123 router.HandleMethodNotAllowed = false
124
125 router.Use(func(c *Context) {
126 signature += "A"
127 c.Next()
128 signature += "B"
129 })
130 router.Use(func(c *Context) {
131 signature += "C"
132 c.Next()
133 signature += "D"
134 })
135 router.NoMethod(func(c *Context) {
136 signature += "E"
137 c.Next()
138 signature += "F"
139 }, func(c *Context) {
140 signature += "G"
141 c.Next()
142 signature += "H"
143 })
144 router.NoRoute(func(c *Context) {
145 signature += " X "
146 })
147 router.POST("/", func(c *Context) {
148 signature += " XX "
149 })
150
151
152 w := PerformRequest(router, "GET", "/")
153
154
155 assert.Equal(t, http.StatusNotFound, w.Code)
156 assert.Equal(t, "AC X DB", signature)
157 }
158
159 func TestMiddlewareAbort(t *testing.T) {
160 signature := ""
161 router := New()
162 router.Use(func(c *Context) {
163 signature += "A"
164 })
165 router.Use(func(c *Context) {
166 signature += "C"
167 c.AbortWithStatus(http.StatusUnauthorized)
168 c.Next()
169 signature += "D"
170 })
171 router.GET("/", func(c *Context) {
172 signature += " X "
173 c.Next()
174 signature += " XX "
175 })
176
177
178 w := PerformRequest(router, "GET", "/")
179
180
181 assert.Equal(t, http.StatusUnauthorized, w.Code)
182 assert.Equal(t, "ACD", signature)
183 }
184
185 func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) {
186 signature := ""
187 router := New()
188 router.Use(func(c *Context) {
189 signature += "A"
190 c.Next()
191 c.AbortWithStatus(http.StatusGone)
192 signature += "B"
193 })
194 router.GET("/", func(c *Context) {
195 signature += "C"
196 c.Next()
197 })
198
199 w := PerformRequest(router, "GET", "/")
200
201
202 assert.Equal(t, http.StatusGone, w.Code)
203 assert.Equal(t, "ACB", signature)
204 }
205
206
207
208 func TestMiddlewareFailHandlersChain(t *testing.T) {
209
210 signature := ""
211 router := New()
212 router.Use(func(context *Context) {
213 signature += "A"
214 context.AbortWithError(http.StatusInternalServerError, errors.New("foo"))
215 })
216 router.Use(func(context *Context) {
217 signature += "B"
218 context.Next()
219 signature += "C"
220 })
221
222 w := PerformRequest(router, "GET", "/")
223
224
225 assert.Equal(t, http.StatusInternalServerError, w.Code)
226 assert.Equal(t, "A", signature)
227 }
228
229 func TestMiddlewareWrite(t *testing.T) {
230 router := New()
231 router.Use(func(c *Context) {
232 c.String(http.StatusBadRequest, "hola\n")
233 })
234 router.Use(func(c *Context) {
235 c.XML(http.StatusBadRequest, H{"foo": "bar"})
236 })
237 router.Use(func(c *Context) {
238 c.JSON(http.StatusBadRequest, H{"foo": "bar"})
239 })
240 router.GET("/", func(c *Context) {
241 c.JSON(http.StatusBadRequest, H{"foo": "bar"})
242 }, func(c *Context) {
243 c.Render(http.StatusBadRequest, sse.Event{
244 Event: "test",
245 Data: "message",
246 })
247 })
248
249 w := PerformRequest(router, "GET", "/")
250
251 assert.Equal(t, http.StatusBadRequest, w.Code)
252 assert.Equal(t, strings.Replace("hola\n<map><foo>bar</foo></map>{\"foo\":\"bar\"}{\"foo\":\"bar\"}event:test\ndata:message\n\n", " ", "", -1), strings.Replace(w.Body.String(), " ", "", -1))
253 }
254
View as plain text