Add rate limiting to login endpoint to prevent brute force
Apply per-IP rate limiting (5 attempts/minute) to POST /login using golang.org/x/time/rate. Returns 429 Too Many Requests when exceeded. Closes #12
This commit is contained in:
parent
3a2bd0e51d
commit
66661d1b1d
@ -5,6 +5,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/99designs/basicauth-go"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/gorilla/csrf"
|
||||
"go.uber.org/fx"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"git.eeqj.de/sneak/upaas/internal/config"
|
||||
"git.eeqj.de/sneak/upaas/internal/globals"
|
||||
@ -162,6 +164,71 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
||||
)
|
||||
}
|
||||
|
||||
// loginRateLimit configures the login rate limiter.
|
||||
const (
|
||||
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
|
||||
loginBurst = 5 // allow burst of 5
|
||||
)
|
||||
|
||||
// ipLimiter tracks per-IP rate limiters for login attempts.
|
||||
type ipLimiter struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*rate.Limiter
|
||||
}
|
||||
|
||||
func newIPLimiter() *ipLimiter {
|
||||
return &ipLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
limiter, exists := i.limiters[ip]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(loginRateLimit, loginBurst)
|
||||
i.limiters[ip] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// loginLimiter is the singleton IP rate limiter for login attempts.
|
||||
//
|
||||
//nolint:gochecknoglobals // intentional singleton for rate limiting state
|
||||
var loginLimiter = newIPLimiter()
|
||||
|
||||
// LoginRateLimit returns middleware that rate-limits login attempts per IP.
|
||||
// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded.
|
||||
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
ip := ipFromHostPort(request.RemoteAddr)
|
||||
limiter := loginLimiter.getLimiter(ip)
|
||||
|
||||
if !limiter.Allow() {
|
||||
m.log.WarnContext(request.Context(), "login rate limit exceeded",
|
||||
"remoteIP", ip,
|
||||
)
|
||||
http.Error(
|
||||
writer,
|
||||
"Too Many Requests",
|
||||
http.StatusTooManyRequests,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(writer, request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetupRequired returns middleware that redirects to setup if no user exists.
|
||||
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
107
internal/middleware/ratelimit_test.go
Normal file
107
internal/middleware/ratelimit_test.go
Normal file
@ -0,0 +1,107 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"git.eeqj.de/sneak/upaas/internal/config"
|
||||
)
|
||||
|
||||
func newTestMiddleware(t *testing.T) *Middleware {
|
||||
t.Helper()
|
||||
|
||||
return &Middleware{
|
||||
log: slog.Default(),
|
||||
params: &Params{
|
||||
Config: &config.Config{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginRateLimitAllowsUpToBurst(t *testing.T) {
|
||||
// Reset the global limiter to get clean state
|
||||
loginLimiter = newIPLimiter()
|
||||
|
||||
mw := newTestMiddleware(t)
|
||||
|
||||
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First 5 requests should succeed (burst)
|
||||
for i := range 5 {
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1)
|
||||
}
|
||||
|
||||
// 6th request should be rate limited
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited")
|
||||
}
|
||||
|
||||
func TestLoginRateLimitIsolatesIPs(t *testing.T) {
|
||||
loginLimiter = newIPLimiter()
|
||||
|
||||
mw := newTestMiddleware(t)
|
||||
|
||||
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Exhaust IP1's budget
|
||||
for range 5 {
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
// IP1 should be blocked
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
|
||||
// IP2 should still work
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req2.RemoteAddr = "10.0.0.2:1234"
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec2, req2)
|
||||
assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited")
|
||||
}
|
||||
|
||||
func TestLoginRateLimitReturns429Body(t *testing.T) {
|
||||
loginLimiter = newIPLimiter()
|
||||
|
||||
mw := newTestMiddleware(t)
|
||||
|
||||
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Exhaust burst
|
||||
for range 5 {
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req.RemoteAddr = "172.16.0.1:5555"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
||||
req.RemoteAddr = "172.16.0.1:5555"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Too Many Requests")
|
||||
}
|
||||
@ -46,7 +46,7 @@ func (s *Server) SetupRoutes() {
|
||||
|
||||
// Public routes
|
||||
r.Get("/login", s.handlers.HandleLoginGET())
|
||||
r.Post("/login", s.handlers.HandleLoginPOST())
|
||||
r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleLoginPOST())
|
||||
r.Get("/setup", s.handlers.HandleSetupGET())
|
||||
r.Post("/setup", s.handlers.HandleSetupPOST())
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user