diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index 6d7d899..3deaff7 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -17,6 +17,7 @@ func apiRouter(tc *testContext) http.Handler { r := chi.NewRouter() r.Route("/api/v1", func(apiR chi.Router) { + apiR.Use(tc.middleware.APICSRFProtection()) apiR.Post("/login", tc.handlers.HandleAPILoginPOST()) apiR.Group(func(apiR chi.Router) { @@ -50,6 +51,7 @@ func setupAPITest(t *testing.T) (*testContext, []*http.Cookie) { loginBody := `{"username":"admin","password":"password123"}` req := httptest.NewRequest(http.MethodPost, "/api/v1/login", strings.NewReader(loginBody)) req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Requested-With", "XMLHttpRequest") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) @@ -80,6 +82,8 @@ func apiRequest( req = httptest.NewRequest(method, path, nil) } + req.Header.Set("X-Requested-With", "XMLHttpRequest") + for _, c := range cookies { req.AddCookie(c) } @@ -105,6 +109,7 @@ func TestAPILoginSuccess(t *testing.T) { body := `{"username":"admin","password":"password123"}` req := httptest.NewRequest(http.MethodPost, "/api/v1/login", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Requested-With", "XMLHttpRequest") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) @@ -132,6 +137,7 @@ func TestAPILoginInvalidCredentials(t *testing.T) { body := `{"username":"admin","password":"wrong"}` req := httptest.NewRequest(http.MethodPost, "/api/v1/login", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Requested-With", "XMLHttpRequest") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) @@ -149,6 +155,7 @@ func TestAPILoginMissingFields(t *testing.T) { body := `{"username":"","password":""}` req := httptest.NewRequest(http.MethodPost, "/api/v1/login", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Requested-With", "XMLHttpRequest") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) @@ -297,3 +304,46 @@ func TestAPIListDeployments(t *testing.T) { require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &deployments)) assert.Empty(t, deployments) } + +func TestAPICSRFRejectsPostWithoutHeader(t *testing.T) { + t.Parallel() + + tc, cookies := setupAPITest(t) + + r := apiRouter(tc) + + // POST without X-Requested-With header should be rejected. + body := `{"name":"csrf-test","repoUrl":"https://example.com/repo"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + for _, c := range cookies { + req.AddCookie(c) + } + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), "X-Requested-With") +} + +func TestAPICSRFAllowsGetWithoutHeader(t *testing.T) { + t.Parallel() + + tc, cookies := setupAPITest(t) + + r := apiRouter(tc) + + // GET without X-Requested-With should still work. + req := httptest.NewRequest(http.MethodGet, "/api/v1/apps", nil) + + for _, c := range cookies { + req.AddCookie(c) + } + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +} diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go new file mode 100644 index 0000000..4f533d3 --- /dev/null +++ b/internal/middleware/csrf_test.go @@ -0,0 +1,87 @@ +package middleware //nolint:testpackage // tests internal CSRF behavior + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAPICSRFProtection_BlocksPostWithoutHeader(t *testing.T) { + t.Parallel() + + mw := &Middleware{} + handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), "X-Requested-With") +} + +func TestAPICSRFProtection_AllowsPostWithHeader(t *testing.T) { + t.Parallel() + + mw := &Middleware{} + handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) + req.Header.Set("X-Requested-With", "XMLHttpRequest") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestAPICSRFProtection_AllowsGetWithoutHeader(t *testing.T) { + t.Parallel() + + mw := &Middleware{} + handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/apps", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestAPICSRFProtection_BlocksDeleteWithoutHeader(t *testing.T) { + t.Parallel() + + mw := &Middleware{} + handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodDelete, "/api/v1/apps/123", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code) +} + +func TestAPICSRFProtection_AllowsHeadWithoutHeader(t *testing.T) { + t.Parallel() + + mw := &Middleware{} + handler := mw.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodHead, "/api/v1/apps", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 4a57c31..7bae051 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -193,7 +193,7 @@ func (m *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: origins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "X-Requested-With"}, ExposedHeaders: []string{"Link"}, AllowCredentials: true, MaxAge: corsMaxAge, @@ -370,6 +370,43 @@ func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { } } +// apiCSRFSafeMethods are HTTP methods that do not require CSRF protection. +// +//nolint:gochecknoglobals // package-level constant set +var apiCSRFSafeMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodOptions: true, +} + +// APICSRFProtection returns middleware that protects API routes against CSRF attacks +// by requiring a custom header (X-Requested-With) on all state-changing requests. +// Browsers will not send custom headers in cross-origin simple requests (form posts, +// navigations), so this blocks CSRF attacks without requiring tokens. +func (m *Middleware) APICSRFProtection() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func( + writer http.ResponseWriter, + request *http.Request, + ) { + if !apiCSRFSafeMethods[request.Method] { + if request.Header.Get("X-Requested-With") == "" { + writer.Header().Set("Content-Type", "application/json") + http.Error( + writer, + `{"error":"missing X-Requested-With header"}`, + http.StatusForbidden, + ) + + return + } + } + + next.ServeHTTP(writer, request) + }) + } +} + // APISessionAuth returns middleware that requires session authentication for API routes. // Unlike SessionAuth, it returns JSON 401 responses instead of redirecting to /login. func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler { diff --git a/internal/server/routes.go b/internal/server/routes.go index 3339bb8..45348a3 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -102,8 +102,11 @@ func (s *Server) SetupRoutes() { }) }) - // API v1 routes (cookie-based session auth, no CSRF) + // API v1 routes (cookie-based session auth with CSRF protection) s.router.Route("/api/v1", func(r chi.Router) { + // CSRF protection: require X-Requested-With header on state-changing requests + r.Use(s.mw.APICSRFProtection()) + // Login endpoint is public (returns session cookie) r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleAPILoginPOST())