Add API CSRF protection via X-Requested-With header (closes #112) #116
79
internal/middleware/apicsrf_test.go
Normal file
79
internal/middleware/apicsrf_test.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package middleware //nolint:testpackage // tests internal middleware behavior
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAPICSRFTestMiddleware() *Middleware {
|
||||||
|
return newCORSTestMiddleware("")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICSRFProtection_SafeMethods_Allowed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
m := newAPICSRFTestMiddleware()
|
||||||
|
handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
|
||||||
|
req := httptest.NewRequest(method, "/api/v1/apps", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code,
|
||||||
|
"expected %s to be allowed without X-Requested-With", method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICSRFProtection_POST_WithoutHeader_Blocked(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
m := newAPICSRFTestMiddleware()
|
||||||
|
handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "X-Requested-With")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICSRFProtection_POST_WithHeader_Allowed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
m := newAPICSRFTestMiddleware()
|
||||||
|
handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/apps", nil)
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICSRFProtection_DELETE_WithoutHeader_Blocked(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
m := newAPICSRFTestMiddleware()
|
||||||
|
handler := m.APICSRFProtection()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/api/v1/apps/1", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
}
|
||||||
@ -193,7 +193,7 @@ func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
|||||||
return cors.Handler(cors.Options{
|
return cors.Handler(cors.Options{
|
||||||
AllowedOrigins: origins,
|
AllowedOrigins: origins,
|
||||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "X-Requested-With"},
|
||||||
ExposedHeaders: []string{"Link"},
|
ExposedHeaders: []string{"Link"},
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
MaxAge: corsMaxAge,
|
MaxAge: corsMaxAge,
|
||||||
@ -370,6 +370,38 @@ func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APICSRFProtection returns middleware that requires a custom header on
|
||||||
|
// state-changing API requests (POST, PUT, DELETE, PATCH) to prevent CSRF.
|
||||||
|
// Browsers will not send custom headers on cross-origin requests without a
|
||||||
|
// CORS preflight, which effectively blocks CSRF attacks via cookie-based auth.
|
||||||
|
// Safe methods (GET, HEAD, OPTIONS) are allowed without the header.
|
||||||
|
func (m *Middleware) APICSRFProtection() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
request *http.Request,
|
||||||
|
) {
|
||||||
|
switch request.Method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||||
|
// Safe methods — no CSRF risk.
|
||||||
|
default:
|
||||||
|
if request.Header.Get("X-Requested-With") == "" {
|
||||||
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
|
http.Error(
|
||||||
|
writer,
|
||||||
|
`{"error":"missing X-Requested-With header"}`,
|
||||||
|
http.StatusForbidden,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(writer, request)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// APISessionAuth returns middleware that requires session authentication for API routes.
|
// APISessionAuth returns middleware that requires session authentication for API routes.
|
||||||
// Unlike SessionAuth, it returns JSON 401 responses instead of redirecting to /login.
|
// Unlike SessionAuth, it returns JSON 401 responses instead of redirecting to /login.
|
||||||
func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler {
|
func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler {
|
||||||
|
|||||||
@ -102,8 +102,10 @@ func (s *Server) SetupRoutes() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// API v1 routes (cookie-based session auth, no CSRF)
|
// API v1 routes (cookie-based session auth, CSRF protected via custom header)
|
||||||
s.router.Route("/api/v1", func(r chi.Router) {
|
s.router.Route("/api/v1", func(r chi.Router) {
|
||||||
|
r.Use(s.mw.APICSRFProtection())
|
||||||
|
|
||||||
// Login endpoint is public (returns session cookie)
|
// Login endpoint is public (returns session cookie)
|
||||||
r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleAPILoginPOST())
|
r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleAPILoginPOST())
|
||||||
|
|
||||||
|
|||||||
@ -73,7 +73,7 @@ func New(_ fx.Lifecycle, params ServiceParams) (*Service, error) {
|
|||||||
MaxAge: sessionMaxAgeSeconds,
|
MaxAge: sessionMaxAgeSeconds,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: !params.Config.Debug,
|
Secure: !params.Config.Debug,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteStrictMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user