From 0829f9a75d1d99d8a78fea6b6166426d56e90536 Mon Sep 17 00:00:00 2001 From: clawbot Date: Tue, 10 Mar 2026 10:05:38 -0700 Subject: [PATCH] refactor: replace custom CSRF and rate-limiting with off-the-shelf libraries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace custom CSRF middleware with gorilla/csrf and custom rate-limiting middleware with go-chi/httprate, as requested in code review. CSRF changes: - Replace session-based CSRF tokens with gorilla/csrf cookie-based double-submit pattern (HMAC-authenticated cookies) - Keep same form field name (csrf_token) for template compatibility - Keep same route exclusions (webhook/API routes) - In dev mode, mark requests as plaintext HTTP to skip Referer check Rate limiting changes: - Replace custom token-bucket rate limiter with httprate sliding-window counter (per-IP, 5 POST requests/min on login endpoint) - Remove custom IP extraction (httprate.KeyByRealIP handles X-Forwarded-For, X-Real-IP, True-Client-IP) - Remove custom cleanup goroutine (httprate manages its own state) Kept as-is: - SSRF prevention code (internal/delivery/ssrf.go) — application-specific - CSRFToken() wrapper function — handlers unchanged Updated README security section and architecture overview to reflect library choices. --- README.md | 22 ++-- go.mod | 5 +- go.sum | 12 +- internal/middleware/csrf.go | 126 ++++++------------- internal/middleware/csrf_test.go | 56 +++------ internal/middleware/middleware_test.go | 12 +- internal/middleware/ratelimit.go | 162 +++---------------------- internal/middleware/ratelimit_test.go | 31 ----- internal/session/session.go | 8 ++ internal/session/session_test.go | 2 +- internal/session/testing.go | 7 +- 11 files changed, 126 insertions(+), 317 deletions(-) diff --git a/README.md b/README.md index e5ea0b4..68144b3 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,10 @@ It uses: logging with TTY detection (text for dev, JSON for prod) - **[gorilla/sessions](https://github.com/gorilla/sessions)** for encrypted cookie-based session management +- **[gorilla/csrf](https://github.com/gorilla/csrf)** for CSRF + protection (cookie-based double-submit tokens) +- **[go-chi/httprate](https://github.com/go-chi/httprate)** for + per-IP login rate limiting (sliding window counter) - **[Prometheus](https://prometheus.io)** for metrics, served at `/metrics` behind basic auth - **[Sentry](https://sentry.io)** for optional error reporting @@ -726,8 +730,8 @@ webhooker/ │ │ └── logger.go # slog setup with TTY detection │ ├── middleware/ │ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize -│ │ ├── csrf.go # CSRF protection middleware (session-based tokens) -│ │ └── ratelimit.go # Per-IP rate limiting middleware (login endpoint) +│ │ ├── csrf.go # CSRF protection middleware (gorilla/csrf) +│ │ └── ratelimit.go # Per-IP rate limiting middleware (go-chi/httprate) │ ├── server/ │ │ ├── server.go # Server struct, fx lifecycle, signal handling │ │ ├── http.go # HTTP server setup with timeouts @@ -814,17 +818,19 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a (`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy, and Permissions-Policy - Request body size limits (1 MB) on all form POST endpoints -- **CSRF protection** on all state-changing forms (session-based tokens - with constant-time comparison). Applied to `/pages`, `/sources`, - `/source`, and `/user` routes. Excluded from `/webhook` (inbound - webhook POSTs) and `/api` (stateless API) +- **CSRF protection** via [gorilla/csrf](https://github.com/gorilla/csrf) + on all state-changing forms (cookie-based double-submit tokens with + HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and + `/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and + `/api` (stateless API) - **SSRF prevention** for HTTP delivery targets: private/reserved IP ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked both at target creation time (URL validation) and at delivery time (custom HTTP transport with SSRF-safe dialer that validates resolved IPs before connecting, preventing DNS rebinding attacks) -- **Login rate limiting**: per-IP rate limiter on the login endpoint - (5 attempts per minute per IP) to prevent brute-force attacks +- **Login rate limiting** via [go-chi/httprate](https://github.com/go-chi/httprate): + per-IP sliding-window rate limiter on the login endpoint (5 POST + attempts per minute per IP) to prevent brute-force attacks - Prometheus metrics behind basic auth - Static assets embedded in binary (no filesystem access needed at runtime) diff --git a/go.mod b/go.mod index 16f4ce1..6f4ff2a 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,9 @@ require ( github.com/getsentry/sentry-go v0.25.0 github.com/go-chi/chi v1.5.5 github.com/go-chi/cors v1.2.1 + github.com/go-chi/httprate v0.15.0 github.com/google/uuid v1.6.0 + github.com/gorilla/csrf v1.7.3 github.com/gorilla/sessions v1.4.0 github.com/joho/godotenv v1.5.1 github.com/prometheus/client_golang v1.18.0 @@ -17,7 +19,6 @@ require ( github.com/stretchr/testify v1.8.4 go.uber.org/fx v1.20.1 golang.org/x/crypto v0.38.0 - golang.org/x/time v0.14.0 gorm.io/driver/sqlite v1.5.4 gorm.io/gorm v1.25.5 modernc.org/sqlite v1.28.0 @@ -32,6 +33,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect @@ -41,6 +43,7 @@ require ( github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.9.0 // indirect diff --git a/go.sum b/go.sum index 5b4395e..f7e0ba6 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE= github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= +github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g= +github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -31,6 +33,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= @@ -43,6 +47,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -80,6 +86,10 @@ github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdr github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= @@ -103,8 +113,6 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= -golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= -golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index 7e84f3a..df2ffb2 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -1,102 +1,56 @@ package middleware import ( - "context" - "crypto/rand" - "crypto/subtle" - "encoding/hex" "net/http" + + "github.com/gorilla/csrf" ) -const ( - // csrfTokenLength is the byte length of generated CSRF tokens. - // 32 bytes = 64 hex characters, providing 256 bits of entropy. - csrfTokenLength = 32 - - // csrfSessionKey is the session key where the CSRF token is stored. - csrfSessionKey = "csrf_token" - - // csrfFormField is the HTML form field name for the CSRF token. - csrfFormField = "csrf_token" -) - -// csrfContextKey is the context key type for CSRF tokens. -type csrfContextKey struct{} - // CSRFToken retrieves the CSRF token from the request context. -// Returns an empty string if no token is present. +// Returns an empty string if the gorilla/csrf middleware has not run. func CSRFToken(r *http.Request) string { - if token, ok := r.Context().Value(csrfContextKey{}).(string); ok { - return token - } - return "" + return csrf.Token(r) } -// CSRF returns middleware that provides CSRF protection for state-changing -// requests. For every request, it ensures a CSRF token exists in the -// session and makes it available via the request context. For POST, PUT, -// PATCH, and DELETE requests, it validates the submitted csrf_token form -// field against the session token. Requests with an invalid or missing +// CSRF returns middleware that provides CSRF protection using the +// gorilla/csrf library. The middleware uses the session authentication +// key to sign a CSRF cookie and validates a masked token submitted via +// the "csrf_token" form field (or the "X-CSRF-Token" header) on +// POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing // token receive a 403 Forbidden response. +// +// In development mode, requests are marked as plaintext HTTP so that +// gorilla/csrf skips the strict Referer-origin check (which is only +// meaningful over TLS). func (m *Middleware) CSRF() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sess, err := m.session.Get(r) - if err != nil { - m.log.Error("csrf: failed to get session", "error", err) - http.Error(w, "Forbidden", http.StatusForbidden) - return - } + protect := csrf.Protect( + m.session.GetKey(), + csrf.FieldName("csrf_token"), + csrf.Secure(!m.params.Config.IsDev()), + csrf.SameSite(csrf.SameSiteLaxMode), + csrf.Path("/"), + csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.log.Warn("csrf: token validation failed", + "method", r.Method, + "path", r.URL.Path, + "remote_addr", r.RemoteAddr, + "reason", csrf.FailureReason(r), + ) + http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) + })), + ) - // Ensure a CSRF token exists in the session - token, ok := sess.Values[csrfSessionKey].(string) - if !ok { - token = "" - } - if token == "" { - token, err = generateCSRFToken() - if err != nil { - m.log.Error("csrf: failed to generate token", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - sess.Values[csrfSessionKey] = token - if saveErr := m.session.Save(r, w, sess); saveErr != nil { - m.log.Error("csrf: failed to save session", "error", saveErr) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - } - - // Store token in context for templates - ctx := context.WithValue(r.Context(), csrfContextKey{}, token) - r = r.WithContext(ctx) - - // Validate token on state-changing methods - switch r.Method { - case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: - submitted := r.FormValue(csrfFormField) - if subtle.ConstantTimeCompare([]byte(submitted), []byte(token)) != 1 { - m.log.Warn("csrf: token mismatch", - "method", r.Method, - "path", r.URL.Path, - "remote_addr", r.RemoteAddr, - ) - http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) - return - } - } - - next.ServeHTTP(w, r) - }) + // In development (plaintext HTTP), signal gorilla/csrf to skip + // the strict TLS Referer check by injecting the PlaintextHTTP + // context key before the CSRF handler sees the request. + if m.params.Config.IsDev() { + return func(next http.Handler) http.Handler { + csrfHandler := protect(next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r)) + }) + } } -} -// generateCSRFToken creates a cryptographically random hex-encoded token. -func generateCSRFToken() (string, error) { - b := make([]byte, csrfTokenLength) - if _, err := rand.Read(b); err != nil { - return "", err - } - return hex.EncodeToString(b), nil + return protect } diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go index 3940e61..ae1791d 100644 --- a/internal/middleware/csrf_test.go +++ b/internal/middleware/csrf_test.go @@ -27,20 +27,19 @@ func TestCSRF_GETSetsToken(t *testing.T) { handler.ServeHTTP(w, req) assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") - assert.Len(t, gotToken, csrfTokenLength*2, "CSRF token should be hex-encoded 32 bytes") } func TestCSRF_POSTWithValidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) - // Use a separate handler for the GET to capture the token + // Capture the token from a GET request var token string - getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + csrfMiddleware := m.CSRF() + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { token = CSRFToken(r) })) - // GET to establish the session and capture token getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) @@ -49,14 +48,13 @@ func TestCSRF_POSTWithValidToken(t *testing.T) { require.NotEmpty(t, cookies) require.NotEmpty(t, token) - // POST handler that tracks whether it was called + // POST with valid token and cookies from the GET response var called bool - postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) - // POST with valid token - form := url.Values{csrfFormField: {token}} + form := url.Values{"csrf_token": {token}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { @@ -73,23 +71,21 @@ func TestCSRF_POSTWithoutToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) - // GET handler to establish session - getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - // no-op — just establishes session - })) + csrfMiddleware := m.CSRF() + // GET to establish the CSRF cookie + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() - // POST handler that tracks whether it was called + // POST without CSRF token var called bool - postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) - // POST without CSRF token postReq := httptest.NewRequest(http.MethodPost, "/form", nil) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { @@ -107,24 +103,22 @@ func TestCSRF_POSTWithInvalidToken(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) - // GET handler to establish session - getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - // no-op — just establishes session - })) + csrfMiddleware := m.CSRF() + // GET to establish the CSRF cookie + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getW := httptest.NewRecorder() getHandler.ServeHTTP(getW, getReq) cookies := getW.Result().Cookies() - // POST handler that tracks whether it was called + // POST with wrong CSRF token var called bool - postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) - // POST with wrong CSRF token - form := url.Values{csrfFormField: {"invalid-token-value"}} + form := url.Values{"csrf_token": {"invalid-token-value"}} postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { @@ -156,20 +150,8 @@ func TestCSRF_GETDoesNotValidate(t *testing.T) { assert.True(t, called, "GET requests should pass through CSRF middleware") } -func TestCSRFToken_NoContext(t *testing.T) { +func TestCSRFToken_NoMiddleware(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) - assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when no token in context") -} - -func TestGenerateCSRFToken(t *testing.T) { - t.Parallel() - token, err := generateCSRFToken() - require.NoError(t, err) - assert.Len(t, token, csrfTokenLength*2, "token should be hex-encoded") - - // Verify uniqueness - token2, err := generateCSRFToken() - require.NoError(t, err) - assert.NotEqual(t, token, token2, "each generated token should be unique") + assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run") } diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 9267393..6702ad2 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -40,7 +40,7 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) { SameSite: http.SameSiteLaxMode, } - sessManager := newTestSession(t, store, cfg, log) + sessManager := newTestSession(t, store, cfg, log, key) m := &Middleware{ log: log, @@ -55,9 +55,9 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) { // newTestSession creates a session.Session with a pre-configured cookie store // for testing. This avoids needing the fx lifecycle and database. -func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *session.Session { +func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *session.Session { t.Helper() - return session.NewForTest(store, cfg, log) + return session.NewForTest(store, cfg, log, key) } // --- Logging Middleware Tests --- @@ -326,7 +326,7 @@ func TestMetricsAuth_ValidCredentials(t *testing.T) { store := sessions.NewCookieStore(key) store.Options = &sessions.Options{Path: "/", MaxAge: 86400} - sessManager := session.NewForTest(store, cfg, log) + sessManager := session.NewForTest(store, cfg, log, key) m := &Middleware{ log: log, @@ -366,7 +366,7 @@ func TestMetricsAuth_InvalidCredentials(t *testing.T) { store := sessions.NewCookieStore(key) store.Options = &sessions.Options{Path: "/", MaxAge: 86400} - sessManager := session.NewForTest(store, cfg, log) + sessManager := session.NewForTest(store, cfg, log, key) m := &Middleware{ log: log, @@ -406,7 +406,7 @@ func TestMetricsAuth_NoCredentials(t *testing.T) { store := sessions.NewCookieStore(key) store.Options = &sessions.Options{Path: "/", MaxAge: 86400} - sessManager := session.NewForTest(store, cfg, log) + sessManager := session.NewForTest(store, cfg, log, key) m := &Middleware{ log: log, diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index db23efd..e82bec3 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -1,12 +1,10 @@ package middleware import ( - "net" "net/http" - "sync" "time" - "golang.org/x/time/rate" + "github.com/go-chi/httprate" ) const ( @@ -15,158 +13,36 @@ const ( // loginRateInterval is the time window for the rate limit. loginRateInterval = 1 * time.Minute - - // limiterCleanupInterval is how often stale per-IP limiters are pruned. - limiterCleanupInterval = 5 * time.Minute - - // limiterMaxAge is how long an unused limiter is kept before pruning. - limiterMaxAge = 10 * time.Minute ) -// ipLimiter holds a rate limiter and the time it was last used. -type ipLimiter struct { - limiter *rate.Limiter - lastSeen time.Time -} - -// rateLimiterMap manages per-IP rate limiters with periodic cleanup. -type rateLimiterMap struct { - mu sync.Mutex - limiters map[string]*ipLimiter - rate rate.Limit - burst int -} - -// newRateLimiterMap creates a new per-IP rate limiter map. -func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap { - rlm := &rateLimiterMap{ - limiters: make(map[string]*ipLimiter), - rate: r, - burst: burst, - } - - // Start background cleanup goroutine - go rlm.cleanup() - - return rlm -} - -// getLimiter returns the rate limiter for the given IP, creating one if -// it doesn't exist. -func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter { - rlm.mu.Lock() - defer rlm.mu.Unlock() - - if entry, ok := rlm.limiters[ip]; ok { - entry.lastSeen = time.Now() - return entry.limiter - } - - limiter := rate.NewLimiter(rlm.rate, rlm.burst) - rlm.limiters[ip] = &ipLimiter{ - limiter: limiter, - lastSeen: time.Now(), - } - return limiter -} - -// cleanup periodically removes stale rate limiters to prevent unbounded -// memory growth from unique IPs. -func (rlm *rateLimiterMap) cleanup() { - ticker := time.NewTicker(limiterCleanupInterval) - defer ticker.Stop() - - for range ticker.C { - rlm.mu.Lock() - cutoff := time.Now().Add(-limiterMaxAge) - for ip, entry := range rlm.limiters { - if entry.lastSeen.Before(cutoff) { - delete(rlm.limiters, ip) - } - } - rlm.mu.Unlock() - } -} - // LoginRateLimit returns middleware that enforces per-IP rate limiting -// on login attempts. Only POST requests are rate-limited; GET requests -// (rendering the login form) pass through unaffected. When the rate -// limit is exceeded, a 429 Too Many Requests response is returned. +// on login attempts using go-chi/httprate. Only POST requests are +// rate-limited; GET requests (rendering the login form) pass through +// unaffected. When the rate limit is exceeded, a 429 Too Many Requests +// response is returned. IP extraction honours X-Forwarded-For, +// X-Real-IP, and True-Client-IP headers for reverse-proxy setups. func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { - // Calculate rate: loginRateLimit events per loginRateInterval - r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds()) - rlm := newRateLimiterMap(r, loginRateLimit) + limiter := httprate.Limit( + loginRateLimit, + loginRateInterval, + httprate.WithKeyFuncs(httprate.KeyByRealIP), + httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.log.Warn("login rate limit exceeded", + "path", r.URL.Path, + ) + http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) + })), + ) return func(next http.Handler) http.Handler { + limited := limiter(next) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only rate-limit POST requests (actual login attempts) if r.Method != http.MethodPost { next.ServeHTTP(w, r) return } - - ip := extractIP(r) - limiter := rlm.getLimiter(ip) - - if !limiter.Allow() { - m.log.Warn("login rate limit exceeded", - "ip", ip, - "path", r.URL.Path, - ) - http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) - return - } - - next.ServeHTTP(w, r) + limited.ServeHTTP(w, r) }) } } - -// extractIP extracts the client IP address from the request. It checks -// X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), -// then falls back to RemoteAddr. -func extractIP(r *http.Request) string { - // Check X-Forwarded-For header (first IP in chain) - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - // X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2 - // The first one is the original client - for i := 0; i < len(xff); i++ { - if xff[i] == ',' { - ip := xff[:i] - // Trim whitespace - for len(ip) > 0 && ip[0] == ' ' { - ip = ip[1:] - } - for len(ip) > 0 && ip[len(ip)-1] == ' ' { - ip = ip[:len(ip)-1] - } - if ip != "" { - return ip - } - break - } - } - trimmed := xff - for len(trimmed) > 0 && trimmed[0] == ' ' { - trimmed = trimmed[1:] - } - for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' { - trimmed = trimmed[:len(trimmed)-1] - } - if trimmed != "" { - return trimmed - } - } - - // Check X-Real-IP header - if xri := r.Header.Get("X-Real-IP"); xri != "" { - return xri - } - - // Fall back to RemoteAddr - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return ip -} diff --git a/internal/middleware/ratelimit_test.go b/internal/middleware/ratelimit_test.go index bd4e85f..6cea882 100644 --- a/internal/middleware/ratelimit_test.go +++ b/internal/middleware/ratelimit_test.go @@ -88,34 +88,3 @@ func TestLoginRateLimit_IndependentPerIP(t *testing.T) { handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected") } - -func TestExtractIP_RemoteAddr(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "192.168.1.100:54321" - assert.Equal(t, "192.168.1.100", extractIP(req)) -} - -func TestExtractIP_XForwardedFor(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "10.0.0.1:1234" - req.Header.Set("X-Forwarded-For", "203.0.113.50, 70.41.3.18, 150.172.238.178") - assert.Equal(t, "203.0.113.50", extractIP(req)) -} - -func TestExtractIP_XRealIP(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "10.0.0.1:1234" - req.Header.Set("X-Real-IP", "203.0.113.50") - assert.Equal(t, "203.0.113.50", extractIP(req)) -} - -func TestExtractIP_XForwardedForSingle(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "10.0.0.1:1234" - req.Header.Set("X-Forwarded-For", "203.0.113.50") - assert.Equal(t, "203.0.113.50", extractIP(req)) -} diff --git a/internal/session/session.go b/internal/session/session.go index 73a89ba..225ce15 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -39,6 +39,7 @@ type SessionParams struct { // Session manages encrypted session storage type Session struct { store *sessions.CookieStore + key []byte // raw 32-byte auth key, also used for CSRF cookie signing log *slog.Logger config *config.Config } @@ -79,6 +80,7 @@ func New(lc fx.Lifecycle, params SessionParams) (*Session, error) { SameSite: http.SameSiteLaxMode, } + s.key = keyBytes s.store = store s.log.Info("session manager initialized") return nil @@ -93,6 +95,12 @@ func (s *Session) Get(r *http.Request) (*sessions.Session, error) { return s.store.Get(r, SessionName) } +// GetKey returns the raw 32-byte authentication key used for session +// encryption. This key is also suitable for CSRF cookie signing. +func (s *Session) GetKey() []byte { + return s.key +} + // Save saves the session func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error { return sess.Save(r, w) diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 7c694f8..0950be5 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -34,7 +34,7 @@ func testSession(t *testing.T) *Session { } log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) - return NewForTest(store, cfg, log) + return NewForTest(store, cfg, log, key) } // --- Get and Save Tests --- diff --git a/internal/session/testing.go b/internal/session/testing.go index 2d20420..d15c4af 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -9,10 +9,13 @@ import ( // NewForTest creates a Session with a pre-configured cookie store for use // in tests. This bypasses the fx lifecycle and database dependency, allowing -// middleware and handler tests to use real session functionality. -func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *Session { +// middleware and handler tests to use real session functionality. The key +// parameter is the raw 32-byte authentication key used for session encryption +// and CSRF cookie signing. +func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *Session { return &Session{ store: store, + key: key, config: cfg, log: log, }