refactor: replace custom CSRF and rate-limiting with off-the-shelf libraries
All checks were successful
check / check (push) Successful in 4s
All checks were successful
check / check (push) Successful in 4s
Replace custom CSRF middleware with gorilla/csrf and custom rate-limiting middleware with go-chi/httprate, as requested in code review. CSRF changes: - Replace session-based CSRF tokens with gorilla/csrf cookie-based double-submit pattern (HMAC-authenticated cookies) - Keep same form field name (csrf_token) for template compatibility - Keep same route exclusions (webhook/API routes) - In dev mode, mark requests as plaintext HTTP to skip Referer check Rate limiting changes: - Replace custom token-bucket rate limiter with httprate sliding-window counter (per-IP, 5 POST requests/min on login endpoint) - Remove custom IP extraction (httprate.KeyByRealIP handles X-Forwarded-For, X-Real-IP, True-Client-IP) - Remove custom cleanup goroutine (httprate manages its own state) Kept as-is: - SSRF prevention code (internal/delivery/ssrf.go) — application-specific - CSRFToken() wrapper function — handlers unchanged Updated README security section and architecture overview to reflect library choices.
This commit is contained in:
22
README.md
22
README.md
@@ -157,6 +157,10 @@ It uses:
|
|||||||
logging with TTY detection (text for dev, JSON for prod)
|
logging with TTY detection (text for dev, JSON for prod)
|
||||||
- **[gorilla/sessions](https://github.com/gorilla/sessions)** for
|
- **[gorilla/sessions](https://github.com/gorilla/sessions)** for
|
||||||
encrypted cookie-based session management
|
encrypted cookie-based session management
|
||||||
|
- **[gorilla/csrf](https://github.com/gorilla/csrf)** for CSRF
|
||||||
|
protection (cookie-based double-submit tokens)
|
||||||
|
- **[go-chi/httprate](https://github.com/go-chi/httprate)** for
|
||||||
|
per-IP login rate limiting (sliding window counter)
|
||||||
- **[Prometheus](https://prometheus.io)** for metrics, served at
|
- **[Prometheus](https://prometheus.io)** for metrics, served at
|
||||||
`/metrics` behind basic auth
|
`/metrics` behind basic auth
|
||||||
- **[Sentry](https://sentry.io)** for optional error reporting
|
- **[Sentry](https://sentry.io)** for optional error reporting
|
||||||
@@ -726,8 +730,8 @@ webhooker/
|
|||||||
│ │ └── logger.go # slog setup with TTY detection
|
│ │ └── logger.go # slog setup with TTY detection
|
||||||
│ ├── middleware/
|
│ ├── middleware/
|
||||||
│ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize
|
│ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize
|
||||||
│ │ ├── csrf.go # CSRF protection middleware (session-based tokens)
|
│ │ ├── csrf.go # CSRF protection middleware (gorilla/csrf)
|
||||||
│ │ └── ratelimit.go # Per-IP rate limiting middleware (login endpoint)
|
│ │ └── ratelimit.go # Per-IP rate limiting middleware (go-chi/httprate)
|
||||||
│ ├── server/
|
│ ├── server/
|
||||||
│ │ ├── server.go # Server struct, fx lifecycle, signal handling
|
│ │ ├── server.go # Server struct, fx lifecycle, signal handling
|
||||||
│ │ ├── http.go # HTTP server setup with timeouts
|
│ │ ├── http.go # HTTP server setup with timeouts
|
||||||
@@ -814,17 +818,19 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a
|
|||||||
(`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy,
|
(`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy,
|
||||||
and Permissions-Policy
|
and Permissions-Policy
|
||||||
- Request body size limits (1 MB) on all form POST endpoints
|
- Request body size limits (1 MB) on all form POST endpoints
|
||||||
- **CSRF protection** on all state-changing forms (session-based tokens
|
- **CSRF protection** via [gorilla/csrf](https://github.com/gorilla/csrf)
|
||||||
with constant-time comparison). Applied to `/pages`, `/sources`,
|
on all state-changing forms (cookie-based double-submit tokens with
|
||||||
`/source`, and `/user` routes. Excluded from `/webhook` (inbound
|
HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and
|
||||||
webhook POSTs) and `/api` (stateless API)
|
`/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and
|
||||||
|
`/api` (stateless API)
|
||||||
- **SSRF prevention** for HTTP delivery targets: private/reserved IP
|
- **SSRF prevention** for HTTP delivery targets: private/reserved IP
|
||||||
ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked
|
ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked
|
||||||
both at target creation time (URL validation) and at delivery time
|
both at target creation time (URL validation) and at delivery time
|
||||||
(custom HTTP transport with SSRF-safe dialer that validates resolved
|
(custom HTTP transport with SSRF-safe dialer that validates resolved
|
||||||
IPs before connecting, preventing DNS rebinding attacks)
|
IPs before connecting, preventing DNS rebinding attacks)
|
||||||
- **Login rate limiting**: per-IP rate limiter on the login endpoint
|
- **Login rate limiting** via [go-chi/httprate](https://github.com/go-chi/httprate):
|
||||||
(5 attempts per minute per IP) to prevent brute-force attacks
|
per-IP sliding-window rate limiter on the login endpoint (5 POST
|
||||||
|
attempts per minute per IP) to prevent brute-force attacks
|
||||||
- Prometheus metrics behind basic auth
|
- Prometheus metrics behind basic auth
|
||||||
- Static assets embedded in binary (no filesystem access needed at
|
- Static assets embedded in binary (no filesystem access needed at
|
||||||
runtime)
|
runtime)
|
||||||
|
|||||||
5
go.mod
5
go.mod
@@ -9,7 +9,9 @@ require (
|
|||||||
github.com/getsentry/sentry-go v0.25.0
|
github.com/getsentry/sentry-go v0.25.0
|
||||||
github.com/go-chi/chi v1.5.5
|
github.com/go-chi/chi v1.5.5
|
||||||
github.com/go-chi/cors v1.2.1
|
github.com/go-chi/cors v1.2.1
|
||||||
|
github.com/go-chi/httprate v0.15.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/gorilla/csrf v1.7.3
|
||||||
github.com/gorilla/sessions v1.4.0
|
github.com/gorilla/sessions v1.4.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/prometheus/client_golang v1.18.0
|
github.com/prometheus/client_golang v1.18.0
|
||||||
@@ -17,7 +19,6 @@ require (
|
|||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
go.uber.org/fx v1.20.1
|
go.uber.org/fx v1.20.1
|
||||||
golang.org/x/crypto v0.38.0
|
golang.org/x/crypto v0.38.0
|
||||||
golang.org/x/time v0.14.0
|
|
||||||
gorm.io/driver/sqlite v1.5.4
|
gorm.io/driver/sqlite v1.5.4
|
||||||
gorm.io/gorm v1.25.5
|
gorm.io/gorm v1.25.5
|
||||||
modernc.org/sqlite v1.28.0
|
modernc.org/sqlite v1.28.0
|
||||||
@@ -32,6 +33,7 @@ require (
|
|||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||||
@@ -41,6 +43,7 @@ require (
|
|||||||
github.com/prometheus/common v0.45.0 // indirect
|
github.com/prometheus/common v0.45.0 // indirect
|
||||||
github.com/prometheus/procfs v0.12.0 // indirect
|
github.com/prometheus/procfs v0.12.0 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/dig v1.17.0 // indirect
|
go.uber.org/dig v1.17.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
|
|||||||
12
go.sum
12
go.sum
@@ -19,6 +19,8 @@ github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE=
|
|||||||
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
|
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
|
||||||
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
|
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
|
||||||
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||||
|
github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g=
|
||||||
|
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
|
||||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||||
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
@@ -31,6 +33,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu
|
|||||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
|
||||||
|
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
|
||||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||||
@@ -43,6 +47,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
|||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
|
||||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
@@ -80,6 +86,10 @@ github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdr
|
|||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
|
||||||
|
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||||
|
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||||
|
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||||
go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI=
|
go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI=
|
||||||
@@ -103,8 +113,6 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
|||||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
|
||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -1,102 +1,56 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/hex"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/csrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// csrfTokenLength is the byte length of generated CSRF tokens.
|
|
||||||
// 32 bytes = 64 hex characters, providing 256 bits of entropy.
|
|
||||||
csrfTokenLength = 32
|
|
||||||
|
|
||||||
// csrfSessionKey is the session key where the CSRF token is stored.
|
|
||||||
csrfSessionKey = "csrf_token"
|
|
||||||
|
|
||||||
// csrfFormField is the HTML form field name for the CSRF token.
|
|
||||||
csrfFormField = "csrf_token"
|
|
||||||
)
|
|
||||||
|
|
||||||
// csrfContextKey is the context key type for CSRF tokens.
|
|
||||||
type csrfContextKey struct{}
|
|
||||||
|
|
||||||
// CSRFToken retrieves the CSRF token from the request context.
|
// CSRFToken retrieves the CSRF token from the request context.
|
||||||
// Returns an empty string if no token is present.
|
// Returns an empty string if the gorilla/csrf middleware has not run.
|
||||||
func CSRFToken(r *http.Request) string {
|
func CSRFToken(r *http.Request) string {
|
||||||
if token, ok := r.Context().Value(csrfContextKey{}).(string); ok {
|
return csrf.Token(r)
|
||||||
return token
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRF returns middleware that provides CSRF protection for state-changing
|
// CSRF returns middleware that provides CSRF protection using the
|
||||||
// requests. For every request, it ensures a CSRF token exists in the
|
// gorilla/csrf library. The middleware uses the session authentication
|
||||||
// session and makes it available via the request context. For POST, PUT,
|
// key to sign a CSRF cookie and validates a masked token submitted via
|
||||||
// PATCH, and DELETE requests, it validates the submitted csrf_token form
|
// the "csrf_token" form field (or the "X-CSRF-Token" header) on
|
||||||
// field against the session token. Requests with an invalid or missing
|
// POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing
|
||||||
// token receive a 403 Forbidden response.
|
// token receive a 403 Forbidden response.
|
||||||
|
//
|
||||||
|
// In development mode, requests are marked as plaintext HTTP so that
|
||||||
|
// gorilla/csrf skips the strict Referer-origin check (which is only
|
||||||
|
// meaningful over TLS).
|
||||||
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
protect := csrf.Protect(
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
m.session.GetKey(),
|
||||||
sess, err := m.session.Get(r)
|
csrf.FieldName("csrf_token"),
|
||||||
if err != nil {
|
csrf.Secure(!m.params.Config.IsDev()),
|
||||||
m.log.Error("csrf: failed to get session", "error", err)
|
csrf.SameSite(csrf.SameSiteLaxMode),
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
csrf.Path("/"),
|
||||||
return
|
csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
m.log.Warn("csrf: token validation failed",
|
||||||
|
|
||||||
// Ensure a CSRF token exists in the session
|
|
||||||
token, ok := sess.Values[csrfSessionKey].(string)
|
|
||||||
if !ok {
|
|
||||||
token = ""
|
|
||||||
}
|
|
||||||
if token == "" {
|
|
||||||
token, err = generateCSRFToken()
|
|
||||||
if err != nil {
|
|
||||||
m.log.Error("csrf: failed to generate token", "error", err)
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sess.Values[csrfSessionKey] = token
|
|
||||||
if saveErr := m.session.Save(r, w, sess); saveErr != nil {
|
|
||||||
m.log.Error("csrf: failed to save session", "error", saveErr)
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store token in context for templates
|
|
||||||
ctx := context.WithValue(r.Context(), csrfContextKey{}, token)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
// Validate token on state-changing methods
|
|
||||||
switch r.Method {
|
|
||||||
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
|
|
||||||
submitted := r.FormValue(csrfFormField)
|
|
||||||
if subtle.ConstantTimeCompare([]byte(submitted), []byte(token)) != 1 {
|
|
||||||
m.log.Warn("csrf: token mismatch",
|
|
||||||
"method", r.Method,
|
"method", r.Method,
|
||||||
"path", r.URL.Path,
|
"path", r.URL.Path,
|
||||||
"remote_addr", r.RemoteAddr,
|
"remote_addr", r.RemoteAddr,
|
||||||
|
"reason", csrf.FailureReason(r),
|
||||||
)
|
)
|
||||||
http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden)
|
http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden)
|
||||||
return
|
})),
|
||||||
}
|
)
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
// In development (plaintext HTTP), signal gorilla/csrf to skip
|
||||||
|
// the strict TLS Referer check by injecting the PlaintextHTTP
|
||||||
|
// context key before the CSRF handler sees the request.
|
||||||
|
if m.params.Config.IsDev() {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
csrfHandler := protect(next)
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateCSRFToken creates a cryptographically random hex-encoded token.
|
return protect
|
||||||
func generateCSRFToken() (string, error) {
|
|
||||||
b := make([]byte, csrfTokenLength)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(b), nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,20 +27,19 @@ func TestCSRF_GETSetsToken(t *testing.T) {
|
|||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
||||||
assert.Len(t, gotToken, csrfTokenLength*2, "CSRF token should be hex-encoded 32 bytes")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
// Use a separate handler for the GET to capture the token
|
// Capture the token from a GET request
|
||||||
var token string
|
var token string
|
||||||
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
csrfMiddleware := m.CSRF()
|
||||||
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||||
token = CSRFToken(r)
|
token = CSRFToken(r)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// GET to establish the session and capture token
|
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||||
getW := httptest.NewRecorder()
|
getW := httptest.NewRecorder()
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
getHandler.ServeHTTP(getW, getReq)
|
||||||
@@ -49,14 +48,13 @@ func TestCSRF_POSTWithValidToken(t *testing.T) {
|
|||||||
require.NotEmpty(t, cookies)
|
require.NotEmpty(t, cookies)
|
||||||
require.NotEmpty(t, token)
|
require.NotEmpty(t, token)
|
||||||
|
|
||||||
// POST handler that tracks whether it was called
|
// POST with valid token and cookies from the GET response
|
||||||
var called bool
|
var called bool
|
||||||
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
called = true
|
called = true
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// POST with valid token
|
form := url.Values{"csrf_token": {token}}
|
||||||
form := url.Values{csrfFormField: {token}}
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
@@ -73,23 +71,21 @@ func TestCSRF_POSTWithoutToken(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
// GET handler to establish session
|
csrfMiddleware := m.CSRF()
|
||||||
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
// no-op — just establishes session
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
// GET to establish the CSRF cookie
|
||||||
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||||
getW := httptest.NewRecorder()
|
getW := httptest.NewRecorder()
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
getHandler.ServeHTTP(getW, getReq)
|
||||||
cookies := getW.Result().Cookies()
|
cookies := getW.Result().Cookies()
|
||||||
|
|
||||||
// POST handler that tracks whether it was called
|
// POST without CSRF token
|
||||||
var called bool
|
var called bool
|
||||||
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
called = true
|
called = true
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// POST without CSRF token
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
@@ -107,24 +103,22 @@ func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
// GET handler to establish session
|
csrfMiddleware := m.CSRF()
|
||||||
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
// no-op — just establishes session
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
// GET to establish the CSRF cookie
|
||||||
|
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||||
getW := httptest.NewRecorder()
|
getW := httptest.NewRecorder()
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
getHandler.ServeHTTP(getW, getReq)
|
||||||
cookies := getW.Result().Cookies()
|
cookies := getW.Result().Cookies()
|
||||||
|
|
||||||
// POST handler that tracks whether it was called
|
// POST with wrong CSRF token
|
||||||
var called bool
|
var called bool
|
||||||
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
called = true
|
called = true
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// POST with wrong CSRF token
|
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
||||||
form := url.Values{csrfFormField: {"invalid-token-value"}}
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
@@ -156,20 +150,8 @@ func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
|||||||
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFToken_NoContext(t *testing.T) {
|
func TestCSRFToken_NoMiddleware(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when no token in context")
|
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run")
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateCSRFToken(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
token, err := generateCSRFToken()
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, token, csrfTokenLength*2, "token should be hex-encoded")
|
|
||||||
|
|
||||||
// Verify uniqueness
|
|
||||||
token2, err := generateCSRFToken()
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotEqual(t, token, token2, "each generated token should be unique")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
|
|||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
sessManager := newTestSession(t, store, cfg, log)
|
sessManager := newTestSession(t, store, cfg, log, key)
|
||||||
|
|
||||||
m := &Middleware{
|
m := &Middleware{
|
||||||
log: log,
|
log: log,
|
||||||
@@ -55,9 +55,9 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
|
|||||||
|
|
||||||
// newTestSession creates a session.Session with a pre-configured cookie store
|
// newTestSession creates a session.Session with a pre-configured cookie store
|
||||||
// for testing. This avoids needing the fx lifecycle and database.
|
// for testing. This avoids needing the fx lifecycle and database.
|
||||||
func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *session.Session {
|
func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *session.Session {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
return session.NewForTest(store, cfg, log)
|
return session.NewForTest(store, cfg, log, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Logging Middleware Tests ---
|
// --- Logging Middleware Tests ---
|
||||||
@@ -326,7 +326,7 @@ func TestMetricsAuth_ValidCredentials(t *testing.T) {
|
|||||||
store := sessions.NewCookieStore(key)
|
store := sessions.NewCookieStore(key)
|
||||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||||
|
|
||||||
sessManager := session.NewForTest(store, cfg, log)
|
sessManager := session.NewForTest(store, cfg, log, key)
|
||||||
|
|
||||||
m := &Middleware{
|
m := &Middleware{
|
||||||
log: log,
|
log: log,
|
||||||
@@ -366,7 +366,7 @@ func TestMetricsAuth_InvalidCredentials(t *testing.T) {
|
|||||||
store := sessions.NewCookieStore(key)
|
store := sessions.NewCookieStore(key)
|
||||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||||
|
|
||||||
sessManager := session.NewForTest(store, cfg, log)
|
sessManager := session.NewForTest(store, cfg, log, key)
|
||||||
|
|
||||||
m := &Middleware{
|
m := &Middleware{
|
||||||
log: log,
|
log: log,
|
||||||
@@ -406,7 +406,7 @@ func TestMetricsAuth_NoCredentials(t *testing.T) {
|
|||||||
store := sessions.NewCookieStore(key)
|
store := sessions.NewCookieStore(key)
|
||||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||||
|
|
||||||
sessManager := session.NewForTest(store, cfg, log)
|
sessManager := session.NewForTest(store, cfg, log, key)
|
||||||
|
|
||||||
m := &Middleware{
|
m := &Middleware{
|
||||||
log: log,
|
log: log,
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/time/rate"
|
"github.com/go-chi/httprate"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -15,158 +13,36 @@ const (
|
|||||||
|
|
||||||
// loginRateInterval is the time window for the rate limit.
|
// loginRateInterval is the time window for the rate limit.
|
||||||
loginRateInterval = 1 * time.Minute
|
loginRateInterval = 1 * time.Minute
|
||||||
|
|
||||||
// limiterCleanupInterval is how often stale per-IP limiters are pruned.
|
|
||||||
limiterCleanupInterval = 5 * time.Minute
|
|
||||||
|
|
||||||
// limiterMaxAge is how long an unused limiter is kept before pruning.
|
|
||||||
limiterMaxAge = 10 * time.Minute
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ipLimiter holds a rate limiter and the time it was last used.
|
|
||||||
type ipLimiter struct {
|
|
||||||
limiter *rate.Limiter
|
|
||||||
lastSeen time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// rateLimiterMap manages per-IP rate limiters with periodic cleanup.
|
|
||||||
type rateLimiterMap struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
limiters map[string]*ipLimiter
|
|
||||||
rate rate.Limit
|
|
||||||
burst int
|
|
||||||
}
|
|
||||||
|
|
||||||
// newRateLimiterMap creates a new per-IP rate limiter map.
|
|
||||||
func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap {
|
|
||||||
rlm := &rateLimiterMap{
|
|
||||||
limiters: make(map[string]*ipLimiter),
|
|
||||||
rate: r,
|
|
||||||
burst: burst,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start background cleanup goroutine
|
|
||||||
go rlm.cleanup()
|
|
||||||
|
|
||||||
return rlm
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLimiter returns the rate limiter for the given IP, creating one if
|
|
||||||
// it doesn't exist.
|
|
||||||
func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter {
|
|
||||||
rlm.mu.Lock()
|
|
||||||
defer rlm.mu.Unlock()
|
|
||||||
|
|
||||||
if entry, ok := rlm.limiters[ip]; ok {
|
|
||||||
entry.lastSeen = time.Now()
|
|
||||||
return entry.limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
limiter := rate.NewLimiter(rlm.rate, rlm.burst)
|
|
||||||
rlm.limiters[ip] = &ipLimiter{
|
|
||||||
limiter: limiter,
|
|
||||||
lastSeen: time.Now(),
|
|
||||||
}
|
|
||||||
return limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup periodically removes stale rate limiters to prevent unbounded
|
|
||||||
// memory growth from unique IPs.
|
|
||||||
func (rlm *rateLimiterMap) cleanup() {
|
|
||||||
ticker := time.NewTicker(limiterCleanupInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
rlm.mu.Lock()
|
|
||||||
cutoff := time.Now().Add(-limiterMaxAge)
|
|
||||||
for ip, entry := range rlm.limiters {
|
|
||||||
if entry.lastSeen.Before(cutoff) {
|
|
||||||
delete(rlm.limiters, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rlm.mu.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoginRateLimit returns middleware that enforces per-IP rate limiting
|
// LoginRateLimit returns middleware that enforces per-IP rate limiting
|
||||||
// on login attempts. Only POST requests are rate-limited; GET requests
|
// on login attempts using go-chi/httprate. Only POST requests are
|
||||||
// (rendering the login form) pass through unaffected. When the rate
|
// rate-limited; GET requests (rendering the login form) pass through
|
||||||
// limit is exceeded, a 429 Too Many Requests response is returned.
|
// unaffected. When the rate limit is exceeded, a 429 Too Many Requests
|
||||||
|
// response is returned. IP extraction honours X-Forwarded-For,
|
||||||
|
// X-Real-IP, and True-Client-IP headers for reverse-proxy setups.
|
||||||
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||||
// Calculate rate: loginRateLimit events per loginRateInterval
|
limiter := httprate.Limit(
|
||||||
r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds())
|
loginRateLimit,
|
||||||
rlm := newRateLimiterMap(r, loginRateLimit)
|
loginRateInterval,
|
||||||
|
httprate.WithKeyFuncs(httprate.KeyByRealIP),
|
||||||
|
httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
m.log.Warn("login rate limit exceeded",
|
||||||
|
"path", r.URL.Path,
|
||||||
|
)
|
||||||
|
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
limited := limiter(next)
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Only rate-limit POST requests (actual login attempts)
|
// Only rate-limit POST requests (actual login attempts)
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
limited.ServeHTTP(w, r)
|
||||||
ip := extractIP(r)
|
|
||||||
limiter := rlm.getLimiter(ip)
|
|
||||||
|
|
||||||
if !limiter.Allow() {
|
|
||||||
m.log.Warn("login rate limit exceeded",
|
|
||||||
"ip", ip,
|
|
||||||
"path", r.URL.Path,
|
|
||||||
)
|
|
||||||
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractIP extracts the client IP address from the request. It checks
|
|
||||||
// X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
|
|
||||||
// then falls back to RemoteAddr.
|
|
||||||
func extractIP(r *http.Request) string {
|
|
||||||
// Check X-Forwarded-For header (first IP in chain)
|
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
||||||
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2
|
|
||||||
// The first one is the original client
|
|
||||||
for i := 0; i < len(xff); i++ {
|
|
||||||
if xff[i] == ',' {
|
|
||||||
ip := xff[:i]
|
|
||||||
// Trim whitespace
|
|
||||||
for len(ip) > 0 && ip[0] == ' ' {
|
|
||||||
ip = ip[1:]
|
|
||||||
}
|
|
||||||
for len(ip) > 0 && ip[len(ip)-1] == ' ' {
|
|
||||||
ip = ip[:len(ip)-1]
|
|
||||||
}
|
|
||||||
if ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
trimmed := xff
|
|
||||||
for len(trimmed) > 0 && trimmed[0] == ' ' {
|
|
||||||
trimmed = trimmed[1:]
|
|
||||||
}
|
|
||||||
for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' {
|
|
||||||
trimmed = trimmed[:len(trimmed)-1]
|
|
||||||
}
|
|
||||||
if trimmed != "" {
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check X-Real-IP header
|
|
||||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
||||||
return xri
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to RemoteAddr
|
|
||||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -88,34 +88,3 @@ func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
|
|||||||
handler.ServeHTTP(w2, req2)
|
handler.ServeHTTP(w2, req2)
|
||||||
assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected")
|
assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractIP_RemoteAddr(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.100:54321"
|
|
||||||
assert.Equal(t, "192.168.1.100", extractIP(req))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractIP_XForwardedFor(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "10.0.0.1:1234"
|
|
||||||
req.Header.Set("X-Forwarded-For", "203.0.113.50, 70.41.3.18, 150.172.238.178")
|
|
||||||
assert.Equal(t, "203.0.113.50", extractIP(req))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractIP_XRealIP(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "10.0.0.1:1234"
|
|
||||||
req.Header.Set("X-Real-IP", "203.0.113.50")
|
|
||||||
assert.Equal(t, "203.0.113.50", extractIP(req))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractIP_XForwardedForSingle(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "10.0.0.1:1234"
|
|
||||||
req.Header.Set("X-Forwarded-For", "203.0.113.50")
|
|
||||||
assert.Equal(t, "203.0.113.50", extractIP(req))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type SessionParams struct {
|
|||||||
// Session manages encrypted session storage
|
// Session manages encrypted session storage
|
||||||
type Session struct {
|
type Session struct {
|
||||||
store *sessions.CookieStore
|
store *sessions.CookieStore
|
||||||
|
key []byte // raw 32-byte auth key, also used for CSRF cookie signing
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
config *config.Config
|
config *config.Config
|
||||||
}
|
}
|
||||||
@@ -79,6 +80,7 @@ func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
|
|||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.key = keyBytes
|
||||||
s.store = store
|
s.store = store
|
||||||
s.log.Info("session manager initialized")
|
s.log.Info("session manager initialized")
|
||||||
return nil
|
return nil
|
||||||
@@ -93,6 +95,12 @@ func (s *Session) Get(r *http.Request) (*sessions.Session, error) {
|
|||||||
return s.store.Get(r, SessionName)
|
return s.store.Get(r, SessionName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetKey returns the raw 32-byte authentication key used for session
|
||||||
|
// encryption. This key is also suitable for CSRF cookie signing.
|
||||||
|
func (s *Session) GetKey() []byte {
|
||||||
|
return s.key
|
||||||
|
}
|
||||||
|
|
||||||
// Save saves the session
|
// Save saves the session
|
||||||
func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
|
func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
|
||||||
return sess.Save(r, w)
|
return sess.Save(r, w)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func testSession(t *testing.T) *Session {
|
|||||||
}
|
}
|
||||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||||
|
|
||||||
return NewForTest(store, cfg, log)
|
return NewForTest(store, cfg, log, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Get and Save Tests ---
|
// --- Get and Save Tests ---
|
||||||
|
|||||||
@@ -9,10 +9,13 @@ import (
|
|||||||
|
|
||||||
// NewForTest creates a Session with a pre-configured cookie store for use
|
// NewForTest creates a Session with a pre-configured cookie store for use
|
||||||
// in tests. This bypasses the fx lifecycle and database dependency, allowing
|
// in tests. This bypasses the fx lifecycle and database dependency, allowing
|
||||||
// middleware and handler tests to use real session functionality.
|
// middleware and handler tests to use real session functionality. The key
|
||||||
func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *Session {
|
// parameter is the raw 32-byte authentication key used for session encryption
|
||||||
|
// and CSRF cookie signing.
|
||||||
|
func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *Session {
|
||||||
return &Session{
|
return &Session{
|
||||||
store: store,
|
store: store,
|
||||||
|
key: key,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user