upaas/internal/middleware/middleware.go
clawbot 66661d1b1d Add rate limiting to login endpoint to prevent brute force
Apply per-IP rate limiting (5 attempts/minute) to POST /login using
golang.org/x/time/rate. Returns 429 Too Many Requests when exceeded.

Closes #12
2026-02-15 21:01:11 -08:00

275 lines
6.6 KiB
Go

// Package middleware provides HTTP middleware.
package middleware
import (
"log/slog"
"net"
"net/http"
"sync"
"time"
"github.com/99designs/basicauth-go"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/gorilla/csrf"
"go.uber.org/fx"
"golang.org/x/time/rate"
"git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/globals"
"git.eeqj.de/sneak/upaas/internal/logger"
"git.eeqj.de/sneak/upaas/internal/service/auth"
)
// corsMaxAge is the maximum age for CORS preflight responses in seconds.
const corsMaxAge = 300
// Params contains dependencies for Middleware.
type Params struct {
fx.In
Logger *logger.Logger
Globals *globals.Globals
Config *config.Config
Auth *auth.Service
}
// Middleware provides HTTP middleware.
type Middleware struct {
log *slog.Logger
params *Params
}
// New creates a new Middleware instance.
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
return &Middleware{
log: params.Logger.Get(),
params: &params,
}, nil
}
// loggingResponseWriter wraps http.ResponseWriter to capture status code.
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func newLoggingResponseWriter(
writer http.ResponseWriter,
) *loggingResponseWriter {
return &loggingResponseWriter{writer, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Logging returns a request logging middleware.
func (m *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
start := time.Now()
lrw := newLoggingResponseWriter(writer)
ctx := request.Context()
defer func() {
latency := time.Since(start)
reqID := middleware.GetReqID(ctx)
m.log.InfoContext(ctx, "request",
"request_start", start,
"method", request.Method,
"url", request.URL.String(),
"useragent", request.UserAgent(),
"request_id", reqID,
"referer", request.Referer(),
"proto", request.Proto,
"remoteIP", ipFromHostPort(request.RemoteAddr),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, request)
})
}
}
func ipFromHostPort(hostPort string) string {
host, _, err := net.SplitHostPort(hostPort)
if err != nil {
return hostPort
}
return host
}
// CORS returns CORS middleware.
func (m *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: corsMaxAge,
})
}
// MetricsAuth returns basic auth middleware for metrics endpoint.
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
if m.params.Config.MetricsUsername == "" {
return func(next http.Handler) http.Handler {
return next
}
}
return basicauth.New(
"metrics",
map[string][]string{
m.params.Config.MetricsUsername: {m.params.Config.MetricsPassword},
},
)
}
// SessionAuth returns middleware that requires authentication.
func (m *Middleware) SessionAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
user, err := m.params.Auth.GetCurrentUser(request.Context(), request)
if err != nil || user == nil {
http.Redirect(writer, request, "/login", http.StatusSeeOther)
return
}
next.ServeHTTP(writer, request)
})
}
}
// CSRF returns CSRF protection middleware using gorilla/csrf.
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
return csrf.Protect(
[]byte(m.params.Config.SessionSecret),
csrf.Secure(false), // Allow HTTP for development; reverse proxy handles TLS
csrf.Path("/"),
)
}
// loginRateLimit configures the login rate limiter.
const (
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
loginBurst = 5 // allow burst of 5
)
// ipLimiter tracks per-IP rate limiters for login attempts.
type ipLimiter struct {
mu sync.Mutex
limiters map[string]*rate.Limiter
}
func newIPLimiter() *ipLimiter {
return &ipLimiter{
limiters: make(map[string]*rate.Limiter),
}
}
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()
limiter, exists := i.limiters[ip]
if !exists {
limiter = rate.NewLimiter(loginRateLimit, loginBurst)
i.limiters[ip] = limiter
}
return limiter
}
// loginLimiter is the singleton IP rate limiter for login attempts.
//
//nolint:gochecknoglobals // intentional singleton for rate limiting state
var loginLimiter = newIPLimiter()
// LoginRateLimit returns middleware that rate-limits login attempts per IP.
// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded.
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
ip := ipFromHostPort(request.RemoteAddr)
limiter := loginLimiter.getLimiter(ip)
if !limiter.Allow() {
m.log.WarnContext(request.Context(), "login rate limit exceeded",
"remoteIP", ip,
)
http.Error(
writer,
"Too Many Requests",
http.StatusTooManyRequests,
)
return
}
next.ServeHTTP(writer, request)
})
}
}
// SetupRequired returns middleware that redirects to setup if no user exists.
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
setupRequired, err := m.params.Auth.IsSetupRequired(request.Context())
if err != nil {
m.log.Error("failed to check setup status", "error", err)
http.Error(
writer,
"Internal Server Error",
http.StatusInternalServerError,
)
return
}
if setupRequired {
// Allow access to setup page
if request.URL.Path == "/setup" {
next.ServeHTTP(writer, request)
return
}
http.Redirect(writer, request, "/setup", http.StatusSeeOther)
return
}
// Block setup page if already set up
if request.URL.Path == "/setup" {
http.Redirect(writer, request, "/", http.StatusSeeOther)
return
}
next.ServeHTTP(writer, request)
})
}
}