Source file src/net/http/csrf_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http_test
     6  
     7  import (
     8  	"io"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"strings"
    12  	"testing"
    13  )
    14  
    15  // httptestNewRequest works around https://go.dev/issue/73151.
    16  func httptestNewRequest(method, target string) *http.Request {
    17  	req := httptest.NewRequest(method, target, nil)
    18  	req.URL.Scheme = ""
    19  	req.URL.Host = ""
    20  	return req
    21  }
    22  
    23  var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    24  	w.WriteHeader(http.StatusOK)
    25  })
    26  
    27  func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
    28  	protection := http.NewCrossOriginProtection()
    29  	handler := protection.Handler(okHandler)
    30  
    31  	tests := []struct {
    32  		name           string
    33  		method         string
    34  		secFetchSite   string
    35  		origin         string
    36  		expectedStatus int
    37  	}{
    38  		{"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
    39  		{"none allowed", "POST", "none", "", http.StatusOK},
    40  		{"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
    41  		{"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
    42  
    43  		{"no header with no origin", "POST", "", "", http.StatusOK},
    44  		{"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
    45  		{"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
    46  		{"no header with null origin", "POST", "", "null", http.StatusForbidden},
    47  
    48  		{"GET allowed", "GET", "cross-site", "", http.StatusOK},
    49  		{"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
    50  		{"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
    51  		{"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
    52  	}
    53  
    54  	for _, tc := range tests {
    55  		t.Run(tc.name, func(t *testing.T) {
    56  			req := httptestNewRequest(tc.method, "https://example.com/")
    57  			if tc.secFetchSite != "" {
    58  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
    59  			}
    60  			if tc.origin != "" {
    61  				req.Header.Set("Origin", tc.origin)
    62  			}
    63  
    64  			w := httptest.NewRecorder()
    65  			handler.ServeHTTP(w, req)
    66  
    67  			if w.Code != tc.expectedStatus {
    68  				t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
    69  			}
    70  		})
    71  	}
    72  }
    73  
    74  func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
    75  	protection := http.NewCrossOriginProtection()
    76  	err := protection.AddTrustedOrigin("https://trusted.example")
    77  	if err != nil {
    78  		t.Fatalf("AddTrustedOrigin: %v", err)
    79  	}
    80  	handler := protection.Handler(okHandler)
    81  
    82  	tests := []struct {
    83  		name           string
    84  		origin         string
    85  		secFetchSite   string
    86  		expectedStatus int
    87  	}{
    88  		{"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
    89  		{"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
    90  		{"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
    91  		{"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
    92  	}
    93  
    94  	for _, tc := range tests {
    95  		t.Run(tc.name, func(t *testing.T) {
    96  			req := httptestNewRequest("POST", "https://example.com/")
    97  			req.Header.Set("Origin", tc.origin)
    98  			if tc.secFetchSite != "" {
    99  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
   100  			}
   101  
   102  			w := httptest.NewRecorder()
   103  			handler.ServeHTTP(w, req)
   104  
   105  			if w.Code != tc.expectedStatus {
   106  				t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
   107  			}
   108  		})
   109  	}
   110  }
   111  
   112  func TestCrossOriginProtectionPatternBypass(t *testing.T) {
   113  	protection := http.NewCrossOriginProtection()
   114  	protection.AddInsecureBypassPattern("/bypass/")
   115  	protection.AddInsecureBypassPattern("/only/{foo}")
   116  	handler := protection.Handler(okHandler)
   117  
   118  	tests := []struct {
   119  		name           string
   120  		path           string
   121  		secFetchSite   string
   122  		expectedStatus int
   123  	}{
   124  		{"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
   125  		{"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
   126  		{"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
   127  		{"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
   128  
   129  		{"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusOK},
   130  		{"redirect to bypass path with trailing slash", "/bypass", "", http.StatusOK},
   131  		{"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
   132  		{"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
   133  
   134  		{"wildcard bypass", "/only/123", "", http.StatusOK},
   135  		{"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
   136  	}
   137  
   138  	for _, tc := range tests {
   139  		t.Run(tc.name, func(t *testing.T) {
   140  			req := httptestNewRequest("POST", "https://example.com"+tc.path)
   141  			req.Header.Set("Origin", "https://attacker.example")
   142  			if tc.secFetchSite != "" {
   143  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
   144  			}
   145  
   146  			w := httptest.NewRecorder()
   147  			handler.ServeHTTP(w, req)
   148  
   149  			if w.Code != tc.expectedStatus {
   150  				t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
   151  			}
   152  		})
   153  	}
   154  }
   155  
   156  func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
   157  	protection := http.NewCrossOriginProtection()
   158  
   159  	handler := protection.Handler(okHandler)
   160  
   161  	req := httptestNewRequest("POST", "https://example.com/")
   162  	req.Header.Set("Sec-Fetch-Site", "cross-site")
   163  
   164  	w := httptest.NewRecorder()
   165  	handler.ServeHTTP(w, req)
   166  
   167  	if w.Code != http.StatusForbidden {
   168  		t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
   169  	}
   170  
   171  	customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   172  		w.WriteHeader(http.StatusTeapot)
   173  		io.WriteString(w, "custom error")
   174  	})
   175  	protection.SetDenyHandler(customErrHandler)
   176  
   177  	w = httptest.NewRecorder()
   178  	handler.ServeHTTP(w, req)
   179  
   180  	if w.Code != http.StatusTeapot {
   181  		t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
   182  	}
   183  
   184  	if !strings.Contains(w.Body.String(), "custom error") {
   185  		t.Errorf("expected custom error message, got: %q", w.Body.String())
   186  	}
   187  
   188  	req = httptestNewRequest("GET", "https://example.com/")
   189  
   190  	w = httptest.NewRecorder()
   191  	handler.ServeHTTP(w, req)
   192  
   193  	if w.Code != http.StatusOK {
   194  		t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
   195  	}
   196  
   197  	protection.SetDenyHandler(nil)
   198  
   199  	req = httptestNewRequest("POST", "https://example.com/")
   200  	req.Header.Set("Sec-Fetch-Site", "cross-site")
   201  
   202  	w = httptest.NewRecorder()
   203  	handler.ServeHTTP(w, req)
   204  
   205  	if w.Code != http.StatusForbidden {
   206  		t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
   207  	}
   208  }
   209  
   210  func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
   211  	protection := http.NewCrossOriginProtection()
   212  
   213  	tests := []struct {
   214  		name    string
   215  		origin  string
   216  		wantErr bool
   217  	}{
   218  		{"valid origin", "https://example.com", false},
   219  		{"valid origin with port", "https://example.com:8080", false},
   220  		{"http origin", "http://example.com", false},
   221  		{"missing scheme", "example.com", true},
   222  		{"missing host", "https://", true},
   223  		{"trailing slash", "https://example.com/", true},
   224  		{"with path", "https://example.com/path", true},
   225  		{"with query", "https://example.com?query=value", true},
   226  		{"with fragment", "https://example.com#fragment", true},
   227  		{"invalid url", "https://ex ample.com", true},
   228  		{"empty string", "", true},
   229  		{"null", "null", true},
   230  	}
   231  
   232  	for _, tc := range tests {
   233  		t.Run(tc.name, func(t *testing.T) {
   234  			err := protection.AddTrustedOrigin(tc.origin)
   235  			if (err != nil) != tc.wantErr {
   236  				t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
   237  			}
   238  		})
   239  	}
   240  }
   241  
   242  func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
   243  	protection := http.NewCrossOriginProtection()
   244  	handler := protection.Handler(okHandler)
   245  
   246  	req := httptestNewRequest("POST", "https://example.com/")
   247  	req.Header.Set("Origin", "https://concurrent.example")
   248  	req.Header.Set("Sec-Fetch-Site", "cross-site")
   249  
   250  	w := httptest.NewRecorder()
   251  	handler.ServeHTTP(w, req)
   252  
   253  	if w.Code != http.StatusForbidden {
   254  		t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
   255  	}
   256  
   257  	start := make(chan struct{})
   258  	done := make(chan struct{})
   259  	go func() {
   260  		close(start)
   261  		defer close(done)
   262  		for range 10 {
   263  			w := httptest.NewRecorder()
   264  			handler.ServeHTTP(w, req)
   265  		}
   266  	}()
   267  
   268  	// Add bypasses while the requests are in flight.
   269  	<-start
   270  	protection.AddTrustedOrigin("https://concurrent.example")
   271  	protection.AddInsecureBypassPattern("/foo/")
   272  	<-done
   273  
   274  	w = httptest.NewRecorder()
   275  	handler.ServeHTTP(w, req)
   276  
   277  	if w.Code != http.StatusOK {
   278  		t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
   279  	}
   280  }
   281  
   282  func TestCrossOriginProtectionServer(t *testing.T) {
   283  	protection := http.NewCrossOriginProtection()
   284  	protection.AddTrustedOrigin("https://trusted.example")
   285  	protection.AddInsecureBypassPattern("/bypass/")
   286  	handler := protection.Handler(okHandler)
   287  
   288  	ts := httptest.NewServer(handler)
   289  	defer ts.Close()
   290  
   291  	tests := []struct {
   292  		name           string
   293  		method         string
   294  		url            string
   295  		origin         string
   296  		secFetchSite   string
   297  		expectedStatus int
   298  	}{
   299  		{"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
   300  		{"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
   301  		{"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
   302  		{"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
   303  		{"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
   304  		{"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
   305  	}
   306  
   307  	for _, tc := range tests {
   308  		t.Run(tc.name, func(t *testing.T) {
   309  			req, err := http.NewRequest(tc.method, tc.url, nil)
   310  			if err != nil {
   311  				t.Fatalf("NewRequest: %v", err)
   312  			}
   313  			if tc.origin != "" {
   314  				req.Header.Set("Origin", tc.origin)
   315  			}
   316  			if tc.secFetchSite != "" {
   317  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
   318  			}
   319  			client := &http.Client{}
   320  			resp, err := client.Do(req)
   321  			if err != nil {
   322  				t.Fatalf("Do: %v", err)
   323  			}
   324  			defer resp.Body.Close()
   325  			if resp.StatusCode != tc.expectedStatus {
   326  				t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
   327  			}
   328  		})
   329  	}
   330  }
   331  

View as plain text