package middleware_test import ( "context" "crypto/tls" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/webhooker/internal/config" "sneak.berlin/go/webhooker/internal/middleware" ) // csrfCookieName is the gorilla/csrf cookie name. const csrfCookieName = "_gorilla_csrf" // csrfGetToken performs a GET request through the CSRF middleware // and returns the token and cookies. func csrfGetToken( t *testing.T, csrfMW func(http.Handler) http.Handler, getReq *http.Request, ) (string, []*http.Cookie) { t.Helper() var token string getHandler := csrfMW(http.HandlerFunc( func(_ http.ResponseWriter, r *http.Request) { token = middleware.CSRFToken(r) }, )) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies, "CSRF cookie should be set") require.NotEmpty(t, token, "CSRF token should be set") return token, cookies } // csrfPostWithToken performs a POST request with the given CSRF // token and cookies through the middleware. Returns whether the // handler was called and the response code. func csrfPostWithToken( t *testing.T, csrfMW func(http.Handler) http.Handler, postReq *http.Request, token string, cookies []*http.Cookie, ) (bool, int) { t.Helper() var called bool postHandler := csrfMW(http.HandlerFunc( func(_ http.ResponseWriter, _ *http.Request) { called = true }, )) form := url.Values{"csrf_token": {token}} postReq.Body = http.NoBody postReq.Body = nil // Rebuild the request with the form body rebuilt := httptest.NewRequestWithContext( context.Background(), postReq.Method, postReq.URL.String(), strings.NewReader(form.Encode()), ) rebuilt.Header = postReq.Header.Clone() rebuilt.TLS = postReq.TLS rebuilt.Header.Set( "Content-Type", "application/x-www-form-urlencoded", ) for _, c := range cookies { rebuilt.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, rebuilt) return called, postW.Code } func TestCSRF_GETSetsToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var gotToken string handler := m.CSRF()(http.HandlerFunc( func(_ http.ResponseWriter, r *http.Request) { gotToken = middleware.CSRFToken(r) }, )) req := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.NotEmpty( t, gotToken, "CSRF token should be set in context on GET", ) } func TestCSRF_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) csrfMW := m.CSRF() getReq := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/form", nil, ) token, cookies := csrfGetToken(t, csrfMW, getReq) postReq := httptest.NewRequestWithContext( context.Background(), http.MethodPost, "/form", nil, ) called, _ := csrfPostWithToken( t, csrfMW, postReq, token, cookies, ) assert.True( t, called, "handler should be called with valid CSRF token", ) } // csrfPOSTWithoutTokenTest is a shared helper for testing POST // requests without a CSRF token in both dev and prod modes. func csrfPOSTWithoutTokenTest( t *testing.T, env string, msg string, ) { t.Helper() m, _ := testMiddleware(t, env) csrfMW := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMW(http.HandlerFunc( func(_ http.ResponseWriter, _ *http.Request) {}, )) getReq := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST without CSRF token var called bool postHandler := csrfMW(http.HandlerFunc( func(_ http.ResponseWriter, _ *http.Request) { called = true }, )) postReq := httptest.NewRequestWithContext( context.Background(), http.MethodPost, "/form", nil, ) postReq.Header.Set( "Content-Type", "application/x-www-form-urlencoded", ) for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, msg) assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_POSTWithoutToken(t *testing.T) { t.Parallel() csrfPOSTWithoutTokenTest( t, config.EnvironmentDev, "handler should NOT be called without CSRF token", ) } func TestCSRF_POSTWithInvalidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) csrfMW := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMW(http.HandlerFunc( func(_ http.ResponseWriter, _ *http.Request) {}, )) getReq := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST with wrong CSRF token var called bool postHandler := csrfMW(http.HandlerFunc( func(_ http.ResponseWriter, _ *http.Request) { called = true }, )) form := url.Values{"csrf_token": {"invalid-token-value"}} postReq := httptest.NewRequestWithContext( context.Background(), http.MethodPost, "/form", strings.NewReader(form.Encode()), ) postReq.Header.Set( "Content-Type", "application/x-www-form-urlencoded", ) for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False( t, called, "handler should NOT be called with invalid CSRF token", ) assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_GETDoesNotValidate(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var called bool handler := m.CSRF()(http.HandlerFunc( func(_ http.ResponseWriter, _ *http.Request) { called = true }, )) req := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.True( t, called, "GET requests should pass through CSRF middleware", ) } func TestCSRFToken_NoMiddleware(t *testing.T) { t.Parallel() req := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/", nil) assert.Empty( t, middleware.CSRFToken(req), "CSRFToken should return empty string when "+ "middleware has not run", ) } // --- TLS Detection Tests --- func TestIsClientTLS_DirectTLS(t *testing.T) { t.Parallel() r := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/", nil) r.TLS = &tls.ConnectionState{} assert.True( t, middleware.IsClientTLS(r), "should detect direct TLS connection", ) } func TestIsClientTLS_XForwardedProto(t *testing.T) { t.Parallel() r := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-Proto", "https") assert.True( t, middleware.IsClientTLS(r), "should detect TLS via X-Forwarded-Proto", ) } func TestIsClientTLS_PlaintextHTTP(t *testing.T) { t.Parallel() r := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/", nil) assert.False( t, middleware.IsClientTLS(r), "should detect plaintext HTTP", ) } func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) { t.Parallel() r := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-Proto", "http") assert.False( t, middleware.IsClientTLS(r), "should detect plaintext when X-Forwarded-Proto is http", ) } // --- Production Mode: POST over plaintext HTTP --- func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken( t *testing.T, ) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) csrfMW := m.CSRF() getReq := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "/form", nil, ) token, cookies := csrfGetToken(t, csrfMW, getReq) // Verify cookie is NOT Secure (plaintext HTTP in prod) for _, c := range cookies { if c.Name == csrfCookieName { assert.False(t, c.Secure, "CSRF cookie should not be Secure "+ "over plaintext HTTP") } } postReq := httptest.NewRequestWithContext( context.Background(), http.MethodPost, "/form", nil, ) called, code := csrfPostWithToken( t, csrfMW, postReq, token, cookies, ) assert.True(t, called, "handler should be called -- prod mode over "+ "plaintext HTTP must work") assert.NotEqual(t, http.StatusForbidden, code, "should not return 403") } // --- Production Mode: POST with X-Forwarded-Proto --- func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken( t *testing.T, ) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) csrfMW := m.CSRF() getReq := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "http://example.com/form", nil, ) getReq.Header.Set("X-Forwarded-Proto", "https") token, cookies := csrfGetToken(t, csrfMW, getReq) // Verify cookie IS Secure (X-Forwarded-Proto: https) for _, c := range cookies { if c.Name == csrfCookieName { assert.True(t, c.Secure, "CSRF cookie should be Secure behind "+ "TLS proxy") } } postReq := httptest.NewRequestWithContext( context.Background(), http.MethodPost, "http://example.com/form", nil, ) postReq.Header.Set("X-Forwarded-Proto", "https") postReq.Header.Set("Origin", "https://example.com") called, code := csrfPostWithToken( t, csrfMW, postReq, token, cookies, ) assert.True(t, called, "handler should be called -- prod mode behind "+ "TLS proxy must work") assert.NotEqual(t, http.StatusForbidden, code, "should not return 403") } // --- Production Mode: direct TLS --- func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken( t *testing.T, ) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentProd) csrfMW := m.CSRF() getReq := httptest.NewRequestWithContext( context.Background(), http.MethodGet, "https://example.com/form", nil, ) getReq.TLS = &tls.ConnectionState{} token, cookies := csrfGetToken(t, csrfMW, getReq) // Verify cookie IS Secure (direct TLS) for _, c := range cookies { if c.Name == csrfCookieName { assert.True(t, c.Secure, "CSRF cookie should be Secure over "+ "direct TLS") } } postReq := httptest.NewRequestWithContext( context.Background(), http.MethodPost, "https://example.com/form", nil, ) postReq.TLS = &tls.ConnectionState{} postReq.Header.Set("Origin", "https://example.com") called, code := csrfPostWithToken( t, csrfMW, postReq, token, cookies, ) assert.True(t, called, "handler should be called -- direct TLS must work") assert.NotEqual(t, http.StatusForbidden, code, "should not return 403") } // --- Production Mode: POST without token still rejects --- func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken( t *testing.T, ) { t.Parallel() csrfPOSTWithoutTokenTest( t, config.EnvironmentProd, "handler should NOT be called without CSRF token "+ "even in prod+plaintext", ) }