Add rate limiting to login endpoint to prevent brute force (closes #12) #14
@ -2,9 +2,13 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/99designs/basicauth-go"
|
"github.com/99designs/basicauth-go"
|
||||||
@ -12,6 +16,7 @@ import (
|
|||||||
"github.com/go-chi/cors"
|
"github.com/go-chi/cors"
|
||||||
"github.com/gorilla/csrf"
|
"github.com/gorilla/csrf"
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"git.eeqj.de/sneak/upaas/internal/config"
|
"git.eeqj.de/sneak/upaas/internal/config"
|
||||||
"git.eeqj.de/sneak/upaas/internal/globals"
|
"git.eeqj.de/sneak/upaas/internal/globals"
|
||||||
@ -86,7 +91,7 @@ func (m *Middleware) Logging() func(http.Handler) http.Handler {
|
|||||||
"request_id", reqID,
|
"request_id", reqID,
|
||||||
"referer", request.Referer(),
|
"referer", request.Referer(),
|
||||||
"proto", request.Proto,
|
"proto", request.Proto,
|
||||||
"remoteIP", ipFromHostPort(request.RemoteAddr),
|
"remoteIP", realIP(request),
|
||||||
"status", lrw.statusCode,
|
"status", lrw.statusCode,
|
||||||
"latency_ms", latency.Milliseconds(),
|
"latency_ms", latency.Milliseconds(),
|
||||||
)
|
)
|
||||||
@ -106,6 +111,28 @@ func ipFromHostPort(hostPort string) string {
|
|||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// realIP extracts the client's real IP address from the request,
|
||||||
|
// checking proxy headers first (trusted reverse proxy like Traefik),
|
||||||
|
// then falling back to RemoteAddr.
|
||||||
|
func realIP(r *http.Request) string {
|
||||||
|
// 1. X-Real-IP (set by Traefik/nginx)
|
||||||
|
if ip := strings.TrimSpace(r.Header.Get("X-Real-IP")); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. X-Forwarded-For: take the first (leftmost/client) IP
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { //nolint:mnd
|
||||||
|
if ip := strings.TrimSpace(parts[0]); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fall back to RemoteAddr
|
||||||
|
return ipFromHostPort(r.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// CORS returns CORS middleware.
|
// CORS returns CORS middleware.
|
||||||
func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
||||||
return cors.Handler(cors.Options{
|
return cors.Handler(cors.Options{
|
||||||
@ -162,6 +189,113 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loginRateLimit configures the login rate limiter.
|
||||||
|
const (
|
||||||
|
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
|
||||||
|
loginBurst = 5 // allow burst of 5
|
||||||
|
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
|
||||||
|
limiterCleanupEvery = 1 * time.Minute // sweep interval
|
||||||
|
)
|
||||||
|
|
||||||
|
// ipLimiterEntry stores a rate limiter with its last-seen timestamp.
|
||||||
|
type ipLimiterEntry struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipLimiter tracks per-IP rate limiters for login attempts with automatic
|
||||||
|
// eviction of stale entries to prevent unbounded memory growth.
|
||||||
|
type ipLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limiters map[string]*ipLimiterEntry
|
||||||
|
lastSweep time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPLimiter() *ipLimiter {
|
||||||
|
return &ipLimiter{
|
||||||
|
limiters: make(map[string]*ipLimiterEntry),
|
||||||
|
lastSweep: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sweep removes entries not seen within limiterExpiry. Must be called with mu held.
|
||||||
|
func (i *ipLimiter) sweep(now time.Time) {
|
||||||
|
for ip, entry := range i.limiters {
|
||||||
|
if now.Sub(entry.lastSeen) > limiterExpiry {
|
||||||
|
delete(i.limiters, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i.lastSweep = now
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
|
||||||
|
i.mu.Lock()
|
||||||
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Lazy sweep: clean up stale entries periodically.
|
||||||
|
if now.Sub(i.lastSweep) >= limiterCleanupEvery {
|
||||||
|
i.sweep(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, exists := i.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
entry = &ipLimiterEntry{
|
||||||
|
limiter: rate.NewLimiter(loginRateLimit, loginBurst),
|
||||||
|
}
|
||||||
|
i.limiters[ip] = entry
|
||||||
|
}
|
||||||
|
entry.lastSeen = now
|
||||||
|
|
||||||
|
return entry.limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// loginLimiter is the singleton IP rate limiter for login attempts.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // intentional singleton for rate limiting state
|
||||||
|
var loginLimiter = newIPLimiter()
|
||||||
|
|
||||||
|
// LoginRateLimit returns middleware that rate-limits login attempts per IP.
|
||||||
|
// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded.
|
||||||
|
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
request *http.Request,
|
||||||
|
) {
|
||||||
|
ip := realIP(request)
|
||||||
|
limiter := loginLimiter.getLimiter(ip)
|
||||||
|
|
||||||
|
if !limiter.Allow() {
|
||||||
|
m.log.WarnContext(request.Context(), "login rate limit exceeded",
|
||||||
|
"remoteIP", ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Compute seconds until the next token is available.
|
||||||
|
reservation := limiter.Reserve()
|
||||||
|
delay := reservation.Delay()
|
||||||
|
reservation.Cancel()
|
||||||
|
retryAfter := int(math.Ceil(delay.Seconds()))
|
||||||
|
if retryAfter < 1 {
|
||||||
|
retryAfter = 1
|
||||||
|
}
|
||||||
|
writer.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
|
||||||
|
|
||||||
|
http.Error(
|
||||||
|
writer,
|
||||||
|
"Too Many Requests",
|
||||||
|
http.StatusTooManyRequests,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(writer, request)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupRequired returns middleware that redirects to setup if no user exists.
|
// SetupRequired returns middleware that redirects to setup if no user exists.
|
||||||
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
|
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
|||||||
135
internal/middleware/ratelimit_test.go
Normal file
135
internal/middleware/ratelimit_test.go
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestMiddleware(t *testing.T) *Middleware {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return &Middleware{
|
||||||
|
log: slog.Default(),
|
||||||
|
params: &Params{
|
||||||
|
Config: &config.Config{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginRateLimitAllowsUpToBurst(t *testing.T) {
|
||||||
|
// Reset the global limiter to get clean state
|
||||||
|
loginLimiter = newIPLimiter()
|
||||||
|
|
||||||
|
mw := newTestMiddleware(t)
|
||||||
|
|
||||||
|
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// First 5 requests should succeed (burst)
|
||||||
|
for i := range 5 {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6th request should be rate limited
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginRateLimitIsolatesIPs(t *testing.T) {
|
||||||
|
loginLimiter = newIPLimiter()
|
||||||
|
|
||||||
|
mw := newTestMiddleware(t)
|
||||||
|
|
||||||
|
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Exhaust IP1's budget
|
||||||
|
for range 5 {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP1 should be blocked
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||||
|
|
||||||
|
// IP2 should still work
|
||||||
|
req2 := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req2.RemoteAddr = "10.0.0.2:1234"
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec2, req2)
|
||||||
|
assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginRateLimitReturns429Body(t *testing.T) {
|
||||||
|
loginLimiter = newIPLimiter()
|
||||||
|
|
||||||
|
mw := newTestMiddleware(t)
|
||||||
|
|
||||||
|
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
for range 5 {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req.RemoteAddr = "172.16.0.1:5555"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||||
|
req.RemoteAddr = "172.16.0.1:5555"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "Too Many Requests")
|
||||||
|
assert.NotEmpty(t, rec.Header().Get("Retry-After"), "should include Retry-After header")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPLimiterEvictsStaleEntries(t *testing.T) {
|
||||||
|
il := newIPLimiter()
|
||||||
|
|
||||||
|
// Add an entry and backdate its lastSeen
|
||||||
|
il.mu.Lock()
|
||||||
|
il.limiters["1.2.3.4"] = &ipLimiterEntry{
|
||||||
|
limiter: nil,
|
||||||
|
lastSeen: time.Now().Add(-15 * time.Minute),
|
||||||
|
}
|
||||||
|
il.limiters["5.6.7.8"] = &ipLimiterEntry{
|
||||||
|
limiter: nil,
|
||||||
|
lastSeen: time.Now(),
|
||||||
|
}
|
||||||
|
il.mu.Unlock()
|
||||||
|
|
||||||
|
// Trigger sweep
|
||||||
|
il.mu.Lock()
|
||||||
|
il.sweep(time.Now())
|
||||||
|
il.mu.Unlock()
|
||||||
|
|
||||||
|
il.mu.Lock()
|
||||||
|
defer il.mu.Unlock()
|
||||||
|
assert.NotContains(t, il.limiters, "1.2.3.4", "stale entry should be evicted")
|
||||||
|
assert.Contains(t, il.limiters, "5.6.7.8", "fresh entry should remain")
|
||||||
|
}
|
||||||
83
internal/middleware/realip_test.go
Normal file
83
internal/middleware/realip_test.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRealIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xRealIP string
|
||||||
|
xff string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "X-Real-IP takes priority",
|
||||||
|
remoteAddr: "10.0.0.1:1234",
|
||||||
|
xRealIP: "203.0.113.5",
|
||||||
|
xff: "198.51.100.1, 10.0.0.1",
|
||||||
|
want: "203.0.113.5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For used when no X-Real-IP",
|
||||||
|
remoteAddr: "10.0.0.1:1234",
|
||||||
|
xff: "198.51.100.1, 10.0.0.1",
|
||||||
|
want: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For single IP",
|
||||||
|
remoteAddr: "10.0.0.1:1234",
|
||||||
|
xff: "203.0.113.10",
|
||||||
|
want: "203.0.113.10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "falls back to RemoteAddr",
|
||||||
|
remoteAddr: "192.168.1.1:5678",
|
||||||
|
want: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr without port",
|
||||||
|
remoteAddr: "192.168.1.1",
|
||||||
|
want: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Real-IP with whitespace",
|
||||||
|
remoteAddr: "10.0.0.1:1234",
|
||||||
|
xRealIP: " 203.0.113.5 ",
|
||||||
|
want: "203.0.113.5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For with whitespace",
|
||||||
|
remoteAddr: "10.0.0.1:1234",
|
||||||
|
xff: " 198.51.100.1 , 10.0.0.1",
|
||||||
|
want: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty X-Real-IP falls through to XFF",
|
||||||
|
remoteAddr: "10.0.0.1:1234",
|
||||||
|
xRealIP: " ",
|
||||||
|
xff: "198.51.100.1",
|
||||||
|
want: "198.51.100.1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xRealIP != "" {
|
||||||
|
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||||
|
}
|
||||||
|
if tt.xff != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xff)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := realIP(req)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("realIP() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -46,7 +46,7 @@ func (s *Server) SetupRoutes() {
|
|||||||
|
|
||||||
// Public routes
|
// Public routes
|
||||||
r.Get("/login", s.handlers.HandleLoginGET())
|
r.Get("/login", s.handlers.HandleLoginGET())
|
||||||
r.Post("/login", s.handlers.HandleLoginPOST())
|
r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleLoginPOST())
|
||||||
r.Get("/setup", s.handlers.HandleSetupGET())
|
r.Get("/setup", s.handlers.HandleSetupGET())
|
||||||
r.Post("/setup", s.handlers.HandleSetupPOST())
|
r.Post("/setup", s.handlers.HandleSetupPOST())
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user