1
2
3
4
5 package gin
6
7 import (
8 "encoding/base64"
9 "net/http"
10 "net/http/httptest"
11 "testing"
12
13 "github.com/stretchr/testify/assert"
14 )
15
16 func TestBasicAuth(t *testing.T) {
17 pairs := processAccounts(Accounts{
18 "admin": "password",
19 "foo": "bar",
20 "bar": "foo",
21 })
22
23 assert.Len(t, pairs, 3)
24 assert.Contains(t, pairs, authPair{
25 user: "bar",
26 value: "Basic YmFyOmZvbw==",
27 })
28 assert.Contains(t, pairs, authPair{
29 user: "foo",
30 value: "Basic Zm9vOmJhcg==",
31 })
32 assert.Contains(t, pairs, authPair{
33 user: "admin",
34 value: "Basic YWRtaW46cGFzc3dvcmQ=",
35 })
36 }
37
38 func TestBasicAuthFails(t *testing.T) {
39 assert.Panics(t, func() { processAccounts(nil) })
40 assert.Panics(t, func() {
41 processAccounts(Accounts{
42 "": "password",
43 "foo": "bar",
44 })
45 })
46 }
47
48 func TestBasicAuthSearchCredential(t *testing.T) {
49 pairs := processAccounts(Accounts{
50 "admin": "password",
51 "foo": "bar",
52 "bar": "foo",
53 })
54
55 user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
56 assert.Equal(t, "admin", user)
57 assert.True(t, found)
58
59 user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
60 assert.Equal(t, "foo", user)
61 assert.True(t, found)
62
63 user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
64 assert.Equal(t, "bar", user)
65 assert.True(t, found)
66
67 user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
68 assert.Empty(t, user)
69 assert.False(t, found)
70
71 user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
72 assert.Empty(t, user)
73 assert.False(t, found)
74
75 user, found = pairs.searchCredential("")
76 assert.Empty(t, user)
77 assert.False(t, found)
78 }
79
80 func TestBasicAuthAuthorizationHeader(t *testing.T) {
81 assert.Equal(t, "Basic YWRtaW46cGFzc3dvcmQ=", authorizationHeader("admin", "password"))
82 }
83
84 func TestBasicAuthSucceed(t *testing.T) {
85 accounts := Accounts{"admin": "password"}
86 router := New()
87 router.Use(BasicAuth(accounts))
88 router.GET("/login", func(c *Context) {
89 c.String(http.StatusOK, c.MustGet(AuthUserKey).(string))
90 })
91
92 w := httptest.NewRecorder()
93 req, _ := http.NewRequest("GET", "/login", nil)
94 req.Header.Set("Authorization", authorizationHeader("admin", "password"))
95 router.ServeHTTP(w, req)
96
97 assert.Equal(t, http.StatusOK, w.Code)
98 assert.Equal(t, "admin", w.Body.String())
99 }
100
101 func TestBasicAuth401(t *testing.T) {
102 called := false
103 accounts := Accounts{"foo": "bar"}
104 router := New()
105 router.Use(BasicAuth(accounts))
106 router.GET("/login", func(c *Context) {
107 called = true
108 c.String(http.StatusOK, c.MustGet(AuthUserKey).(string))
109 })
110
111 w := httptest.NewRecorder()
112 req, _ := http.NewRequest("GET", "/login", nil)
113 req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
114 router.ServeHTTP(w, req)
115
116 assert.False(t, called)
117 assert.Equal(t, http.StatusUnauthorized, w.Code)
118 assert.Equal(t, "Basic realm=\"Authorization Required\"", w.Header().Get("WWW-Authenticate"))
119 }
120
121 func TestBasicAuth401WithCustomRealm(t *testing.T) {
122 called := false
123 accounts := Accounts{"foo": "bar"}
124 router := New()
125 router.Use(BasicAuthForRealm(accounts, "My Custom \"Realm\""))
126 router.GET("/login", func(c *Context) {
127 called = true
128 c.String(http.StatusOK, c.MustGet(AuthUserKey).(string))
129 })
130
131 w := httptest.NewRecorder()
132 req, _ := http.NewRequest("GET", "/login", nil)
133 req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
134 router.ServeHTTP(w, req)
135
136 assert.False(t, called)
137 assert.Equal(t, http.StatusUnauthorized, w.Code)
138 assert.Equal(t, "Basic realm=\"My Custom \\\"Realm\\\"\"", w.Header().Get("WWW-Authenticate"))
139 }
140
View as plain text