From 52ae9a1f1c1a6aee88e915c99fb0f2281a37de20 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 17 Mar 2026 05:28:54 -0700 Subject: [PATCH] fix: detect TLS per-request in CSRF middleware to fix login MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CSRF middleware previously tied its PlaintextHTTPRequest wrapping and cookie Secure flag to the IsDev() environment check. This meant production mode always assumed HTTPS, which broke login in two common deployment scenarios: 1. Production behind a TLS-terminating reverse proxy: gorilla/csrf assumed HTTPS but r.TLS was nil, causing Origin/Referer scheme mismatches and 'referer not supplied' errors. 2. Production over direct HTTP (testing/development with prod config): the Secure cookie flag prevented the browser from sending the CSRF cookie back on POST, causing 'CSRF token invalid' errors. The fix detects the actual transport protocol per-request using r.TLS (direct TLS) and the X-Forwarded-Proto header (reverse proxy). Two gorilla/csrf instances are maintained — one with Secure cookies for TLS and one without for plaintext — since the csrf.Secure option is set at creation time. Both instances share the same signing key, so cookies are interchangeable between them. Behavior after fix: - Direct TLS: Secure cookies, strict Origin/Referer checks - Behind TLS proxy (X-Forwarded-Proto: https): same as direct TLS - Plaintext HTTP: non-Secure cookies, relaxed Origin/Referer checks (csrf.PlaintextHTTPRequest), token validation still enforced Closes #53 --- README.md | 19 ++- internal/middleware/csrf.go | 84 ++++++++---- internal/middleware/csrf_test.go | 215 +++++++++++++++++++++++++++++++ 3 files changed, 289 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 50faf6c..3713b1c 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,21 @@ or `prod` (default: `dev`). The setting controls several behaviors: | CORS | Allows any origin (`*`) | Disabled (no-op) | | Session cookie Secure | `false` (works over plain HTTP) | `true` (requires HTTPS) | +The CSRF cookie's `Secure` flag and Origin/Referer validation mode are +determined per-request based on the actual transport protocol, not the +environment setting. The middleware checks `r.TLS` (direct TLS) and the +`X-Forwarded-Proto` header (TLS-terminating reverse proxy) to decide: + +- **Direct TLS or `X-Forwarded-Proto: https`**: Secure cookies, strict + Origin/Referer validation. +- **Plaintext HTTP**: Non-Secure cookies, relaxed Origin/Referer + checks (token validation still enforced). + +This means CSRF protection works correctly in all deployment scenarios: +behind a TLS-terminating reverse proxy, with direct TLS, or over plain +HTTP during development. When running behind a reverse proxy, ensure it +sets the `X-Forwarded-Proto: https` header. + All other differences (log format, security headers, etc.) are independent of the environment setting — log format is determined by TTY detection, and security headers are always applied. @@ -841,7 +856,9 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a on all state-changing forms (cookie-based double-submit tokens with HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and `/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and - `/api` (stateless API) + `/api` (stateless API). The middleware auto-detects TLS status + per-request (via `r.TLS` and `X-Forwarded-Proto`) to set appropriate + cookie security flags and Origin/Referer validation mode - **SSRF prevention** for HTTP delivery targets: private/reserved IP ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked both at target creation time (URL validation) and at delivery time diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index df2ffb2..b9e62f6 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -12,6 +12,13 @@ func CSRFToken(r *http.Request) string { return csrf.Token(r) } +// isClientTLS reports whether the client-facing connection uses TLS. +// It checks for a direct TLS connection (r.TLS) or a TLS-terminating +// reverse proxy that sets the standard X-Forwarded-Proto header. +func isClientTLS(r *http.Request) bool { + return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" +} + // CSRF returns middleware that provides CSRF protection using the // gorilla/csrf library. The middleware uses the session authentication // key to sign a CSRF cookie and validates a masked token submitted via @@ -19,38 +26,59 @@ func CSRFToken(r *http.Request) string { // POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing // token receive a 403 Forbidden response. // -// In development mode, requests are marked as plaintext HTTP so that -// gorilla/csrf skips the strict Referer-origin check (which is only -// meaningful over TLS). +// The middleware detects the client-facing transport protocol per-request +// using r.TLS and the X-Forwarded-Proto header. This allows correct +// behavior in all deployment scenarios: +// +// - Direct HTTPS: strict Referer/Origin checks, Secure cookies. +// - Behind a TLS-terminating reverse proxy: strict checks (the +// browser is on HTTPS, so Origin/Referer headers use https://), +// Secure cookies (the browser sees HTTPS from the proxy). +// - Direct HTTP: relaxed Referer/Origin checks via PlaintextHTTPRequest, +// non-Secure cookies so the browser sends them over HTTP. +// +// Two gorilla/csrf instances are maintained — one with Secure cookies +// (for TLS) and one without (for plaintext HTTP) — because the +// csrf.Secure option is set at creation time, not per-request. func (m *Middleware) CSRF() func(http.Handler) http.Handler { - protect := csrf.Protect( - m.session.GetKey(), + csrfErrorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.log.Warn("csrf: token validation failed", + "method", r.Method, + "path", r.URL.Path, + "remote_addr", r.RemoteAddr, + "reason", csrf.FailureReason(r), + ) + http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) + }) + + key := m.session.GetKey() + baseOpts := []csrf.Option{ csrf.FieldName("csrf_token"), - csrf.Secure(!m.params.Config.IsDev()), csrf.SameSite(csrf.SameSiteLaxMode), csrf.Path("/"), - csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - m.log.Warn("csrf: token validation failed", - "method", r.Method, - "path", r.URL.Path, - "remote_addr", r.RemoteAddr, - "reason", csrf.FailureReason(r), - ) - http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) - })), - ) - - // In development (plaintext HTTP), signal gorilla/csrf to skip - // the strict TLS Referer check by injecting the PlaintextHTTP - // context key before the CSRF handler sees the request. - if m.params.Config.IsDev() { - return func(next http.Handler) http.Handler { - csrfHandler := protect(next) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r)) - }) - } + csrf.ErrorHandler(csrfErrorHandler), } - return protect + // Two middleware instances with different Secure flags but the + // same signing key, so cookies are interchangeable between them. + tlsProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(true))...) + httpProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(false))...) + + return func(next http.Handler) http.Handler { + tlsCSRF := tlsProtect(next) + httpCSRF := httpProtect(next) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isClientTLS(r) { + // Client is on TLS (directly or via reverse proxy). + // Use Secure cookies and strict Origin/Referer checks. + tlsCSRF.ServeHTTP(w, r) + } else { + // Plaintext HTTP: use non-Secure cookies and tell + // gorilla/csrf to use "http" for scheme comparisons, + // skipping the strict Referer check that assumes TLS. + httpCSRF.ServeHTTP(w, csrf.PlaintextHTTPRequest(r)) + } + }) + } } diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go index ae1791d..ca9d0d8 100644 --- a/internal/middleware/csrf_test.go +++ b/internal/middleware/csrf_test.go @@ -1,6 +1,7 @@ package middleware import ( + "crypto/tls" "net/http" "net/http/httptest" "net/url" @@ -155,3 +156,217 @@ func TestCSRFToken_NoMiddleware(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run") } + +// --- TLS Detection Tests --- + +func TestIsClientTLS_DirectTLS(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.TLS = &tls.ConnectionState{} // simulate direct TLS + assert.True(t, isClientTLS(r), "should detect direct TLS connection") +} + +func TestIsClientTLS_XForwardedProto(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("X-Forwarded-Proto", "https") + assert.True(t, isClientTLS(r), "should detect TLS via X-Forwarded-Proto") +} + +func TestIsClientTLS_PlaintextHTTP(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) + assert.False(t, isClientTLS(r), "should detect plaintext HTTP") +} + +func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("X-Forwarded-Proto", "http") + assert.False(t, isClientTLS(r), "should detect plaintext when X-Forwarded-Proto is http") +} + +// --- Production Mode: POST over plaintext HTTP --- + +func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentProd) + + // This tests the critical fix: prod mode over plaintext HTTP should + // work because the middleware detects the transport per-request. + var token string + csrfMiddleware := m.CSRF() + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + token = CSRFToken(r) + })) + + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + + cookies := getW.Result().Cookies() + require.NotEmpty(t, cookies, "CSRF cookie should be set on GET") + require.NotEmpty(t, token, "CSRF token should be set in context on GET") + + // Verify the cookie is NOT Secure (plaintext HTTP in prod mode) + for _, c := range cookies { + if c.Name == "_gorilla_csrf" { + assert.False(t, c.Secure, "CSRF cookie should not be Secure over plaintext HTTP") + } + } + + // POST with valid token — should succeed + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + form := url.Values{"csrf_token": {token}} + postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.True(t, called, "handler should be called — prod mode over plaintext HTTP must work") + assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403") +} + +// --- Production Mode: POST with X-Forwarded-Proto (reverse proxy) --- + +func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentProd) + + // Simulates a deployment behind a TLS-terminating reverse proxy. + // The Go server sees HTTP but X-Forwarded-Proto is "https". + var token string + csrfMiddleware := m.CSRF() + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + token = CSRFToken(r) + })) + + getReq := httptest.NewRequest(http.MethodGet, "http://example.com/form", nil) + getReq.Header.Set("X-Forwarded-Proto", "https") + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + + cookies := getW.Result().Cookies() + require.NotEmpty(t, cookies, "CSRF cookie should be set on GET") + require.NotEmpty(t, token, "CSRF token should be set in context") + + // Verify the cookie IS Secure (X-Forwarded-Proto: https) + for _, c := range cookies { + if c.Name == "_gorilla_csrf" { + assert.True(t, c.Secure, "CSRF cookie should be Secure behind TLS proxy") + } + } + + // POST with valid token, HTTPS Origin (as a browser behind proxy would send) + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + form := url.Values{"csrf_token": {token}} + postReq := httptest.NewRequest(http.MethodPost, "http://example.com/form", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + postReq.Header.Set("X-Forwarded-Proto", "https") + postReq.Header.Set("Origin", "https://example.com") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.True(t, called, "handler should be called — prod mode behind TLS proxy must work") + assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403") +} + +// --- Production Mode: direct TLS --- + +func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentProd) + + var token string + csrfMiddleware := m.CSRF() + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + token = CSRFToken(r) + })) + + getReq := httptest.NewRequest(http.MethodGet, "https://example.com/form", nil) + getReq.TLS = &tls.ConnectionState{} + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + + cookies := getW.Result().Cookies() + require.NotEmpty(t, cookies, "CSRF cookie should be set on GET") + require.NotEmpty(t, token, "CSRF token should be set in context") + + // Verify the cookie IS Secure (direct TLS) + for _, c := range cookies { + if c.Name == "_gorilla_csrf" { + assert.True(t, c.Secure, "CSRF cookie should be Secure over direct TLS") + } + } + + // POST with valid token over direct TLS + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + form := url.Values{"csrf_token": {token}} + postReq := httptest.NewRequest(http.MethodPost, "https://example.com/form", strings.NewReader(form.Encode())) + postReq.TLS = &tls.ConnectionState{} + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + postReq.Header.Set("Origin", "https://example.com") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.True(t, called, "handler should be called — direct TLS must work") + assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403") +} + +// --- Production Mode: POST without token still rejects --- + +func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentProd) + + csrfMiddleware := m.CSRF() + + // GET to establish the CSRF cookie + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + cookies := getW.Result().Cookies() + + // POST without CSRF token — should be rejected + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + postReq := httptest.NewRequest(http.MethodPost, "/form", nil) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.False(t, called, "handler should NOT be called without CSRF token even in prod+plaintext") + assert.Equal(t, http.StatusForbidden, postW.Code) +}