package middleware import ( "crypto/tls" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/webhooker/internal/config" ) func TestCSRF_GETSetsToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var gotToken string handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { gotToken = CSRFToken(r) })) req := httptest.NewRequest(http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") } func TestCSRF_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) // Capture the token from a GET request var token string csrfMiddleware := m.CSRF() getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies) require.NotEmpty(t, token) // POST with valid token and cookies from the GET response var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {token}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.True(t, called, "handler should be called with valid CSRF token") } func TestCSRF_POSTWithoutToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) csrfMiddleware := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST without CSRF token var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) postReq := httptest.NewRequest(http.MethodPost, "/form", nil) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called without CSRF token") assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_POSTWithInvalidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) csrfMiddleware := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST with wrong CSRF token var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {"invalid-token-value"}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called with invalid CSRF token") assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_GETDoesNotValidate(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var called bool handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) // GET requests should pass through without CSRF validation req := httptest.NewRequest(http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.True(t, called, "GET requests should pass through CSRF middleware") } func TestCSRFToken_NoMiddleware(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run") } // --- TLS Detection Tests --- func TestIsClientTLS_DirectTLS(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/", nil) r.TLS = &tls.ConnectionState{} // simulate direct TLS assert.True(t, isClientTLS(r), "should detect direct TLS connection") } func TestIsClientTLS_XForwardedProto(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-Proto", "https") assert.True(t, isClientTLS(r), "should detect TLS via X-Forwarded-Proto") } func TestIsClientTLS_PlaintextHTTP(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/", nil) assert.False(t, isClientTLS(r), "should detect plaintext HTTP") } func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-Proto", "http") assert.False(t, isClientTLS(r), "should detect plaintext when X-Forwarded-Proto is http") } // --- Production Mode: POST over plaintext HTTP --- func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) // This tests the critical fix: prod mode over plaintext HTTP should // work because the middleware detects the transport per-request. var token string csrfMiddleware := m.CSRF() getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies, "CSRF cookie should be set on GET") require.NotEmpty(t, token, "CSRF token should be set in context on GET") // Verify the cookie is NOT Secure (plaintext HTTP in prod mode) for _, c := range cookies { if c.Name == "_gorilla_csrf" { assert.False(t, c.Secure, "CSRF cookie should not be Secure over plaintext HTTP") } } // POST with valid token — should succeed var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {token}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.True(t, called, "handler should be called — prod mode over plaintext HTTP must work") assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403") } // --- Production Mode: POST with X-Forwarded-Proto (reverse proxy) --- func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) // Simulates a deployment behind a TLS-terminating reverse proxy. // The Go server sees HTTP but X-Forwarded-Proto is "https". var token string csrfMiddleware := m.CSRF() getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) getReq := httptest.NewRequest(http.MethodGet, "http://example.com/form", nil) getReq.Header.Set("X-Forwarded-Proto", "https") getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies, "CSRF cookie should be set on GET") require.NotEmpty(t, token, "CSRF token should be set in context") // Verify the cookie IS Secure (X-Forwarded-Proto: https) for _, c := range cookies { if c.Name == "_gorilla_csrf" { assert.True(t, c.Secure, "CSRF cookie should be Secure behind TLS proxy") } } // POST with valid token, HTTPS Origin (as a browser behind proxy would send) var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {token}} postReq := httptest.NewRequest(http.MethodPost, "http://example.com/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.Header.Set("X-Forwarded-Proto", "https") postReq.Header.Set("Origin", "https://example.com") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.True(t, called, "handler should be called — prod mode behind TLS proxy must work") assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403") } // --- Production Mode: direct TLS --- func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) var token string csrfMiddleware := m.CSRF() getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) getReq := httptest.NewRequest(http.MethodGet, "https://example.com/form", nil) getReq.TLS = &tls.ConnectionState{} getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies, "CSRF cookie should be set on GET") require.NotEmpty(t, token, "CSRF token should be set in context") // Verify the cookie IS Secure (direct TLS) for _, c := range cookies { if c.Name == "_gorilla_csrf" { assert.True(t, c.Secure, "CSRF cookie should be Secure over direct TLS") } } // POST with valid token over direct TLS var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {token}} postReq := httptest.NewRequest(http.MethodPost, "https://example.com/form", strings.NewReader(form.Encode())) postReq.TLS = &tls.ConnectionState{} postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.Header.Set("Origin", "https://example.com") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.True(t, called, "handler should be called — direct TLS must work") assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403") } // --- Production Mode: POST without token still rejects --- func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) csrfMiddleware := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST without CSRF token — should be rejected var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) postReq := httptest.NewRequest(http.MethodPost, "/form", nil) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called without CSRF token even in prod+plaintext") assert.Equal(t, http.StatusForbidden, postW.Code) }