diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 637b51d..daf15f1 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -2,9 +2,13 @@ package middleware import ( + "fmt" "log/slog" + "math" "net" "net/http" + "strings" + "sync" "time" "github.com/99designs/basicauth-go" @@ -12,6 +16,7 @@ import ( "github.com/go-chi/cors" "github.com/gorilla/csrf" "go.uber.org/fx" + "golang.org/x/time/rate" "git.eeqj.de/sneak/upaas/internal/config" "git.eeqj.de/sneak/upaas/internal/globals" @@ -86,7 +91,7 @@ func (m *Middleware) Logging() func(http.Handler) http.Handler { "request_id", reqID, "referer", request.Referer(), "proto", request.Proto, - "remoteIP", ipFromHostPort(request.RemoteAddr), + "remoteIP", realIP(request), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) @@ -106,6 +111,28 @@ func ipFromHostPort(hostPort string) string { return host } +// realIP extracts the client's real IP address from the request, +// checking proxy headers first (trusted reverse proxy like Traefik), +// then falling back to RemoteAddr. +func realIP(r *http.Request) string { + // 1. X-Real-IP (set by Traefik/nginx) + if ip := strings.TrimSpace(r.Header.Get("X-Real-IP")); ip != "" { + return ip + } + + // 2. X-Forwarded-For: take the first (leftmost/client) IP + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { //nolint:mnd + if ip := strings.TrimSpace(parts[0]); ip != "" { + return ip + } + } + } + + // 3. Fall back to RemoteAddr + return ipFromHostPort(r.RemoteAddr) +} + // CORS returns CORS middleware. func (m *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ @@ -162,6 +189,113 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler { ) } +// loginRateLimit configures the login rate limiter. +const ( + loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds + loginBurst = 5 // allow burst of 5 + limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes + limiterCleanupEvery = 1 * time.Minute // sweep interval +) + +// ipLimiterEntry stores a rate limiter with its last-seen timestamp. +type ipLimiterEntry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// ipLimiter tracks per-IP rate limiters for login attempts with automatic +// eviction of stale entries to prevent unbounded memory growth. +type ipLimiter struct { + mu sync.Mutex + limiters map[string]*ipLimiterEntry + lastSweep time.Time +} + +func newIPLimiter() *ipLimiter { + return &ipLimiter{ + limiters: make(map[string]*ipLimiterEntry), + lastSweep: time.Now(), + } +} + +// sweep removes entries not seen within limiterExpiry. Must be called with mu held. +func (i *ipLimiter) sweep(now time.Time) { + for ip, entry := range i.limiters { + if now.Sub(entry.lastSeen) > limiterExpiry { + delete(i.limiters, ip) + } + } + i.lastSweep = now +} + +func (i *ipLimiter) getLimiter(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + now := time.Now() + + // Lazy sweep: clean up stale entries periodically. + if now.Sub(i.lastSweep) >= limiterCleanupEvery { + i.sweep(now) + } + + entry, exists := i.limiters[ip] + if !exists { + entry = &ipLimiterEntry{ + limiter: rate.NewLimiter(loginRateLimit, loginBurst), + } + i.limiters[ip] = entry + } + entry.lastSeen = now + + return entry.limiter +} + +// loginLimiter is the singleton IP rate limiter for login attempts. +// +//nolint:gochecknoglobals // intentional singleton for rate limiting state +var loginLimiter = newIPLimiter() + +// LoginRateLimit returns middleware that rate-limits login attempts per IP. +// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded. +func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func( + writer http.ResponseWriter, + request *http.Request, + ) { + ip := realIP(request) + limiter := loginLimiter.getLimiter(ip) + + if !limiter.Allow() { + m.log.WarnContext(request.Context(), "login rate limit exceeded", + "remoteIP", ip, + ) + + // Compute seconds until the next token is available. + reservation := limiter.Reserve() + delay := reservation.Delay() + reservation.Cancel() + retryAfter := int(math.Ceil(delay.Seconds())) + if retryAfter < 1 { + retryAfter = 1 + } + writer.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter)) + + http.Error( + writer, + "Too Many Requests", + http.StatusTooManyRequests, + ) + + return + } + + next.ServeHTTP(writer, request) + }) + } +} + // SetupRequired returns middleware that redirects to setup if no user exists. func (m *Middleware) SetupRequired() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { diff --git a/internal/middleware/ratelimit_test.go b/internal/middleware/ratelimit_test.go new file mode 100644 index 0000000..a932ccf --- /dev/null +++ b/internal/middleware/ratelimit_test.go @@ -0,0 +1,135 @@ +package middleware + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "git.eeqj.de/sneak/upaas/internal/config" +) + +func newTestMiddleware(t *testing.T) *Middleware { + t.Helper() + + return &Middleware{ + log: slog.Default(), + params: &Params{ + Config: &config.Config{}, + }, + } +} + +func TestLoginRateLimitAllowsUpToBurst(t *testing.T) { + // Reset the global limiter to get clean state + loginLimiter = newIPLimiter() + + mw := newTestMiddleware(t) + + handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First 5 requests should succeed (burst) + for i := range 5 { + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1) + } + + // 6th request should be rate limited + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited") +} + +func TestLoginRateLimitIsolatesIPs(t *testing.T) { + loginLimiter = newIPLimiter() + + mw := newTestMiddleware(t) + + handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust IP1's budget + for range 5 { + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "10.0.0.1:1234" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + } + + // IP1 should be blocked + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "10.0.0.1:1234" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code) + + // IP2 should still work + req2 := httptest.NewRequest(http.MethodPost, "/login", nil) + req2.RemoteAddr = "10.0.0.2:1234" + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited") +} + +func TestLoginRateLimitReturns429Body(t *testing.T) { + loginLimiter = newIPLimiter() + + mw := newTestMiddleware(t) + + handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust burst + for range 5 { + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "172.16.0.1:5555" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + } + + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "172.16.0.1:5555" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code) + assert.Contains(t, rec.Body.String(), "Too Many Requests") + assert.NotEmpty(t, rec.Header().Get("Retry-After"), "should include Retry-After header") +} + +func TestIPLimiterEvictsStaleEntries(t *testing.T) { + il := newIPLimiter() + + // Add an entry and backdate its lastSeen + il.mu.Lock() + il.limiters["1.2.3.4"] = &ipLimiterEntry{ + limiter: nil, + lastSeen: time.Now().Add(-15 * time.Minute), + } + il.limiters["5.6.7.8"] = &ipLimiterEntry{ + limiter: nil, + lastSeen: time.Now(), + } + il.mu.Unlock() + + // Trigger sweep + il.mu.Lock() + il.sweep(time.Now()) + il.mu.Unlock() + + il.mu.Lock() + defer il.mu.Unlock() + assert.NotContains(t, il.limiters, "1.2.3.4", "stale entry should be evicted") + assert.Contains(t, il.limiters, "5.6.7.8", "fresh entry should remain") +} diff --git a/internal/middleware/realip_test.go b/internal/middleware/realip_test.go new file mode 100644 index 0000000..3271a85 --- /dev/null +++ b/internal/middleware/realip_test.go @@ -0,0 +1,83 @@ +package middleware + +import ( + "net/http" + "testing" +) + +func TestRealIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + xRealIP string + xff string + want string + }{ + { + name: "X-Real-IP takes priority", + remoteAddr: "10.0.0.1:1234", + xRealIP: "203.0.113.5", + xff: "198.51.100.1, 10.0.0.1", + want: "203.0.113.5", + }, + { + name: "X-Forwarded-For used when no X-Real-IP", + remoteAddr: "10.0.0.1:1234", + xff: "198.51.100.1, 10.0.0.1", + want: "198.51.100.1", + }, + { + name: "X-Forwarded-For single IP", + remoteAddr: "10.0.0.1:1234", + xff: "203.0.113.10", + want: "203.0.113.10", + }, + { + name: "falls back to RemoteAddr", + remoteAddr: "192.168.1.1:5678", + want: "192.168.1.1", + }, + { + name: "RemoteAddr without port", + remoteAddr: "192.168.1.1", + want: "192.168.1.1", + }, + { + name: "X-Real-IP with whitespace", + remoteAddr: "10.0.0.1:1234", + xRealIP: " 203.0.113.5 ", + want: "203.0.113.5", + }, + { + name: "X-Forwarded-For with whitespace", + remoteAddr: "10.0.0.1:1234", + xff: " 198.51.100.1 , 10.0.0.1", + want: "198.51.100.1", + }, + { + name: "empty X-Real-IP falls through to XFF", + remoteAddr: "10.0.0.1:1234", + xRealIP: " ", + xff: "198.51.100.1", + want: "198.51.100.1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = tt.remoteAddr + if tt.xRealIP != "" { + req.Header.Set("X-Real-IP", tt.xRealIP) + } + if tt.xff != "" { + req.Header.Set("X-Forwarded-For", tt.xff) + } + + got := realIP(req) + if got != tt.want { + t.Errorf("realIP() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 860d0b9..2fb632f 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -46,7 +46,7 @@ func (s *Server) SetupRoutes() { // Public routes r.Get("/login", s.handlers.HandleLoginGET()) - r.Post("/login", s.handlers.HandleLoginPOST()) + r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleLoginPOST()) r.Get("/setup", s.handlers.HandleSetupGET()) r.Post("/setup", s.handlers.HandleSetupPOST())