package middleware import ( "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/webhooker/internal/config" ) func TestCSRF_GETSetsToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var gotToken string handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { gotToken = CSRFToken(r) })) req := httptest.NewRequest(http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") assert.Len(t, gotToken, csrfTokenLength*2, "CSRF token should be hex-encoded 32 bytes") } func TestCSRF_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) // Use a separate handler for the GET to capture the token var token string getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) // GET to establish the session and capture token getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() require.NotEmpty(t, cookies) require.NotEmpty(t, token) // POST handler that tracks whether it was called var called bool postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) // POST with valid token form := url.Values{csrfFormField: {token}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.True(t, called, "handler should be called with valid CSRF token") } func TestCSRF_POSTWithoutToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) // GET handler to establish session getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { // no-op — just establishes session })) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST handler that tracks whether it was called var called bool postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) // POST without CSRF token postReq := httptest.NewRequest(http.MethodPost, "/form", nil) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called without CSRF token") assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_POSTWithInvalidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) // GET handler to establish session getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { // no-op — just establishes session })) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() // POST handler that tracks whether it was called var called bool postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) // POST with wrong CSRF token form := url.Values{csrfFormField: {"invalid-token-value"}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { postReq.AddCookie(c) } postW := httptest.NewRecorder() postHandler.ServeHTTP(postW, postReq) assert.False(t, called, "handler should NOT be called with invalid CSRF token") assert.Equal(t, http.StatusForbidden, postW.Code) } func TestCSRF_GETDoesNotValidate(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var called bool handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) // GET requests should pass through without CSRF validation req := httptest.NewRequest(http.MethodGet, "/form", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.True(t, called, "GET requests should pass through CSRF middleware") } func TestCSRFToken_NoContext(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when no token in context") } func TestGenerateCSRFToken(t *testing.T) { t.Parallel() token, err := generateCSRFToken() require.NoError(t, err) assert.Len(t, token, csrfTokenLength*2, "token should be hex-encoded") // Verify uniqueness token2, err := generateCSRFToken() require.NoError(t, err) assert.NotEqual(t, token, token2, "each generated token should be unique") } func TestSecureCompare(t *testing.T) { t.Parallel() assert.True(t, secureCompare("abc", "abc")) assert.False(t, secureCompare("abc", "abd")) assert.False(t, secureCompare("abc", "ab")) assert.False(t, secureCompare("", "a")) assert.True(t, secureCompare("", "")) }