Add rate limiting to login endpoint to prevent brute force (closes #12) #14

Merged
sneak merged 3 commits from :fix/issue-12 into main 2026-02-16 06:15:49 +01:00
2 changed files with 82 additions and 10 deletions
Showing only changes of commit a1b06219e7 - Show all commits

View File

@ -2,7 +2,9 @@
package middleware package middleware
import ( import (
"fmt"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"sync" "sync"
@ -166,33 +168,64 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
// loginRateLimit configures the login rate limiter. // loginRateLimit configures the login rate limiter.
const ( const (
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
loginBurst = 5 // allow burst of 5 loginBurst = 5 // allow burst of 5
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
limiterCleanupEvery = 1 * time.Minute // sweep interval
) )
// ipLimiter tracks per-IP rate limiters for login attempts. // ipLimiterEntry stores a rate limiter with its last-seen timestamp.
type ipLimiterEntry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// ipLimiter tracks per-IP rate limiters for login attempts with automatic
// eviction of stale entries to prevent unbounded memory growth.
type ipLimiter struct { type ipLimiter struct {
mu sync.Mutex mu sync.Mutex
limiters map[string]*rate.Limiter limiters map[string]*ipLimiterEntry
lastSweep time.Time
} }
func newIPLimiter() *ipLimiter { func newIPLimiter() *ipLimiter {
return &ipLimiter{ return &ipLimiter{
limiters: make(map[string]*rate.Limiter), limiters: make(map[string]*ipLimiterEntry),
lastSweep: time.Now(),
} }
} }
// sweep removes entries not seen within limiterExpiry. Must be called with mu held.
func (i *ipLimiter) sweep(now time.Time) {
for ip, entry := range i.limiters {
if now.Sub(entry.lastSeen) > limiterExpiry {
delete(i.limiters, ip)
}
}
i.lastSweep = now
}
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter { func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
i.mu.Lock() i.mu.Lock()
defer i.mu.Unlock() defer i.mu.Unlock()
limiter, exists := i.limiters[ip] now := time.Now()
if !exists {
limiter = rate.NewLimiter(loginRateLimit, loginBurst) // Lazy sweep: clean up stale entries periodically.
i.limiters[ip] = limiter if now.Sub(i.lastSweep) >= limiterCleanupEvery {
i.sweep(now)
} }
return limiter entry, exists := i.limiters[ip]
if !exists {
entry = &ipLimiterEntry{
limiter: rate.NewLimiter(loginRateLimit, loginBurst),
}
i.limiters[ip] = entry
}
entry.lastSeen = now
return entry.limiter
} }
// loginLimiter is the singleton IP rate limiter for login attempts. // loginLimiter is the singleton IP rate limiter for login attempts.
@ -215,6 +248,17 @@ func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
m.log.WarnContext(request.Context(), "login rate limit exceeded", m.log.WarnContext(request.Context(), "login rate limit exceeded",
"remoteIP", ip, "remoteIP", ip,
) )
// Compute seconds until the next token is available.
reservation := limiter.Reserve()
delay := reservation.Delay()
reservation.Cancel()
retryAfter := int(math.Ceil(delay.Seconds()))
if retryAfter < 1 {
retryAfter = 1
}
writer.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
http.Error( http.Error(
writer, writer,
"Too Many Requests", "Too Many Requests",

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -104,4 +105,31 @@ func TestLoginRateLimitReturns429Body(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code) assert.Equal(t, http.StatusTooManyRequests, rec.Code)
assert.Contains(t, rec.Body.String(), "Too Many Requests") assert.Contains(t, rec.Body.String(), "Too Many Requests")
assert.NotEmpty(t, rec.Header().Get("Retry-After"), "should include Retry-After header")
}
func TestIPLimiterEvictsStaleEntries(t *testing.T) {
il := newIPLimiter()
// Add an entry and backdate its lastSeen
il.mu.Lock()
il.limiters["1.2.3.4"] = &ipLimiterEntry{
limiter: nil,
lastSeen: time.Now().Add(-15 * time.Minute),
}
il.limiters["5.6.7.8"] = &ipLimiterEntry{
limiter: nil,
lastSeen: time.Now(),
}
il.mu.Unlock()
// Trigger sweep
il.mu.Lock()
il.sweep(time.Now())
il.mu.Unlock()
il.mu.Lock()
defer il.mu.Unlock()
assert.NotContains(t, il.limiters, "1.2.3.4", "stale entry should be evicted")
assert.Contains(t, il.limiters, "5.6.7.8", "fresh entry should remain")
} }