Add rate limiting to login endpoint to prevent brute force (closes #12) #14

Merged
sneak merged 3 commits from :fix/issue-12 into main 2026-02-16 06:15:49 +01:00
3 changed files with 175 additions and 1 deletions
Showing only changes of commit 66661d1b1d - Show all commits

View File

@ -5,6 +5,7 @@ import (
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/99designs/basicauth-go" "github.com/99designs/basicauth-go"
@ -12,6 +13,7 @@ import (
"github.com/go-chi/cors" "github.com/go-chi/cors"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"go.uber.org/fx" "go.uber.org/fx"
"golang.org/x/time/rate"
"git.eeqj.de/sneak/upaas/internal/config" "git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/globals" "git.eeqj.de/sneak/upaas/internal/globals"
@ -162,6 +164,71 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
) )
} }
// loginRateLimit configures the login rate limiter.
const (
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
loginBurst = 5 // allow burst of 5
)
// ipLimiter tracks per-IP rate limiters for login attempts.
type ipLimiter struct {
mu sync.Mutex
limiters map[string]*rate.Limiter
}
func newIPLimiter() *ipLimiter {
return &ipLimiter{
limiters: make(map[string]*rate.Limiter),
}
}
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()
limiter, exists := i.limiters[ip]
if !exists {
limiter = rate.NewLimiter(loginRateLimit, loginBurst)
i.limiters[ip] = limiter
}
return limiter
}
// loginLimiter is the singleton IP rate limiter for login attempts.
//
//nolint:gochecknoglobals // intentional singleton for rate limiting state
var loginLimiter = newIPLimiter()
// LoginRateLimit returns middleware that rate-limits login attempts per IP.
// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded.
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
ip := ipFromHostPort(request.RemoteAddr)
limiter := loginLimiter.getLimiter(ip)
if !limiter.Allow() {
m.log.WarnContext(request.Context(), "login rate limit exceeded",
"remoteIP", ip,
)
http.Error(
writer,
"Too Many Requests",
http.StatusTooManyRequests,
)
return
}
next.ServeHTTP(writer, request)
})
}
}
// SetupRequired returns middleware that redirects to setup if no user exists. // SetupRequired returns middleware that redirects to setup if no user exists.
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler { func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

View File

@ -0,0 +1,107 @@
package middleware
import (
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"git.eeqj.de/sneak/upaas/internal/config"
)
func newTestMiddleware(t *testing.T) *Middleware {
t.Helper()
return &Middleware{
log: slog.Default(),
params: &Params{
Config: &config.Config{},
},
}
}
func TestLoginRateLimitAllowsUpToBurst(t *testing.T) {
// Reset the global limiter to get clean state
loginLimiter = newIPLimiter()
mw := newTestMiddleware(t)
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First 5 requests should succeed (burst)
for i := range 5 {
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "request %d should succeed", i+1)
}
// 6th request should be rate limited
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "6th request should be rate limited")
}
func TestLoginRateLimitIsolatesIPs(t *testing.T) {
loginLimiter = newIPLimiter()
mw := newTestMiddleware(t)
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Exhaust IP1's budget
for range 5 {
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "10.0.0.1:1234"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
}
// IP1 should be blocked
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "10.0.0.1:1234"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 should still work
req2 := httptest.NewRequest(http.MethodPost, "/login", nil)
req2.RemoteAddr = "10.0.0.2:1234"
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
assert.Equal(t, http.StatusOK, rec2.Code, "different IP should not be rate limited")
}
func TestLoginRateLimitReturns429Body(t *testing.T) {
loginLimiter = newIPLimiter()
mw := newTestMiddleware(t)
handler := mw.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Exhaust burst
for range 5 {
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "172.16.0.1:5555"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
}
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "172.16.0.1:5555"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
assert.Contains(t, rec.Body.String(), "Too Many Requests")
}

View File

@ -46,7 +46,7 @@ func (s *Server) SetupRoutes() {
// Public routes // Public routes
r.Get("/login", s.handlers.HandleLoginGET()) r.Get("/login", s.handlers.HandleLoginGET())
r.Post("/login", s.handlers.HandleLoginPOST()) r.With(s.mw.LoginRateLimit()).Post("/login", s.handlers.HandleLoginPOST())
r.Get("/setup", s.handlers.HandleSetupGET()) r.Get("/setup", s.handlers.HandleSetupGET())
r.Post("/setup", s.handlers.HandleSetupPOST()) r.Post("/setup", s.handlers.HandleSetupPOST())