All checks were successful
check / check (push) Successful in 2m0s
The CSRF middleware previously tied its PlaintextHTTPRequest wrapping and cookie Secure flag to the IsDev() environment check. This meant production mode always assumed HTTPS, which broke login in two common deployment scenarios: 1. Production behind a TLS-terminating reverse proxy: gorilla/csrf assumed HTTPS but r.TLS was nil, causing Origin/Referer scheme mismatches and 'referer not supplied' errors. 2. Production over direct HTTP (testing/development with prod config): the Secure cookie flag prevented the browser from sending the CSRF cookie back on POST, causing 'CSRF token invalid' errors. The fix detects the actual transport protocol per-request using r.TLS (direct TLS) and the X-Forwarded-Proto header (reverse proxy). Two gorilla/csrf instances are maintained — one with Secure cookies for TLS and one without for plaintext — since the csrf.Secure option is set at creation time. Both instances share the same signing key, so cookies are interchangeable between them. Behavior after fix: - Direct TLS: Secure cookies, strict Origin/Referer checks - Behind TLS proxy (X-Forwarded-Proto: https): same as direct TLS - Plaintext HTTP: non-Secure cookies, relaxed Origin/Referer checks (csrf.PlaintextHTTPRequest), token validation still enforced Closes #53
373 lines
12 KiB
Go
373 lines
12 KiB
Go
package middleware
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"sneak.berlin/go/webhooker/internal/config"
|
|
)
|
|
|
|
func TestCSRF_GETSetsToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
|
|
|
var gotToken string
|
|
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
gotToken = CSRFToken(r)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
|
}
|
|
|
|
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
|
|
|
// Capture the token from a GET request
|
|
var token string
|
|
csrfMiddleware := m.CSRF()
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
token = CSRFToken(r)
|
|
}))
|
|
|
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
|
|
cookies := getW.Result().Cookies()
|
|
require.NotEmpty(t, cookies)
|
|
require.NotEmpty(t, token)
|
|
|
|
// POST with valid token and cookies from the GET response
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
form := url.Values{"csrf_token": {token}}
|
|
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.True(t, called, "handler should be called with valid CSRF token")
|
|
}
|
|
|
|
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
|
|
|
csrfMiddleware := m.CSRF()
|
|
|
|
// GET to establish the CSRF cookie
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
cookies := getW.Result().Cookies()
|
|
|
|
// POST without CSRF token
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.False(t, called, "handler should NOT be called without CSRF token")
|
|
assert.Equal(t, http.StatusForbidden, postW.Code)
|
|
}
|
|
|
|
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
|
|
|
csrfMiddleware := m.CSRF()
|
|
|
|
// GET to establish the CSRF cookie
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
cookies := getW.Result().Cookies()
|
|
|
|
// POST with wrong CSRF token
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
|
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.False(t, called, "handler should NOT be called with invalid CSRF token")
|
|
assert.Equal(t, http.StatusForbidden, postW.Code)
|
|
}
|
|
|
|
func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
|
|
|
var called bool
|
|
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
// GET requests should pass through without CSRF validation
|
|
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
|
}
|
|
|
|
func TestCSRFToken_NoMiddleware(t *testing.T) {
|
|
t.Parallel()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run")
|
|
}
|
|
|
|
// --- TLS Detection Tests ---
|
|
|
|
func TestIsClientTLS_DirectTLS(t *testing.T) {
|
|
t.Parallel()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r.TLS = &tls.ConnectionState{} // simulate direct TLS
|
|
assert.True(t, isClientTLS(r), "should detect direct TLS connection")
|
|
}
|
|
|
|
func TestIsClientTLS_XForwardedProto(t *testing.T) {
|
|
t.Parallel()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r.Header.Set("X-Forwarded-Proto", "https")
|
|
assert.True(t, isClientTLS(r), "should detect TLS via X-Forwarded-Proto")
|
|
}
|
|
|
|
func TestIsClientTLS_PlaintextHTTP(t *testing.T) {
|
|
t.Parallel()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
assert.False(t, isClientTLS(r), "should detect plaintext HTTP")
|
|
}
|
|
|
|
func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) {
|
|
t.Parallel()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r.Header.Set("X-Forwarded-Proto", "http")
|
|
assert.False(t, isClientTLS(r), "should detect plaintext when X-Forwarded-Proto is http")
|
|
}
|
|
|
|
// --- Production Mode: POST over plaintext HTTP ---
|
|
|
|
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
|
|
|
// This tests the critical fix: prod mode over plaintext HTTP should
|
|
// work because the middleware detects the transport per-request.
|
|
var token string
|
|
csrfMiddleware := m.CSRF()
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
token = CSRFToken(r)
|
|
}))
|
|
|
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
|
|
cookies := getW.Result().Cookies()
|
|
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
|
require.NotEmpty(t, token, "CSRF token should be set in context on GET")
|
|
|
|
// Verify the cookie is NOT Secure (plaintext HTTP in prod mode)
|
|
for _, c := range cookies {
|
|
if c.Name == "_gorilla_csrf" {
|
|
assert.False(t, c.Secure, "CSRF cookie should not be Secure over plaintext HTTP")
|
|
}
|
|
}
|
|
|
|
// POST with valid token — should succeed
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
form := url.Values{"csrf_token": {token}}
|
|
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.True(t, called, "handler should be called — prod mode over plaintext HTTP must work")
|
|
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
|
}
|
|
|
|
// --- Production Mode: POST with X-Forwarded-Proto (reverse proxy) ---
|
|
|
|
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
|
|
|
// Simulates a deployment behind a TLS-terminating reverse proxy.
|
|
// The Go server sees HTTP but X-Forwarded-Proto is "https".
|
|
var token string
|
|
csrfMiddleware := m.CSRF()
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
token = CSRFToken(r)
|
|
}))
|
|
|
|
getReq := httptest.NewRequest(http.MethodGet, "http://example.com/form", nil)
|
|
getReq.Header.Set("X-Forwarded-Proto", "https")
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
|
|
cookies := getW.Result().Cookies()
|
|
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
|
require.NotEmpty(t, token, "CSRF token should be set in context")
|
|
|
|
// Verify the cookie IS Secure (X-Forwarded-Proto: https)
|
|
for _, c := range cookies {
|
|
if c.Name == "_gorilla_csrf" {
|
|
assert.True(t, c.Secure, "CSRF cookie should be Secure behind TLS proxy")
|
|
}
|
|
}
|
|
|
|
// POST with valid token, HTTPS Origin (as a browser behind proxy would send)
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
form := url.Values{"csrf_token": {token}}
|
|
postReq := httptest.NewRequest(http.MethodPost, "http://example.com/form", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
postReq.Header.Set("X-Forwarded-Proto", "https")
|
|
postReq.Header.Set("Origin", "https://example.com")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.True(t, called, "handler should be called — prod mode behind TLS proxy must work")
|
|
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
|
}
|
|
|
|
// --- Production Mode: direct TLS ---
|
|
|
|
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
|
|
|
var token string
|
|
csrfMiddleware := m.CSRF()
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
token = CSRFToken(r)
|
|
}))
|
|
|
|
getReq := httptest.NewRequest(http.MethodGet, "https://example.com/form", nil)
|
|
getReq.TLS = &tls.ConnectionState{}
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
|
|
cookies := getW.Result().Cookies()
|
|
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
|
require.NotEmpty(t, token, "CSRF token should be set in context")
|
|
|
|
// Verify the cookie IS Secure (direct TLS)
|
|
for _, c := range cookies {
|
|
if c.Name == "_gorilla_csrf" {
|
|
assert.True(t, c.Secure, "CSRF cookie should be Secure over direct TLS")
|
|
}
|
|
}
|
|
|
|
// POST with valid token over direct TLS
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
form := url.Values{"csrf_token": {token}}
|
|
postReq := httptest.NewRequest(http.MethodPost, "https://example.com/form", strings.NewReader(form.Encode()))
|
|
postReq.TLS = &tls.ConnectionState{}
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
postReq.Header.Set("Origin", "https://example.com")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.True(t, called, "handler should be called — direct TLS must work")
|
|
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
|
}
|
|
|
|
// --- Production Mode: POST without token still rejects ---
|
|
|
|
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(t *testing.T) {
|
|
t.Parallel()
|
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
|
|
|
csrfMiddleware := m.CSRF()
|
|
|
|
// GET to establish the CSRF cookie
|
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
|
getW := httptest.NewRecorder()
|
|
getHandler.ServeHTTP(getW, getReq)
|
|
cookies := getW.Result().Cookies()
|
|
|
|
// POST without CSRF token — should be rejected
|
|
var called bool
|
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
}))
|
|
|
|
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
postReq.AddCookie(c)
|
|
}
|
|
postW := httptest.NewRecorder()
|
|
|
|
postHandler.ServeHTTP(postW, postReq)
|
|
|
|
assert.False(t, called, "handler should NOT be called without CSRF token even in prod+plaintext")
|
|
assert.Equal(t, http.StatusForbidden, postW.Code)
|
|
}
|