package middleware //nolint:testpackage // tests internal CSRF behavior import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func TestAPICSRFProtection_BlocksPostWithoutHeader(t *testing.T) { t.Parallel() mw := &Middleware{} handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "X-Requested-With") } func TestAPICSRFProtection_AllowsPostWithHeader(t *testing.T) { t.Parallel() mw := &Middleware{} handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) req.Header.Set("X-Requested-With", "XMLHttpRequest") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) } func TestAPICSRFProtection_AllowsGetWithoutHeader(t *testing.T) { t.Parallel() mw := &Middleware{} handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api/v1/apps", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) } func TestAPICSRFProtection_BlocksDeleteWithoutHeader(t *testing.T) { t.Parallel() mw := &Middleware{} handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodDelete, "/api/v1/apps/123", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) } func TestAPICSRFProtection_AllowsHeadWithoutHeader(t *testing.T) { t.Parallel() mw := &Middleware{} handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodHead, "/api/v1/apps", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) }