diff --git a/internal/middleware/apicsrf_test.go b/internal/middleware/apicsrf_test.go new file mode 100644 index 0000000..e3efaf2 --- /dev/null +++ b/internal/middleware/apicsrf_test.go @@ -0,0 +1,79 @@ +package middleware //nolint:testpackage // tests internal middleware behavior + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func newAPICSRFTestMiddleware() *Middleware { + return newCORSTestMiddleware("") +} + +func TestAPICSRFProtection_SafeMethods_Allowed(t *testing.T) { + t.Parallel() + + m := newAPICSRFTestMiddleware() + handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { + req := httptest.NewRequest(method, "/api/v1/apps", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, + "expected %s to be allowed without X-Requested-With", method) + } +} + +func TestAPICSRFProtection_POST_WithoutHeader_Blocked(t *testing.T) { + t.Parallel() + + m := newAPICSRFTestMiddleware() + handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Contains(t, rec.Body.String(), "X-Requested-With") +} + +func TestAPICSRFProtection_POST_WithHeader_Allowed(t *testing.T) { + t.Parallel() + + m := newAPICSRFTestMiddleware() + handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil) + req.Header.Set("X-Requested-With", "XMLHttpRequest") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestAPICSRFProtection_DELETE_WithoutHeader_Blocked(t *testing.T) { + t.Parallel() + + m := newAPICSRFTestMiddleware() + handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodDelete, "/api/v1/apps/1", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code) +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 4a57c31..78704be 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -193,7 +193,7 @@ func (m *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: origins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "X-Requested-With"}, ExposedHeaders: []string{"Link"}, AllowCredentials: true, MaxAge: corsMaxAge, @@ -370,6 +370,38 @@ func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { } } +// APICSRFProtection returns middleware that requires a custom header on +// state-changing API requests (POST, PUT, DELETE, PATCH) to prevent CSRF. +// Browsers will not send custom headers on cross-origin requests without a +// CORS preflight, which effectively blocks CSRF attacks via cookie-based auth. +// Safe methods (GET, HEAD, OPTIONS) are allowed without the header. +func (m *Middleware) APICSRFProtection() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func( + writer http.ResponseWriter, + request *http.Request, + ) { + switch request.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + // Safe methods — no CSRF risk. + default: + if request.Header.Get("X-Requested-With") == "" { + writer.Header().Set("Content-Type", "application/json") + http.Error( + writer, + `{"error":"missing X-Requested-With header"}`, + http.StatusForbidden, + ) + + return + } + } + + next.ServeHTTP(writer, request) + }) + } +} + // APISessionAuth returns middleware that requires session authentication for API routes. // Unlike SessionAuth, it returns JSON 401 responses instead of redirecting to /login. func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler { diff --git a/internal/server/routes.go b/internal/server/routes.go index 3339bb8..95e89b2 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -102,8 +102,10 @@ func (s *Server) SetupRoutes() { }) }) - // API v1 routes (cookie-based session auth, no CSRF) + // API v1 routes (cookie-based session auth, CSRF protected via custom header) s.router.Route("/api/v1", func(r chi.Router) { + r.Use(s.mw.APICSRFProtection()) + // Login endpoint is public (returns session cookie) r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleAPILoginPOST()) diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index 726c2c0..3beaa31 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -73,7 +73,7 @@ func New(_ fx.Lifecycle, params ServiceParams) (*Service, error) { MaxAge: sessionMaxAgeSeconds, HttpOnly: true, Secure: !params.Config.Debug, - SameSite: http.SameSiteLaxMode, + SameSite: http.SameSiteStrictMode, } return &Service{