refactor: use pinned golangci-lint Docker image for linting
All checks were successful
check / check (push) Successful in 1m37s
All checks were successful
check / check (push) Successful in 1m37s
Refactor Dockerfile to use a separate lint stage with a pinned golangci-lint v2.11.3 Docker image instead of installing golangci-lint via curl in the builder stage. This follows the pattern used by sneak/pixa. Changes: - Dockerfile: separate lint stage using golangci/golangci-lint:v2.11.3 (Debian-based, pinned by sha256) with COPY --from=lint dependency - Bump Go from 1.24 to 1.26.1 (golang:1.26.1-bookworm, pinned) - Bump golangci-lint from v1.64.8 to v2.11.3 - Migrate .golangci.yml from v1 to v2 format (same linters, format only) - All Docker images pinned by sha256 digest - Fix all lint issues from the v2 linter upgrade: - Add package comments to all packages - Add doc comments to all exported types, functions, and methods - Fix unchecked errors (errcheck) - Fix unused parameters (revive) - Fix gosec warnings (MaxBytesReader for form parsing) - Fix staticcheck suggestions (fmt.Fprintf instead of WriteString) - Rename DeliveryTask to Task to avoid stutter (delivery.Task) - Rename shadowed builtin 'max' parameter - Update README.md version requirements
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,362 +12,483 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/middleware"
|
||||
)
|
||||
|
||||
// csrfCookieName is the gorilla/csrf cookie name.
|
||||
const csrfCookieName = "_gorilla_csrf"
|
||||
|
||||
// csrfGetToken performs a GET request through the CSRF middleware
|
||||
// and returns the token and cookies.
|
||||
func csrfGetToken(
|
||||
t *testing.T,
|
||||
csrfMW func(http.Handler) http.Handler,
|
||||
getReq *http.Request,
|
||||
) (string, []*http.Cookie) {
|
||||
t.Helper()
|
||||
|
||||
var token string
|
||||
|
||||
getHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = middleware.CSRFToken(r)
|
||||
},
|
||||
))
|
||||
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies, "CSRF cookie should be set")
|
||||
require.NotEmpty(t, token, "CSRF token should be set")
|
||||
|
||||
return token, cookies
|
||||
}
|
||||
|
||||
// csrfPostWithToken performs a POST request with the given CSRF
|
||||
// token and cookies through the middleware. Returns whether the
|
||||
// handler was called and the response code.
|
||||
func csrfPostWithToken(
|
||||
t *testing.T,
|
||||
csrfMW func(http.Handler) http.Handler,
|
||||
postReq *http.Request,
|
||||
token string,
|
||||
cookies []*http.Cookie,
|
||||
) (bool, int) {
|
||||
t.Helper()
|
||||
|
||||
var called bool
|
||||
|
||||
postHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq.Body = http.NoBody
|
||||
postReq.Body = nil
|
||||
|
||||
// Rebuild the request with the form body
|
||||
rebuilt := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
postReq.Method, postReq.URL.String(),
|
||||
strings.NewReader(form.Encode()),
|
||||
)
|
||||
rebuilt.Header = postReq.Header.Clone()
|
||||
rebuilt.TLS = postReq.TLS
|
||||
rebuilt.Header.Set(
|
||||
"Content-Type", "application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
for _, c := range cookies {
|
||||
rebuilt.AddCookie(c)
|
||||
}
|
||||
|
||||
postW := httptest.NewRecorder()
|
||||
postHandler.ServeHTTP(postW, rebuilt)
|
||||
|
||||
return called, postW.Code
|
||||
}
|
||||
|
||||
func TestCSRF_GETSetsToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var gotToken string
|
||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
gotToken = CSRFToken(r)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
handler := m.CSRF()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, r *http.Request) {
|
||||
gotToken = middleware.CSRFToken(r)
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
||||
assert.NotEmpty(
|
||||
t, gotToken,
|
||||
"CSRF token should be set in context on GET",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
// Capture the token from a GET request
|
||||
var token string
|
||||
csrfMiddleware := m.CSRF()
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = CSRFToken(r)
|
||||
}))
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/form", nil,
|
||||
)
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form", nil,
|
||||
)
|
||||
called, _ := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
// POST with valid token and cookies from the GET response
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.True(t, called, "handler should be called with valid CSRF token")
|
||||
assert.True(
|
||||
t, called,
|
||||
"handler should be called with valid CSRF token",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
// csrfPOSTWithoutTokenTest is a shared helper for testing POST
|
||||
// requests without a CSRF token in both dev and prod modes.
|
||||
func csrfPOSTWithoutTokenTest(
|
||||
t *testing.T,
|
||||
env string,
|
||||
msg string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
csrfMiddleware := m.CSRF()
|
||||
m, _ := testMiddleware(t, env)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
// GET to establish the CSRF cookie
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {},
|
||||
))
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST without CSRF token
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form", nil,
|
||||
)
|
||||
postReq.Header.Set(
|
||||
"Content-Type", "application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, "handler should NOT be called without CSRF token")
|
||||
assert.False(t, called, msg)
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
csrfPOSTWithoutTokenTest(
|
||||
t,
|
||||
config.EnvironmentDev,
|
||||
"handler should NOT be called without CSRF token",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
csrfMiddleware := m.CSRF()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
// GET to establish the CSRF cookie
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {},
|
||||
))
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST with wrong CSRF token
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
postHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form",
|
||||
strings.NewReader(form.Encode()),
|
||||
)
|
||||
postReq.Header.Set(
|
||||
"Content-Type", "application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, "handler should NOT be called with invalid CSRF token")
|
||||
assert.False(
|
||||
t, called,
|
||||
"handler should NOT be called with invalid CSRF token",
|
||||
)
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
}
|
||||
|
||||
func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
// GET requests should pass through without CSRF validation
|
||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
handler := m.CSRF()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
||||
assert.True(
|
||||
t, called,
|
||||
"GET requests should pass through CSRF middleware",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCSRFToken_NoMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run")
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
assert.Empty(
|
||||
t, middleware.CSRFToken(req),
|
||||
"CSRFToken should return empty string when "+
|
||||
"middleware has not run",
|
||||
)
|
||||
}
|
||||
|
||||
// --- TLS Detection Tests ---
|
||||
|
||||
func TestIsClientTLS_DirectTLS(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.TLS = &tls.ConnectionState{} // simulate direct TLS
|
||||
assert.True(t, isClientTLS(r), "should detect direct TLS connection")
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
r.TLS = &tls.ConnectionState{}
|
||||
|
||||
assert.True(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect direct TLS connection",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIsClientTLS_XForwardedProto(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
r.Header.Set("X-Forwarded-Proto", "https")
|
||||
assert.True(t, isClientTLS(r), "should detect TLS via X-Forwarded-Proto")
|
||||
|
||||
assert.True(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect TLS via X-Forwarded-Proto",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIsClientTLS_PlaintextHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.False(t, isClientTLS(r), "should detect plaintext HTTP")
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
assert.False(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect plaintext HTTP",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
r.Header.Set("X-Forwarded-Proto", "http")
|
||||
assert.False(t, isClientTLS(r), "should detect plaintext when X-Forwarded-Proto is http")
|
||||
|
||||
assert.False(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect plaintext when X-Forwarded-Proto is http",
|
||||
)
|
||||
}
|
||||
|
||||
// --- Production Mode: POST over plaintext HTTP ---
|
||||
|
||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(t *testing.T) {
|
||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
// This tests the critical fix: prod mode over plaintext HTTP should
|
||||
// work because the middleware detects the transport per-request.
|
||||
var token string
|
||||
csrfMiddleware := m.CSRF()
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = CSRFToken(r)
|
||||
}))
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/form", nil,
|
||||
)
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
||||
require.NotEmpty(t, token, "CSRF token should be set in context on GET")
|
||||
|
||||
// Verify the cookie is NOT Secure (plaintext HTTP in prod mode)
|
||||
// Verify cookie is NOT Secure (plaintext HTTP in prod)
|
||||
for _, c := range cookies {
|
||||
if c.Name == "_gorilla_csrf" {
|
||||
assert.False(t, c.Secure, "CSRF cookie should not be Secure over plaintext HTTP")
|
||||
if c.Name == csrfCookieName {
|
||||
assert.False(t, c.Secure,
|
||||
"CSRF cookie should not be Secure "+
|
||||
"over plaintext HTTP")
|
||||
}
|
||||
}
|
||||
|
||||
// POST with valid token — should succeed
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form", nil,
|
||||
)
|
||||
called, code := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.True(t, called, "handler should be called — prod mode over plaintext HTTP must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
||||
assert.True(t, called,
|
||||
"handler should be called -- prod mode over "+
|
||||
"plaintext HTTP must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, code,
|
||||
"should not return 403")
|
||||
}
|
||||
|
||||
// --- Production Mode: POST with X-Forwarded-Proto (reverse proxy) ---
|
||||
// --- Production Mode: POST with X-Forwarded-Proto ---
|
||||
|
||||
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(t *testing.T) {
|
||||
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
// Simulates a deployment behind a TLS-terminating reverse proxy.
|
||||
// The Go server sees HTTP but X-Forwarded-Proto is "https".
|
||||
var token string
|
||||
csrfMiddleware := m.CSRF()
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = CSRFToken(r)
|
||||
}))
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "http://example.com/form", nil)
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "http://example.com/form", nil,
|
||||
)
|
||||
getReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
||||
require.NotEmpty(t, token, "CSRF token should be set in context")
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
// Verify the cookie IS Secure (X-Forwarded-Proto: https)
|
||||
// Verify cookie IS Secure (X-Forwarded-Proto: https)
|
||||
for _, c := range cookies {
|
||||
if c.Name == "_gorilla_csrf" {
|
||||
assert.True(t, c.Secure, "CSRF cookie should be Secure behind TLS proxy")
|
||||
if c.Name == csrfCookieName {
|
||||
assert.True(t, c.Secure,
|
||||
"CSRF cookie should be Secure behind "+
|
||||
"TLS proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// POST with valid token, HTTPS Origin (as a browser behind proxy would send)
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "http://example.com/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "http://example.com/form", nil,
|
||||
)
|
||||
postReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
postReq.Header.Set("Origin", "https://example.com")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
called, code := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
assert.True(t, called, "handler should be called — prod mode behind TLS proxy must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
||||
assert.True(t, called,
|
||||
"handler should be called -- prod mode behind "+
|
||||
"TLS proxy must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, code,
|
||||
"should not return 403")
|
||||
}
|
||||
|
||||
// --- Production Mode: direct TLS ---
|
||||
|
||||
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(t *testing.T) {
|
||||
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
var token string
|
||||
csrfMiddleware := m.CSRF()
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = CSRFToken(r)
|
||||
}))
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "https://example.com/form", nil)
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "https://example.com/form", nil,
|
||||
)
|
||||
getReq.TLS = &tls.ConnectionState{}
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
||||
require.NotEmpty(t, token, "CSRF token should be set in context")
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
// Verify the cookie IS Secure (direct TLS)
|
||||
// Verify cookie IS Secure (direct TLS)
|
||||
for _, c := range cookies {
|
||||
if c.Name == "_gorilla_csrf" {
|
||||
assert.True(t, c.Secure, "CSRF cookie should be Secure over direct TLS")
|
||||
if c.Name == csrfCookieName {
|
||||
assert.True(t, c.Secure,
|
||||
"CSRF cookie should be Secure over "+
|
||||
"direct TLS")
|
||||
}
|
||||
}
|
||||
|
||||
// POST with valid token over direct TLS
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "https://example.com/form", strings.NewReader(form.Encode()))
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "https://example.com/form", nil,
|
||||
)
|
||||
postReq.TLS = &tls.ConnectionState{}
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postReq.Header.Set("Origin", "https://example.com")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
called, code := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
assert.True(t, called, "handler should be called — direct TLS must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
||||
assert.True(t, called,
|
||||
"handler should be called -- direct TLS must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, code,
|
||||
"should not return 403")
|
||||
}
|
||||
|
||||
// --- Production Mode: POST without token still rejects ---
|
||||
|
||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(t *testing.T) {
|
||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
|
||||
csrfMiddleware := m.CSRF()
|
||||
|
||||
// GET to establish the CSRF cookie
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST without CSRF token — should be rejected
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, "handler should NOT be called without CSRF token even in prod+plaintext")
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
csrfPOSTWithoutTokenTest(
|
||||
t,
|
||||
config.EnvironmentProd,
|
||||
"handler should NOT be called without CSRF token "+
|
||||
"even in prod+plaintext",
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user