package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "sneak.berlin/go/webhooker/internal/config" ) func TestLoginRateLimit_AllowsGET(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var callCount int handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { callCount++ w.WriteHeader(http.StatusOK) })) // GET requests should never be rate-limited for i := 0; i < 20; i++ { req := httptest.NewRequest(http.MethodGet, "/pages/login", nil) req.RemoteAddr = "192.168.1.1:12345" w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code, "GET request %d should pass", i) } assert.Equal(t, 20, callCount) } func TestLoginRateLimit_LimitsPOST(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) var callCount int handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { callCount++ w.WriteHeader(http.StatusOK) })) // First loginRateLimit POST requests should succeed for i := 0; i < loginRateLimit; i++ { req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) req.RemoteAddr = "10.0.0.1:12345" w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code, "POST request %d should pass", i) } // Next POST should be rate-limited req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) req.RemoteAddr = "10.0.0.1:12345" w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.Equal(t, http.StatusTooManyRequests, w.Code, "POST after limit should be 429") assert.Equal(t, loginRateLimit, callCount) } func TestLoginRateLimit_IndependentPerIP(t *testing.T) { t.Parallel() m, _ := testMiddleware(t, config.EnvironmentDev) handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // Exhaust limit for IP1 for i := 0; i < loginRateLimit; i++ { req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) req.RemoteAddr = "1.2.3.4:12345" w := httptest.NewRecorder() handler.ServeHTTP(w, req) } // IP1 should be rate-limited req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) req.RemoteAddr = "1.2.3.4:12345" w := httptest.NewRecorder() handler.ServeHTTP(w, req) assert.Equal(t, http.StatusTooManyRequests, w.Code) // IP2 should still be allowed req2 := httptest.NewRequest(http.MethodPost, "/pages/login", nil) req2.RemoteAddr = "5.6.7.8:12345" w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected") } func TestExtractIP_RemoteAddr(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.100:54321" assert.Equal(t, "192.168.1.100", extractIP(req)) } func TestExtractIP_XForwardedFor(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "10.0.0.1:1234" req.Header.Set("X-Forwarded-For", "203.0.113.50, 70.41.3.18, 150.172.238.178") assert.Equal(t, "203.0.113.50", extractIP(req)) } func TestExtractIP_XRealIP(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "10.0.0.1:1234" req.Header.Set("X-Real-IP", "203.0.113.50") assert.Equal(t, "203.0.113.50", extractIP(req)) } func TestExtractIP_XForwardedForSingle(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "10.0.0.1:1234" req.Header.Set("X-Forwarded-For", "203.0.113.50") assert.Equal(t, "203.0.113.50", extractIP(req)) }