refactor: replace custom CSRF and rate-limiting with off-the-shelf libraries
All checks were successful
check / check (push) Successful in 4s

Replace custom CSRF middleware with gorilla/csrf and custom rate-limiting
middleware with go-chi/httprate, as requested in code review.

CSRF changes:
- Replace session-based CSRF tokens with gorilla/csrf cookie-based
  double-submit pattern (HMAC-authenticated cookies)
- Keep same form field name (csrf_token) for template compatibility
- Keep same route exclusions (webhook/API routes)
- In dev mode, mark requests as plaintext HTTP to skip Referer check

Rate limiting changes:
- Replace custom token-bucket rate limiter with httprate sliding-window
  counter (per-IP, 5 POST requests/min on login endpoint)
- Remove custom IP extraction (httprate.KeyByRealIP handles
  X-Forwarded-For, X-Real-IP, True-Client-IP)
- Remove custom cleanup goroutine (httprate manages its own state)

Kept as-is:
- SSRF prevention code (internal/delivery/ssrf.go) — application-specific
- CSRFToken() wrapper function — handlers unchanged

Updated README security section and architecture overview to reflect
library choices.
This commit is contained in:
clawbot
2026-03-10 10:05:38 -07:00
parent 5c69efb5bc
commit 0829f9a75d
11 changed files with 126 additions and 317 deletions

View File

@@ -157,6 +157,10 @@ It uses:
logging with TTY detection (text for dev, JSON for prod) logging with TTY detection (text for dev, JSON for prod)
- **[gorilla/sessions](https://github.com/gorilla/sessions)** for - **[gorilla/sessions](https://github.com/gorilla/sessions)** for
encrypted cookie-based session management encrypted cookie-based session management
- **[gorilla/csrf](https://github.com/gorilla/csrf)** for CSRF
protection (cookie-based double-submit tokens)
- **[go-chi/httprate](https://github.com/go-chi/httprate)** for
per-IP login rate limiting (sliding window counter)
- **[Prometheus](https://prometheus.io)** for metrics, served at - **[Prometheus](https://prometheus.io)** for metrics, served at
`/metrics` behind basic auth `/metrics` behind basic auth
- **[Sentry](https://sentry.io)** for optional error reporting - **[Sentry](https://sentry.io)** for optional error reporting
@@ -726,8 +730,8 @@ webhooker/
│ │ └── logger.go # slog setup with TTY detection │ │ └── logger.go # slog setup with TTY detection
│ ├── middleware/ │ ├── middleware/
│ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize │ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize
│ │ ├── csrf.go # CSRF protection middleware (session-based tokens) │ │ ├── csrf.go # CSRF protection middleware (gorilla/csrf)
│ │ └── ratelimit.go # Per-IP rate limiting middleware (login endpoint) │ │ └── ratelimit.go # Per-IP rate limiting middleware (go-chi/httprate)
│ ├── server/ │ ├── server/
│ │ ├── server.go # Server struct, fx lifecycle, signal handling │ │ ├── server.go # Server struct, fx lifecycle, signal handling
│ │ ├── http.go # HTTP server setup with timeouts │ │ ├── http.go # HTTP server setup with timeouts
@@ -814,17 +818,19 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a
(`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy, (`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy,
and Permissions-Policy and Permissions-Policy
- Request body size limits (1 MB) on all form POST endpoints - Request body size limits (1 MB) on all form POST endpoints
- **CSRF protection** on all state-changing forms (session-based tokens - **CSRF protection** via [gorilla/csrf](https://github.com/gorilla/csrf)
with constant-time comparison). Applied to `/pages`, `/sources`, on all state-changing forms (cookie-based double-submit tokens with
`/source`, and `/user` routes. Excluded from `/webhook` (inbound HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and
webhook POSTs) and `/api` (stateless API) `/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and
`/api` (stateless API)
- **SSRF prevention** for HTTP delivery targets: private/reserved IP - **SSRF prevention** for HTTP delivery targets: private/reserved IP
ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked
both at target creation time (URL validation) and at delivery time both at target creation time (URL validation) and at delivery time
(custom HTTP transport with SSRF-safe dialer that validates resolved (custom HTTP transport with SSRF-safe dialer that validates resolved
IPs before connecting, preventing DNS rebinding attacks) IPs before connecting, preventing DNS rebinding attacks)
- **Login rate limiting**: per-IP rate limiter on the login endpoint - **Login rate limiting** via [go-chi/httprate](https://github.com/go-chi/httprate):
(5 attempts per minute per IP) to prevent brute-force attacks per-IP sliding-window rate limiter on the login endpoint (5 POST
attempts per minute per IP) to prevent brute-force attacks
- Prometheus metrics behind basic auth - Prometheus metrics behind basic auth
- Static assets embedded in binary (no filesystem access needed at - Static assets embedded in binary (no filesystem access needed at
runtime) runtime)

5
go.mod
View File

@@ -9,7 +9,9 @@ require (
github.com/getsentry/sentry-go v0.25.0 github.com/getsentry/sentry-go v0.25.0
github.com/go-chi/chi v1.5.5 github.com/go-chi/chi v1.5.5
github.com/go-chi/cors v1.2.1 github.com/go-chi/cors v1.2.1
github.com/go-chi/httprate v0.15.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/csrf v1.7.3
github.com/gorilla/sessions v1.4.0 github.com/gorilla/sessions v1.4.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
@@ -17,7 +19,6 @@ require (
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
go.uber.org/fx v1.20.1 go.uber.org/fx v1.20.1
golang.org/x/crypto v0.38.0 golang.org/x/crypto v0.38.0
golang.org/x/time v0.14.0
gorm.io/driver/sqlite v1.5.4 gorm.io/driver/sqlite v1.5.4
gorm.io/gorm v1.25.5 gorm.io/gorm v1.25.5
modernc.org/sqlite v1.28.0 modernc.org/sqlite v1.28.0
@@ -32,6 +33,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect
@@ -41,6 +43,7 @@ require (
github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/common v0.45.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
go.uber.org/dig v1.17.0 // indirect go.uber.org/dig v1.17.0 // indirect
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect

12
go.sum
View File

@@ -19,6 +19,8 @@ github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE=
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g=
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
@@ -31,6 +33,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
@@ -43,6 +47,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -80,6 +86,10 @@ github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdr
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI=
@@ -103,8 +113,6 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -1,102 +1,56 @@
package middleware package middleware
import ( import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"net/http" "net/http"
"github.com/gorilla/csrf"
) )
const (
// csrfTokenLength is the byte length of generated CSRF tokens.
// 32 bytes = 64 hex characters, providing 256 bits of entropy.
csrfTokenLength = 32
// csrfSessionKey is the session key where the CSRF token is stored.
csrfSessionKey = "csrf_token"
// csrfFormField is the HTML form field name for the CSRF token.
csrfFormField = "csrf_token"
)
// csrfContextKey is the context key type for CSRF tokens.
type csrfContextKey struct{}
// CSRFToken retrieves the CSRF token from the request context. // CSRFToken retrieves the CSRF token from the request context.
// Returns an empty string if no token is present. // Returns an empty string if the gorilla/csrf middleware has not run.
func CSRFToken(r *http.Request) string { func CSRFToken(r *http.Request) string {
if token, ok := r.Context().Value(csrfContextKey{}).(string); ok { return csrf.Token(r)
return token
}
return ""
} }
// CSRF returns middleware that provides CSRF protection for state-changing // CSRF returns middleware that provides CSRF protection using the
// requests. For every request, it ensures a CSRF token exists in the // gorilla/csrf library. The middleware uses the session authentication
// session and makes it available via the request context. For POST, PUT, // key to sign a CSRF cookie and validates a masked token submitted via
// PATCH, and DELETE requests, it validates the submitted csrf_token form // the "csrf_token" form field (or the "X-CSRF-Token" header) on
// field against the session token. Requests with an invalid or missing // POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing
// token receive a 403 Forbidden response. // token receive a 403 Forbidden response.
//
// In development mode, requests are marked as plaintext HTTP so that
// gorilla/csrf skips the strict Referer-origin check (which is only
// meaningful over TLS).
func (m *Middleware) CSRF() func(http.Handler) http.Handler { func (m *Middleware) CSRF() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { protect := csrf.Protect(
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.session.GetKey(),
sess, err := m.session.Get(r) csrf.FieldName("csrf_token"),
if err != nil { csrf.Secure(!m.params.Config.IsDev()),
m.log.Error("csrf: failed to get session", "error", err) csrf.SameSite(csrf.SameSiteLaxMode),
http.Error(w, "Forbidden", http.StatusForbidden) csrf.Path("/"),
return csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
} m.log.Warn("csrf: token validation failed",
// Ensure a CSRF token exists in the session
token, ok := sess.Values[csrfSessionKey].(string)
if !ok {
token = ""
}
if token == "" {
token, err = generateCSRFToken()
if err != nil {
m.log.Error("csrf: failed to generate token", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
sess.Values[csrfSessionKey] = token
if saveErr := m.session.Save(r, w, sess); saveErr != nil {
m.log.Error("csrf: failed to save session", "error", saveErr)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
// Store token in context for templates
ctx := context.WithValue(r.Context(), csrfContextKey{}, token)
r = r.WithContext(ctx)
// Validate token on state-changing methods
switch r.Method {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
submitted := r.FormValue(csrfFormField)
if subtle.ConstantTimeCompare([]byte(submitted), []byte(token)) != 1 {
m.log.Warn("csrf: token mismatch",
"method", r.Method, "method", r.Method,
"path", r.URL.Path, "path", r.URL.Path,
"remote_addr", r.RemoteAddr, "remote_addr", r.RemoteAddr,
"reason", csrf.FailureReason(r),
) )
http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden)
return })),
} )
}
next.ServeHTTP(w, r) // In development (plaintext HTTP), signal gorilla/csrf to skip
// the strict TLS Referer check by injecting the PlaintextHTTP
// context key before the CSRF handler sees the request.
if m.params.Config.IsDev() {
return func(next http.Handler) http.Handler {
csrfHandler := protect(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r))
}) })
} }
}
// generateCSRFToken creates a cryptographically random hex-encoded token.
func generateCSRFToken() (string, error) {
b := make([]byte, csrfTokenLength)
if _, err := rand.Read(b); err != nil {
return "", err
} }
return hex.EncodeToString(b), nil
return protect
} }

View File

@@ -27,20 +27,19 @@ func TestCSRF_GETSetsToken(t *testing.T) {
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
assert.Len(t, gotToken, csrfTokenLength*2, "CSRF token should be hex-encoded 32 bytes")
} }
func TestCSRF_POSTWithValidToken(t *testing.T) { func TestCSRF_POSTWithValidToken(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
// Use a separate handler for the GET to capture the token // Capture the token from a GET request
var token string var token string
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { csrfMiddleware := m.CSRF()
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
token = CSRFToken(r) token = CSRFToken(r)
})) }))
// GET to establish the session and capture token
getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
getW := httptest.NewRecorder() getW := httptest.NewRecorder()
getHandler.ServeHTTP(getW, getReq) getHandler.ServeHTTP(getW, getReq)
@@ -49,14 +48,13 @@ func TestCSRF_POSTWithValidToken(t *testing.T) {
require.NotEmpty(t, cookies) require.NotEmpty(t, cookies)
require.NotEmpty(t, token) require.NotEmpty(t, token)
// POST handler that tracks whether it was called // POST with valid token and cookies from the GET response
var called bool var called bool
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
})) }))
// POST with valid token form := url.Values{"csrf_token": {token}}
form := url.Values{csrfFormField: {token}}
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range cookies { for _, c := range cookies {
@@ -73,23 +71,21 @@ func TestCSRF_POSTWithoutToken(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
// GET handler to establish session csrfMiddleware := m.CSRF()
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
// no-op — just establishes session
}))
// GET to establish the CSRF cookie
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
getW := httptest.NewRecorder() getW := httptest.NewRecorder()
getHandler.ServeHTTP(getW, getReq) getHandler.ServeHTTP(getW, getReq)
cookies := getW.Result().Cookies() cookies := getW.Result().Cookies()
// POST handler that tracks whether it was called // POST without CSRF token
var called bool var called bool
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
})) }))
// POST without CSRF token
postReq := httptest.NewRequest(http.MethodPost, "/form", nil) postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range cookies { for _, c := range cookies {
@@ -107,24 +103,22 @@ func TestCSRF_POSTWithInvalidToken(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
// GET handler to establish session csrfMiddleware := m.CSRF()
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
// no-op — just establishes session
}))
// GET to establish the CSRF cookie
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
getReq := httptest.NewRequest(http.MethodGet, "/form", nil) getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
getW := httptest.NewRecorder() getW := httptest.NewRecorder()
getHandler.ServeHTTP(getW, getReq) getHandler.ServeHTTP(getW, getReq)
cookies := getW.Result().Cookies() cookies := getW.Result().Cookies()
// POST handler that tracks whether it was called // POST with wrong CSRF token
var called bool var called bool
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
})) }))
// POST with wrong CSRF token form := url.Values{"csrf_token": {"invalid-token-value"}}
form := url.Values{csrfFormField: {"invalid-token-value"}}
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range cookies { for _, c := range cookies {
@@ -156,20 +150,8 @@ func TestCSRF_GETDoesNotValidate(t *testing.T) {
assert.True(t, called, "GET requests should pass through CSRF middleware") assert.True(t, called, "GET requests should pass through CSRF middleware")
} }
func TestCSRFToken_NoContext(t *testing.T) { func TestCSRFToken_NoMiddleware(t *testing.T) {
t.Parallel() t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when no token in context") assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run")
}
func TestGenerateCSRFToken(t *testing.T) {
t.Parallel()
token, err := generateCSRFToken()
require.NoError(t, err)
assert.Len(t, token, csrfTokenLength*2, "token should be hex-encoded")
// Verify uniqueness
token2, err := generateCSRFToken()
require.NoError(t, err)
assert.NotEqual(t, token, token2, "each generated token should be unique")
} }

View File

@@ -40,7 +40,7 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
} }
sessManager := newTestSession(t, store, cfg, log) sessManager := newTestSession(t, store, cfg, log, key)
m := &Middleware{ m := &Middleware{
log: log, log: log,
@@ -55,9 +55,9 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
// newTestSession creates a session.Session with a pre-configured cookie store // newTestSession creates a session.Session with a pre-configured cookie store
// for testing. This avoids needing the fx lifecycle and database. // for testing. This avoids needing the fx lifecycle and database.
func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *session.Session { func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *session.Session {
t.Helper() t.Helper()
return session.NewForTest(store, cfg, log) return session.NewForTest(store, cfg, log, key)
} }
// --- Logging Middleware Tests --- // --- Logging Middleware Tests ---
@@ -326,7 +326,7 @@ func TestMetricsAuth_ValidCredentials(t *testing.T) {
store := sessions.NewCookieStore(key) store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{Path: "/", MaxAge: 86400} store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
sessManager := session.NewForTest(store, cfg, log) sessManager := session.NewForTest(store, cfg, log, key)
m := &Middleware{ m := &Middleware{
log: log, log: log,
@@ -366,7 +366,7 @@ func TestMetricsAuth_InvalidCredentials(t *testing.T) {
store := sessions.NewCookieStore(key) store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{Path: "/", MaxAge: 86400} store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
sessManager := session.NewForTest(store, cfg, log) sessManager := session.NewForTest(store, cfg, log, key)
m := &Middleware{ m := &Middleware{
log: log, log: log,
@@ -406,7 +406,7 @@ func TestMetricsAuth_NoCredentials(t *testing.T) {
store := sessions.NewCookieStore(key) store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{Path: "/", MaxAge: 86400} store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
sessManager := session.NewForTest(store, cfg, log) sessManager := session.NewForTest(store, cfg, log, key)
m := &Middleware{ m := &Middleware{
log: log, log: log,

View File

@@ -1,12 +1,10 @@
package middleware package middleware
import ( import (
"net"
"net/http" "net/http"
"sync"
"time" "time"
"golang.org/x/time/rate" "github.com/go-chi/httprate"
) )
const ( const (
@@ -15,158 +13,36 @@ const (
// loginRateInterval is the time window for the rate limit. // loginRateInterval is the time window for the rate limit.
loginRateInterval = 1 * time.Minute loginRateInterval = 1 * time.Minute
// limiterCleanupInterval is how often stale per-IP limiters are pruned.
limiterCleanupInterval = 5 * time.Minute
// limiterMaxAge is how long an unused limiter is kept before pruning.
limiterMaxAge = 10 * time.Minute
) )
// ipLimiter holds a rate limiter and the time it was last used.
type ipLimiter struct {
limiter *rate.Limiter
lastSeen time.Time
}
// rateLimiterMap manages per-IP rate limiters with periodic cleanup.
type rateLimiterMap struct {
mu sync.Mutex
limiters map[string]*ipLimiter
rate rate.Limit
burst int
}
// newRateLimiterMap creates a new per-IP rate limiter map.
func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap {
rlm := &rateLimiterMap{
limiters: make(map[string]*ipLimiter),
rate: r,
burst: burst,
}
// Start background cleanup goroutine
go rlm.cleanup()
return rlm
}
// getLimiter returns the rate limiter for the given IP, creating one if
// it doesn't exist.
func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter {
rlm.mu.Lock()
defer rlm.mu.Unlock()
if entry, ok := rlm.limiters[ip]; ok {
entry.lastSeen = time.Now()
return entry.limiter
}
limiter := rate.NewLimiter(rlm.rate, rlm.burst)
rlm.limiters[ip] = &ipLimiter{
limiter: limiter,
lastSeen: time.Now(),
}
return limiter
}
// cleanup periodically removes stale rate limiters to prevent unbounded
// memory growth from unique IPs.
func (rlm *rateLimiterMap) cleanup() {
ticker := time.NewTicker(limiterCleanupInterval)
defer ticker.Stop()
for range ticker.C {
rlm.mu.Lock()
cutoff := time.Now().Add(-limiterMaxAge)
for ip, entry := range rlm.limiters {
if entry.lastSeen.Before(cutoff) {
delete(rlm.limiters, ip)
}
}
rlm.mu.Unlock()
}
}
// LoginRateLimit returns middleware that enforces per-IP rate limiting // LoginRateLimit returns middleware that enforces per-IP rate limiting
// on login attempts. Only POST requests are rate-limited; GET requests // on login attempts using go-chi/httprate. Only POST requests are
// (rendering the login form) pass through unaffected. When the rate // rate-limited; GET requests (rendering the login form) pass through
// limit is exceeded, a 429 Too Many Requests response is returned. // unaffected. When the rate limit is exceeded, a 429 Too Many Requests
// response is returned. IP extraction honours X-Forwarded-For,
// X-Real-IP, and True-Client-IP headers for reverse-proxy setups.
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
// Calculate rate: loginRateLimit events per loginRateInterval limiter := httprate.Limit(
r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds()) loginRateLimit,
rlm := newRateLimiterMap(r, loginRateLimit) loginRateInterval,
httprate.WithKeyFuncs(httprate.KeyByRealIP),
httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.log.Warn("login rate limit exceeded",
"path", r.URL.Path,
)
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
})),
)
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
limited := limiter(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only rate-limit POST requests (actual login attempts) // Only rate-limit POST requests (actual login attempts)
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
limited.ServeHTTP(w, r)
ip := extractIP(r)
limiter := rlm.getLimiter(ip)
if !limiter.Allow() {
m.log.Warn("login rate limit exceeded",
"ip", ip,
"path", r.URL.Path,
)
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
}) })
} }
} }
// extractIP extracts the client IP address from the request. It checks
// X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
// then falls back to RemoteAddr.
func extractIP(r *http.Request) string {
// Check X-Forwarded-For header (first IP in chain)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2
// The first one is the original client
for i := 0; i < len(xff); i++ {
if xff[i] == ',' {
ip := xff[:i]
// Trim whitespace
for len(ip) > 0 && ip[0] == ' ' {
ip = ip[1:]
}
for len(ip) > 0 && ip[len(ip)-1] == ' ' {
ip = ip[:len(ip)-1]
}
if ip != "" {
return ip
}
break
}
}
trimmed := xff
for len(trimmed) > 0 && trimmed[0] == ' ' {
trimmed = trimmed[1:]
}
for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' {
trimmed = trimmed[:len(trimmed)-1]
}
if trimmed != "" {
return trimmed
}
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

View File

@@ -88,34 +88,3 @@ func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected") assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected")
} }
func TestExtractIP_RemoteAddr(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.100:54321"
assert.Equal(t, "192.168.1.100", extractIP(req))
}
func TestExtractIP_XForwardedFor(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
req.Header.Set("X-Forwarded-For", "203.0.113.50, 70.41.3.18, 150.172.238.178")
assert.Equal(t, "203.0.113.50", extractIP(req))
}
func TestExtractIP_XRealIP(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
req.Header.Set("X-Real-IP", "203.0.113.50")
assert.Equal(t, "203.0.113.50", extractIP(req))
}
func TestExtractIP_XForwardedForSingle(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
req.Header.Set("X-Forwarded-For", "203.0.113.50")
assert.Equal(t, "203.0.113.50", extractIP(req))
}

View File

@@ -39,6 +39,7 @@ type SessionParams struct {
// Session manages encrypted session storage // Session manages encrypted session storage
type Session struct { type Session struct {
store *sessions.CookieStore store *sessions.CookieStore
key []byte // raw 32-byte auth key, also used for CSRF cookie signing
log *slog.Logger log *slog.Logger
config *config.Config config *config.Config
} }
@@ -79,6 +80,7 @@ func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
} }
s.key = keyBytes
s.store = store s.store = store
s.log.Info("session manager initialized") s.log.Info("session manager initialized")
return nil return nil
@@ -93,6 +95,12 @@ func (s *Session) Get(r *http.Request) (*sessions.Session, error) {
return s.store.Get(r, SessionName) return s.store.Get(r, SessionName)
} }
// GetKey returns the raw 32-byte authentication key used for session
// encryption. This key is also suitable for CSRF cookie signing.
func (s *Session) GetKey() []byte {
return s.key
}
// Save saves the session // Save saves the session
func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error { func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
return sess.Save(r, w) return sess.Save(r, w)

View File

@@ -34,7 +34,7 @@ func testSession(t *testing.T) *Session {
} }
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
return NewForTest(store, cfg, log) return NewForTest(store, cfg, log, key)
} }
// --- Get and Save Tests --- // --- Get and Save Tests ---

View File

@@ -9,10 +9,13 @@ import (
// NewForTest creates a Session with a pre-configured cookie store for use // NewForTest creates a Session with a pre-configured cookie store for use
// in tests. This bypasses the fx lifecycle and database dependency, allowing // in tests. This bypasses the fx lifecycle and database dependency, allowing
// middleware and handler tests to use real session functionality. // middleware and handler tests to use real session functionality. The key
func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *Session { // parameter is the raw 32-byte authentication key used for session encryption
// and CSRF cookie signing.
func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *Session {
return &Session{ return &Session{
store: store, store: store,
key: key,
config: cfg, config: cfg,
log: log, log: log,
} }