refactor: replace custom CSRF and rate-limiting with off-the-shelf libraries
All checks were successful
check / check (push) Successful in 4s
All checks were successful
check / check (push) Successful in 4s
Replace custom CSRF middleware with gorilla/csrf and custom rate-limiting middleware with go-chi/httprate, as requested in code review. CSRF changes: - Replace session-based CSRF tokens with gorilla/csrf cookie-based double-submit pattern (HMAC-authenticated cookies) - Keep same form field name (csrf_token) for template compatibility - Keep same route exclusions (webhook/API routes) - In dev mode, mark requests as plaintext HTTP to skip Referer check Rate limiting changes: - Replace custom token-bucket rate limiter with httprate sliding-window counter (per-IP, 5 POST requests/min on login endpoint) - Remove custom IP extraction (httprate.KeyByRealIP handles X-Forwarded-For, X-Real-IP, True-Client-IP) - Remove custom cleanup goroutine (httprate manages its own state) Kept as-is: - SSRF prevention code (internal/delivery/ssrf.go) — application-specific - CSRFToken() wrapper function — handlers unchanged Updated README security section and architecture overview to reflect library choices.
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"github.com/go-chi/httprate"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -15,158 +13,36 @@ const (
|
||||
|
||||
// loginRateInterval is the time window for the rate limit.
|
||||
loginRateInterval = 1 * time.Minute
|
||||
|
||||
// limiterCleanupInterval is how often stale per-IP limiters are pruned.
|
||||
limiterCleanupInterval = 5 * time.Minute
|
||||
|
||||
// limiterMaxAge is how long an unused limiter is kept before pruning.
|
||||
limiterMaxAge = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ipLimiter holds a rate limiter and the time it was last used.
|
||||
type ipLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// rateLimiterMap manages per-IP rate limiters with periodic cleanup.
|
||||
type rateLimiterMap struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*ipLimiter
|
||||
rate rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
// newRateLimiterMap creates a new per-IP rate limiter map.
|
||||
func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap {
|
||||
rlm := &rateLimiterMap{
|
||||
limiters: make(map[string]*ipLimiter),
|
||||
rate: r,
|
||||
burst: burst,
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go rlm.cleanup()
|
||||
|
||||
return rlm
|
||||
}
|
||||
|
||||
// getLimiter returns the rate limiter for the given IP, creating one if
|
||||
// it doesn't exist.
|
||||
func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter {
|
||||
rlm.mu.Lock()
|
||||
defer rlm.mu.Unlock()
|
||||
|
||||
if entry, ok := rlm.limiters[ip]; ok {
|
||||
entry.lastSeen = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
limiter := rate.NewLimiter(rlm.rate, rlm.burst)
|
||||
rlm.limiters[ip] = &ipLimiter{
|
||||
limiter: limiter,
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanup periodically removes stale rate limiters to prevent unbounded
|
||||
// memory growth from unique IPs.
|
||||
func (rlm *rateLimiterMap) cleanup() {
|
||||
ticker := time.NewTicker(limiterCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rlm.mu.Lock()
|
||||
cutoff := time.Now().Add(-limiterMaxAge)
|
||||
for ip, entry := range rlm.limiters {
|
||||
if entry.lastSeen.Before(cutoff) {
|
||||
delete(rlm.limiters, ip)
|
||||
}
|
||||
}
|
||||
rlm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRateLimit returns middleware that enforces per-IP rate limiting
|
||||
// on login attempts. Only POST requests are rate-limited; GET requests
|
||||
// (rendering the login form) pass through unaffected. When the rate
|
||||
// limit is exceeded, a 429 Too Many Requests response is returned.
|
||||
// on login attempts using go-chi/httprate. Only POST requests are
|
||||
// rate-limited; GET requests (rendering the login form) pass through
|
||||
// unaffected. When the rate limit is exceeded, a 429 Too Many Requests
|
||||
// response is returned. IP extraction honours X-Forwarded-For,
|
||||
// X-Real-IP, and True-Client-IP headers for reverse-proxy setups.
|
||||
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||
// Calculate rate: loginRateLimit events per loginRateInterval
|
||||
r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds())
|
||||
rlm := newRateLimiterMap(r, loginRateLimit)
|
||||
limiter := httprate.Limit(
|
||||
loginRateLimit,
|
||||
loginRateInterval,
|
||||
httprate.WithKeyFuncs(httprate.KeyByRealIP),
|
||||
httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
m.log.Warn("login rate limit exceeded",
|
||||
"path", r.URL.Path,
|
||||
)
|
||||
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
||||
})),
|
||||
)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
limited := limiter(next)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only rate-limit POST requests (actual login attempts)
|
||||
if r.Method != http.MethodPost {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ip := extractIP(r)
|
||||
limiter := rlm.getLimiter(ip)
|
||||
|
||||
if !limiter.Allow() {
|
||||
m.log.Warn("login rate limit exceeded",
|
||||
"ip", ip,
|
||||
"path", r.URL.Path,
|
||||
)
|
||||
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
limited.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractIP extracts the client IP address from the request. It checks
|
||||
// X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
|
||||
// then falls back to RemoteAddr.
|
||||
func extractIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (first IP in chain)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2
|
||||
// The first one is the original client
|
||||
for i := 0; i < len(xff); i++ {
|
||||
if xff[i] == ',' {
|
||||
ip := xff[:i]
|
||||
// Trim whitespace
|
||||
for len(ip) > 0 && ip[0] == ' ' {
|
||||
ip = ip[1:]
|
||||
}
|
||||
for len(ip) > 0 && ip[len(ip)-1] == ' ' {
|
||||
ip = ip[:len(ip)-1]
|
||||
}
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
trimmed := xff
|
||||
for len(trimmed) > 0 && trimmed[0] == ' ' {
|
||||
trimmed = trimmed[1:]
|
||||
}
|
||||
for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' {
|
||||
trimmed = trimmed[:len(trimmed)-1]
|
||||
}
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user