package middleware import ( "log/slog" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "git.eeqj.de/sneak/upaas/internal/config" ) func newTestMiddleware(t *testing.T) *Middleware { t.Helper() return &Middleware{ log: slog.Default(), params: &Params{ Config: &config.Config{}, }, } } func TestLoginRateLimitAllowsUpToBurst(t *testing.T) { // Reset the global limiter to get clean state loginLimiter = newIPLimiter() mw := newTestMiddleware(t) handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // First 5 requests should succeed (burst) for i := range 5 { req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1) } // 6th request should be rate limited req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited") } func TestLoginRateLimitIsolatesIPs(t *testing.T) { loginLimiter = newIPLimiter() mw := newTestMiddleware(t) handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // Exhaust IP1's budget for range 5 { req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "10.0.0.1:1234" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) } // IP1 should be blocked req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "10.0.0.1:1234" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code) // IP2 should still work req2 := httptest.NewRequest(http.MethodPost, "/login", nil) req2.RemoteAddr = "10.0.0.2:1234" rec2 := httptest.NewRecorder() handler.ServeHTTP(rec2, req2) assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited") } func TestLoginRateLimitReturns429Body(t *testing.T) { loginLimiter = newIPLimiter() mw := newTestMiddleware(t) handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // Exhaust burst for range 5 { req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "172.16.0.1:5555" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) } req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "172.16.0.1:5555" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code) assert.Contains(t, rec.Body.String(), "Too Many Requests") }