package middleware //nolint:testpackage // tests internal middleware behavior import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func newAPICSRFTestMiddleware() *Middleware { return newCORSTestMiddleware("") } func TestAPICSRFProtection_SafeMethods_Allowed(t *testing.T) { t.Parallel() m := newAPICSRFTestMiddleware() handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { req := httptest.NewRequest(method, "/api/v1/apps", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "expected %s to be allowed without X-Requested-With", method) } } func TestAPICSRFProtection_POST_WithoutHeader_Blocked(t *testing.T) { t.Parallel() m := newAPICSRFTestMiddleware() handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusForbidden, rec.Code) assert.Contains(t, rec.Body.String(), "X-Requested-With") } func TestAPICSRFProtection_POST_WithHeader_Allowed(t *testing.T) { t.Parallel() m := newAPICSRFTestMiddleware() handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) req.Header.Set("X-Requested-With", "XMLHttpRequest") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) } func TestAPICSRFProtection_DELETE_WithoutHeader_Blocked(t *testing.T) { t.Parallel() m := newAPICSRFTestMiddleware() handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodDelete, "/api/v1/apps/1", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusForbidden, rec.Code) }