package middleware import ( "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/webhooker/internal/config" ) func TestCSRF_GETSetsToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var gotToken string handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { gotToken = CSRFToken(r) })) req := httptest.NewRequest(http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") } func TestCSRF_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) // Capture the token from a GET request var token string csrfMiddleware := m.CSRF() getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies) require.NotEmpty(t, token) // POST with valid token and cookies from the GET response var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {token}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.True(t, called, "handler should be called with valid CSRF token") } func TestCSRF_POSTWithoutToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) csrfMiddleware := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST without CSRF token var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) postReq := httptest.NewRequest(http.MethodPost, "/form", nil) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called without CSRF token") assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_POSTWithInvalidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) csrfMiddleware := m.CSRF() // GET to establish the CSRF cookie getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST with wrong CSRF token var called bool postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) form := url.Values{"csrf_token": {"invalid-token-value"}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called with invalid CSRF token") assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_GETDoesNotValidate(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var called bool handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) // GET requests should pass through without CSRF validation req := httptest.NewRequest(http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.True(t, called, "GET requests should pass through CSRF middleware") } func TestCSRFToken_NoMiddleware(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run") }