Source file
src/net/http/csrf_test.go
1
2
3
4
5 package http_test
6
7 import (
8 "io"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 )
14
15
16 func httptestNewRequest(method, target string) *http.Request {
17 req := httptest.NewRequest(method, target, nil)
18 req.URL.Scheme = ""
19 req.URL.Host = ""
20 return req
21 }
22
23 var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 w.WriteHeader(http.StatusOK)
25 })
26
27 func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
28 protection := http.NewCrossOriginProtection()
29 handler := protection.Handler(okHandler)
30
31 tests := []struct {
32 name string
33 method string
34 secFetchSite string
35 origin string
36 expectedStatus int
37 }{
38 {"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
39 {"none allowed", "POST", "none", "", http.StatusOK},
40 {"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
41 {"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
42
43 {"no header with no origin", "POST", "", "", http.StatusOK},
44 {"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
45 {"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
46 {"no header with null origin", "POST", "", "null", http.StatusForbidden},
47
48 {"GET allowed", "GET", "cross-site", "", http.StatusOK},
49 {"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
50 {"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
51 {"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
52 }
53
54 for _, tc := range tests {
55 t.Run(tc.name, func(t *testing.T) {
56 req := httptestNewRequest(tc.method, "https://example.com/")
57 if tc.secFetchSite != "" {
58 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
59 }
60 if tc.origin != "" {
61 req.Header.Set("Origin", tc.origin)
62 }
63
64 w := httptest.NewRecorder()
65 handler.ServeHTTP(w, req)
66
67 if w.Code != tc.expectedStatus {
68 t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
69 }
70 })
71 }
72 }
73
74 func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
75 protection := http.NewCrossOriginProtection()
76 err := protection.AddTrustedOrigin("https://trusted.example")
77 if err != nil {
78 t.Fatalf("AddTrustedOrigin: %v", err)
79 }
80 handler := protection.Handler(okHandler)
81
82 tests := []struct {
83 name string
84 origin string
85 secFetchSite string
86 expectedStatus int
87 }{
88 {"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
89 {"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
90 {"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
91 {"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
92 }
93
94 for _, tc := range tests {
95 t.Run(tc.name, func(t *testing.T) {
96 req := httptestNewRequest("POST", "https://example.com/")
97 req.Header.Set("Origin", tc.origin)
98 if tc.secFetchSite != "" {
99 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
100 }
101
102 w := httptest.NewRecorder()
103 handler.ServeHTTP(w, req)
104
105 if w.Code != tc.expectedStatus {
106 t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
107 }
108 })
109 }
110 }
111
112 func TestCrossOriginProtectionPatternBypass(t *testing.T) {
113 protection := http.NewCrossOriginProtection()
114 protection.AddInsecureBypassPattern("/bypass/")
115 protection.AddInsecureBypassPattern("/only/{foo}")
116 handler := protection.Handler(okHandler)
117
118 tests := []struct {
119 name string
120 path string
121 secFetchSite string
122 expectedStatus int
123 }{
124 {"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
125 {"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
126 {"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
127 {"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
128
129 {"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusOK},
130 {"redirect to bypass path with trailing slash", "/bypass", "", http.StatusOK},
131 {"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
132 {"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
133
134 {"wildcard bypass", "/only/123", "", http.StatusOK},
135 {"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
136 }
137
138 for _, tc := range tests {
139 t.Run(tc.name, func(t *testing.T) {
140 req := httptestNewRequest("POST", "https://example.com"+tc.path)
141 req.Header.Set("Origin", "https://attacker.example")
142 if tc.secFetchSite != "" {
143 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
144 }
145
146 w := httptest.NewRecorder()
147 handler.ServeHTTP(w, req)
148
149 if w.Code != tc.expectedStatus {
150 t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
151 }
152 })
153 }
154 }
155
156 func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
157 protection := http.NewCrossOriginProtection()
158
159 handler := protection.Handler(okHandler)
160
161 req := httptestNewRequest("POST", "https://example.com/")
162 req.Header.Set("Sec-Fetch-Site", "cross-site")
163
164 w := httptest.NewRecorder()
165 handler.ServeHTTP(w, req)
166
167 if w.Code != http.StatusForbidden {
168 t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
169 }
170
171 customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
172 w.WriteHeader(http.StatusTeapot)
173 io.WriteString(w, "custom error")
174 })
175 protection.SetDenyHandler(customErrHandler)
176
177 w = httptest.NewRecorder()
178 handler.ServeHTTP(w, req)
179
180 if w.Code != http.StatusTeapot {
181 t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
182 }
183
184 if !strings.Contains(w.Body.String(), "custom error") {
185 t.Errorf("expected custom error message, got: %q", w.Body.String())
186 }
187
188 req = httptestNewRequest("GET", "https://example.com/")
189
190 w = httptest.NewRecorder()
191 handler.ServeHTTP(w, req)
192
193 if w.Code != http.StatusOK {
194 t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
195 }
196
197 protection.SetDenyHandler(nil)
198
199 req = httptestNewRequest("POST", "https://example.com/")
200 req.Header.Set("Sec-Fetch-Site", "cross-site")
201
202 w = httptest.NewRecorder()
203 handler.ServeHTTP(w, req)
204
205 if w.Code != http.StatusForbidden {
206 t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
207 }
208 }
209
210 func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
211 protection := http.NewCrossOriginProtection()
212
213 tests := []struct {
214 name string
215 origin string
216 wantErr bool
217 }{
218 {"valid origin", "https://example.com", false},
219 {"valid origin with port", "https://example.com:8080", false},
220 {"http origin", "http://example.com", false},
221 {"missing scheme", "example.com", true},
222 {"missing host", "https://", true},
223 {"trailing slash", "https://example.com/", true},
224 {"with path", "https://example.com/path", true},
225 {"with query", "https://example.com?query=value", true},
226 {"with fragment", "https://example.com#fragment", true},
227 {"invalid url", "https://ex ample.com", true},
228 {"empty string", "", true},
229 {"null", "null", true},
230 }
231
232 for _, tc := range tests {
233 t.Run(tc.name, func(t *testing.T) {
234 err := protection.AddTrustedOrigin(tc.origin)
235 if (err != nil) != tc.wantErr {
236 t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
237 }
238 })
239 }
240 }
241
242 func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
243 protection := http.NewCrossOriginProtection()
244 handler := protection.Handler(okHandler)
245
246 req := httptestNewRequest("POST", "https://example.com/")
247 req.Header.Set("Origin", "https://concurrent.example")
248 req.Header.Set("Sec-Fetch-Site", "cross-site")
249
250 w := httptest.NewRecorder()
251 handler.ServeHTTP(w, req)
252
253 if w.Code != http.StatusForbidden {
254 t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
255 }
256
257 start := make(chan struct{})
258 done := make(chan struct{})
259 go func() {
260 close(start)
261 defer close(done)
262 for range 10 {
263 w := httptest.NewRecorder()
264 handler.ServeHTTP(w, req)
265 }
266 }()
267
268
269 <-start
270 protection.AddTrustedOrigin("https://concurrent.example")
271 protection.AddInsecureBypassPattern("/foo/")
272 <-done
273
274 w = httptest.NewRecorder()
275 handler.ServeHTTP(w, req)
276
277 if w.Code != http.StatusOK {
278 t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
279 }
280 }
281
282 func TestCrossOriginProtectionServer(t *testing.T) {
283 protection := http.NewCrossOriginProtection()
284 protection.AddTrustedOrigin("https://trusted.example")
285 protection.AddInsecureBypassPattern("/bypass/")
286 handler := protection.Handler(okHandler)
287
288 ts := httptest.NewServer(handler)
289 defer ts.Close()
290
291 tests := []struct {
292 name string
293 method string
294 url string
295 origin string
296 secFetchSite string
297 expectedStatus int
298 }{
299 {"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
300 {"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
301 {"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
302 {"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
303 {"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
304 {"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
305 }
306
307 for _, tc := range tests {
308 t.Run(tc.name, func(t *testing.T) {
309 req, err := http.NewRequest(tc.method, tc.url, nil)
310 if err != nil {
311 t.Fatalf("NewRequest: %v", err)
312 }
313 if tc.origin != "" {
314 req.Header.Set("Origin", tc.origin)
315 }
316 if tc.secFetchSite != "" {
317 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
318 }
319 client := &http.Client{}
320 resp, err := client.Do(req)
321 if err != nil {
322 t.Fatalf("Do: %v", err)
323 }
324 defer resp.Body.Close()
325 if resp.StatusCode != tc.expectedStatus {
326 t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
327 }
328 })
329 }
330 }
331
View as plain text