package middleware //nolint:testpackage // tests unexported types and globals import ( "log/slog" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "git.eeqj.de/sneak/upaas/internal/config" ) func newTestMiddleware(t *testing.T) *Middleware { t.Helper() return &Middleware{ log: slog.Default(), params: &Params{ Config: &config.Config{}, }, } } //nolint:paralleltest // mutates global loginLimiter func TestLoginRateLimitAllowsUpToBurst(t *testing.T) { // Reset the global limiter to get clean state loginLimiter = newIPLimiter() mw := newTestMiddleware(t) handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // First 5 requests should succeed (burst) for i := range 5 { req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1) } // 6th request should be rate limited req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited") } //nolint:paralleltest // mutates global loginLimiter func TestLoginRateLimitIsolatesIPs(t *testing.T) { loginLimiter = newIPLimiter() mw := newTestMiddleware(t) handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // Exhaust IP1's budget for range 5 { req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "10.0.0.1:1234" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) } // IP1 should be blocked req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "10.0.0.1:1234" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code) // IP2 should still work req2 := httptest.NewRequest(http.MethodPost, "/login", nil) req2.RemoteAddr = "10.0.0.2:1234" rec2 := httptest.NewRecorder() handler.ServeHTTP(rec2, req2) assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited") } //nolint:paralleltest // mutates global loginLimiter func TestLoginRateLimitReturns429Body(t *testing.T) { loginLimiter = newIPLimiter() mw := newTestMiddleware(t) handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // Exhaust burst for range 5 { req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "172.16.0.1:5555" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) } req := httptest.NewRequest(http.MethodPost, "/login", nil) req.RemoteAddr = "172.16.0.1:5555" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code) assert.Contains(t, rec.Body.String(), "Too Many Requests") assert.NotEmpty(t, rec.Header().Get("Retry-After"), "should include Retry-After header") } func TestIPLimiterEvictsStaleEntries(t *testing.T) { t.Parallel() il := newIPLimiter() // Add an entry and backdate its lastSeen il.mu.Lock() il.limiters["1.2.3.4"] = &ipLimiterEntry{ limiter: nil, lastSeen: time.Now().Add(-15 * time.Minute), } il.limiters["5.6.7.8"] = &ipLimiterEntry{ limiter: nil, lastSeen: time.Now(), } il.mu.Unlock() // Trigger sweep il.mu.Lock() il.sweep(time.Now()) il.mu.Unlock() il.mu.Lock() defer il.mu.Unlock() assert.NotContains(t, il.limiters, "1.2.3.4", "stale entry should be evicted") assert.Contains(t, il.limiters, "5.6.7.8", "fresh entry should remain") }