diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 637b51d..6e465e4 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -5,6 +5,7 @@ import ( "log/slog" "net" "net/http" + "sync" "time" "github.com/99designs/basicauth-go" @@ -12,6 +13,7 @@ import ( "github.com/go-chi/cors" "github.com/gorilla/csrf" "go.uber.org/fx" + "golang.org/x/time/rate" "git.eeqj.de/sneak/upaas/internal/config" "git.eeqj.de/sneak/upaas/internal/globals" @@ -162,6 +164,71 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler { ) } +// loginRateLimit configures the login rate limiter. +const ( + loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds + loginBurst = 5 // allow burst of 5 +) + +// ipLimiter tracks per-IP rate limiters for login attempts. +type ipLimiter struct { + mu sync.Mutex + limiters map[string]*rate.Limiter +} + +func newIPLimiter() *ipLimiter { + return &ipLimiter{ + limiters: make(map[string]*rate.Limiter), + } +} + +func (i *ipLimiter) getLimiter(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + limiter, exists := i.limiters[ip] + if !exists { + limiter = rate.NewLimiter(loginRateLimit, loginBurst) + i.limiters[ip] = limiter + } + + return limiter +} + +// loginLimiter is the singleton IP rate limiter for login attempts. +// +//nolint:gochecknoglobals // intentional singleton for rate limiting state +var loginLimiter = newIPLimiter() + +// LoginRateLimit returns middleware that rate-limits login attempts per IP. +// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded. +func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func( + writer http.ResponseWriter, + request *http.Request, + ) { + ip := ipFromHostPort(request.RemoteAddr) + limiter := loginLimiter.getLimiter(ip) + + if !limiter.Allow() { + m.log.WarnContext(request.Context(), "login rate limit exceeded", + "remoteIP", ip, + ) + http.Error( + writer, + "Too Many Requests", + http.StatusTooManyRequests, + ) + + return + } + + next.ServeHTTP(writer, request) + }) + } +} + // SetupRequired returns middleware that redirects to setup if no user exists. func (m *Middleware) SetupRequired() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { diff --git a/internal/middleware/ratelimit_test.go b/internal/middleware/ratelimit_test.go new file mode 100644 index 0000000..ea3a448 --- /dev/null +++ b/internal/middleware/ratelimit_test.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "git.eeqj.de/sneak/upaas/internal/config" +) + +func newTestMiddleware(t *testing.T) *Middleware { + t.Helper() + + return &Middleware{ + log: slog.Default(), + params: &Params{ + Config: &config.Config{}, + }, + } +} + +func TestLoginRateLimitAllowsUpToBurst(t *testing.T) { + // Reset the global limiter to get clean state + loginLimiter = newIPLimiter() + + mw := newTestMiddleware(t) + + handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First 5 requests should succeed (burst) + for i := range 5 { + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1) + } + + // 6th request should be rate limited + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited") +} + +func TestLoginRateLimitIsolatesIPs(t *testing.T) { + loginLimiter = newIPLimiter() + + mw := newTestMiddleware(t) + + handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust IP1's budget + for range 5 { + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "10.0.0.1:1234" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + } + + // IP1 should be blocked + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "10.0.0.1:1234" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code) + + // IP2 should still work + req2 := httptest.NewRequest(http.MethodPost, "/login", nil) + req2.RemoteAddr = "10.0.0.2:1234" + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited") +} + +func TestLoginRateLimitReturns429Body(t *testing.T) { + loginLimiter = newIPLimiter() + + mw := newTestMiddleware(t) + + handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust burst + for range 5 { + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "172.16.0.1:5555" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + } + + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "172.16.0.1:5555" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code) + assert.Contains(t, rec.Body.String(), "Too Many Requests") +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 860d0b9..2fb632f 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -46,7 +46,7 @@ func (s *Server) SetupRoutes() { // Public routes r.Get("/login", s.handlers.HandleLoginGET()) - r.Post("/login", s.handlers.HandleLoginPOST()) + r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleLoginPOST()) r.Get("/setup", s.handlers.HandleSetupGET()) r.Post("/setup", s.handlers.HandleSetupPOST())