package middleware import ( "net" "net/http" "sync" "time" "golang.org/x/time/rate" ) const ( // loginRateLimit is the maximum number of login attempts per interval. loginRateLimit = 5 // loginRateInterval is the time window for the rate limit. loginRateInterval = 1 * time.Minute // limiterCleanupInterval is how often stale per-IP limiters are pruned. limiterCleanupInterval = 5 * time.Minute // limiterMaxAge is how long an unused limiter is kept before pruning. limiterMaxAge = 10 * time.Minute ) // ipLimiter holds a rate limiter and the time it was last used. type ipLimiter struct { limiter *rate.Limiter lastSeen time.Time } // rateLimiterMap manages per-IP rate limiters with periodic cleanup. type rateLimiterMap struct { mu sync.Mutex limiters map[string]*ipLimiter rate rate.Limit burst int } // newRateLimiterMap creates a new per-IP rate limiter map. func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap { rlm := &rateLimiterMap{ limiters: make(map[string]*ipLimiter), rate: r, burst: burst, } // Start background cleanup goroutine go rlm.cleanup() return rlm } // getLimiter returns the rate limiter for the given IP, creating one if // it doesn't exist. func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter { rlm.mu.Lock() defer rlm.mu.Unlock() if entry, ok := rlm.limiters[ip]; ok { entry.lastSeen = time.Now() return entry.limiter } limiter := rate.NewLimiter(rlm.rate, rlm.burst) rlm.limiters[ip] = &ipLimiter{ limiter: limiter, lastSeen: time.Now(), } return limiter } // cleanup periodically removes stale rate limiters to prevent unbounded // memory growth from unique IPs. func (rlm *rateLimiterMap) cleanup() { ticker := time.NewTicker(limiterCleanupInterval) defer ticker.Stop() for range ticker.C { rlm.mu.Lock() cutoff := time.Now().Add(-limiterMaxAge) for ip, entry := range rlm.limiters { if entry.lastSeen.Before(cutoff) { delete(rlm.limiters, ip) } } rlm.mu.Unlock() } } // LoginRateLimit returns middleware that enforces per-IP rate limiting // on login attempts. Only POST requests are rate-limited; GET requests // (rendering the login form) pass through unaffected. When the rate // limit is exceeded, a 429 Too Many Requests response is returned. func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { // Calculate rate: loginRateLimit events per loginRateInterval r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds()) rlm := newRateLimiterMap(r, loginRateLimit) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only rate-limit POST requests (actual login attempts) if r.Method != http.MethodPost { next.ServeHTTP(w, r) return } ip := extractIP(r) limiter := rlm.getLimiter(ip) if !limiter.Allow() { m.log.Warn("login rate limit exceeded", "ip", ip, "path", r.URL.Path, ) http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } } // extractIP extracts the client IP address from the request. It checks // X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), // then falls back to RemoteAddr. func extractIP(r *http.Request) string { // Check X-Forwarded-For header (first IP in chain) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2 // The first one is the original client for i := 0; i < len(xff); i++ { if xff[i] == ',' { ip := xff[:i] // Trim whitespace for len(ip) > 0 && ip[0] == ' ' { ip = ip[1:] } for len(ip) > 0 && ip[len(ip)-1] == ' ' { ip = ip[:len(ip)-1] } if ip != "" { return ip } break } } trimmed := xff for len(trimmed) > 0 && trimmed[0] == ' ' { trimmed = trimmed[1:] } for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' { trimmed = trimmed[:len(trimmed)-1] } if trimmed != "" { return trimmed } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return ip }