- Store lastSeen timestamp per IP limiter entry - Lazy sweep removes entries older than 10 minutes on each request - Add Retry-After header to 429 responses - Add test for stale entry eviction Fixes memory leak under sustained attack from many IPs.
136 lines
3.5 KiB
Go
136 lines
3.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/config"
|
|
)
|
|
|
|
func newTestMiddleware(t *testing.T) *Middleware {
|
|
t.Helper()
|
|
|
|
return &Middleware{
|
|
log: slog.Default(),
|
|
params: &Params{
|
|
Config: &config.Config{},
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestLoginRateLimitAllowsUpToBurst(t *testing.T) {
|
|
// Reset the global limiter to get clean state
|
|
loginLimiter = newIPLimiter()
|
|
|
|
mw := newTestMiddleware(t)
|
|
|
|
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// First 5 requests should succeed (burst)
|
|
for i := range 5 {
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1)
|
|
}
|
|
|
|
// 6th request should be rate limited
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited")
|
|
}
|
|
|
|
func TestLoginRateLimitIsolatesIPs(t *testing.T) {
|
|
loginLimiter = newIPLimiter()
|
|
|
|
mw := newTestMiddleware(t)
|
|
|
|
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Exhaust IP1's budget
|
|
for range 5 {
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
}
|
|
|
|
// IP1 should be blocked
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
|
|
// IP2 should still work
|
|
req2 := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req2.RemoteAddr = "10.0.0.2:1234"
|
|
rec2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec2, req2)
|
|
assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited")
|
|
}
|
|
|
|
func TestLoginRateLimitReturns429Body(t *testing.T) {
|
|
loginLimiter = newIPLimiter()
|
|
|
|
mw := newTestMiddleware(t)
|
|
|
|
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Exhaust burst
|
|
for range 5 {
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "172.16.0.1:5555"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "172.16.0.1:5555"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
assert.Contains(t, rec.Body.String(), "Too Many Requests")
|
|
assert.NotEmpty(t, rec.Header().Get("Retry-After"), "should include Retry-After header")
|
|
}
|
|
|
|
func TestIPLimiterEvictsStaleEntries(t *testing.T) {
|
|
il := newIPLimiter()
|
|
|
|
// Add an entry and backdate its lastSeen
|
|
il.mu.Lock()
|
|
il.limiters["1.2.3.4"] = &ipLimiterEntry{
|
|
limiter: nil,
|
|
lastSeen: time.Now().Add(-15 * time.Minute),
|
|
}
|
|
il.limiters["5.6.7.8"] = &ipLimiterEntry{
|
|
limiter: nil,
|
|
lastSeen: time.Now(),
|
|
}
|
|
il.mu.Unlock()
|
|
|
|
// Trigger sweep
|
|
il.mu.Lock()
|
|
il.sweep(time.Now())
|
|
il.mu.Unlock()
|
|
|
|
il.mu.Lock()
|
|
defer il.mu.Unlock()
|
|
assert.NotContains(t, il.limiters, "1.2.3.4", "stale entry should be evicted")
|
|
assert.Contains(t, il.limiters, "5.6.7.8", "fresh entry should remain")
|
|
}
|