// Package middleware provides HTTP middleware. package middleware import ( "log/slog" "net" "net/http" "sync" "time" "github.com/99designs/basicauth-go" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/gorilla/csrf" "go.uber.org/fx" "golang.org/x/time/rate" "git.eeqj.de/sneak/upaas/internal/config" "git.eeqj.de/sneak/upaas/internal/globals" "git.eeqj.de/sneak/upaas/internal/logger" "git.eeqj.de/sneak/upaas/internal/service/auth" ) // corsMaxAge is the maximum age for CORS preflight responses in seconds. const corsMaxAge = 300 // Params contains dependencies for Middleware. type Params struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config Auth *auth.Service } // Middleware provides HTTP middleware. type Middleware struct { log *slog.Logger params *Params } // New creates a new Middleware instance. func New(_ fx.Lifecycle, params Params) (*Middleware, error) { return &Middleware{ log: params.Logger.Get(), params: ¶ms, }, nil } // loggingResponseWriter wraps http.ResponseWriter to capture status code. type loggingResponseWriter struct { http.ResponseWriter statusCode int } func newLoggingResponseWriter( writer http.ResponseWriter, ) *loggingResponseWriter { return &loggingResponseWriter{writer, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // Logging returns a request logging middleware. func (m *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { start := time.Now() lrw := newLoggingResponseWriter(writer) ctx := request.Context() defer func() { latency := time.Since(start) reqID := middleware.GetReqID(ctx) m.log.InfoContext(ctx, "request", "request_start", start, "method", request.Method, "url", request.URL.String(), "useragent", request.UserAgent(), "request_id", reqID, "referer", request.Referer(), "proto", request.Proto, "remoteIP", ipFromHostPort(request.RemoteAddr), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, request) }) } } func ipFromHostPort(hostPort string) string { host, _, err := net.SplitHostPort(hostPort) if err != nil { return hostPort } return host } // CORS returns CORS middleware. func (m *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: corsMaxAge, }) } // MetricsAuth returns basic auth middleware for metrics endpoint. func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler { if m.params.Config.MetricsUsername == "" { return func(next http.Handler) http.Handler { return next } } return basicauth.New( "metrics", map[string][]string{ m.params.Config.MetricsUsername: {m.params.Config.MetricsPassword}, }, ) } // SessionAuth returns middleware that requires authentication. func (m *Middleware) SessionAuth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { user, err := m.params.Auth.GetCurrentUser(request.Context(), request) if err != nil || user == nil { http.Redirect(writer, request, "/login", http.StatusSeeOther) return } next.ServeHTTP(writer, request) }) } } // CSRF returns CSRF protection middleware using gorilla/csrf. func (m *Middleware) CSRF() func(http.Handler) http.Handler { return csrf.Protect( []byte(m.params.Config.SessionSecret), csrf.Secure(false), // Allow HTTP for development; reverse proxy handles TLS csrf.Path("/"), ) } // loginRateLimit configures the login rate limiter. const ( loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds loginBurst = 5 // allow burst of 5 ) // ipLimiter tracks per-IP rate limiters for login attempts. type ipLimiter struct { mu sync.Mutex limiters map[string]*rate.Limiter } func newIPLimiter() *ipLimiter { return &ipLimiter{ limiters: make(map[string]*rate.Limiter), } } func (i *ipLimiter) getLimiter(ip string) *rate.Limiter { i.mu.Lock() defer i.mu.Unlock() limiter, exists := i.limiters[ip] if !exists { limiter = rate.NewLimiter(loginRateLimit, loginBurst) i.limiters[ip] = limiter } return limiter } // loginLimiter is the singleton IP rate limiter for login attempts. // //nolint:gochecknoglobals // intentional singleton for rate limiting state var loginLimiter = newIPLimiter() // LoginRateLimit returns middleware that rate-limits login attempts per IP. // It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded. func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { ip := ipFromHostPort(request.RemoteAddr) limiter := loginLimiter.getLimiter(ip) if !limiter.Allow() { m.log.WarnContext(request.Context(), "login rate limit exceeded", "remoteIP", ip, ) http.Error( writer, "Too Many Requests", http.StatusTooManyRequests, ) return } next.ServeHTTP(writer, request) }) } } // SetupRequired returns middleware that redirects to setup if no user exists. func (m *Middleware) SetupRequired() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { setupRequired, err := m.params.Auth.IsSetupRequired(request.Context()) if err != nil { m.log.Error("failed to check setup status", "error", err) http.Error( writer, "Internal Server Error", http.StatusInternalServerError, ) return } if setupRequired { // Allow access to setup page if request.URL.Path == "/setup" { next.ServeHTTP(writer, request) return } http.Redirect(writer, request, "/setup", http.StatusSeeOther) return } // Block setup page if already set up if request.URL.Path == "/setup" { http.Redirect(writer, request, "/", http.StatusSeeOther) return } next.ServeHTTP(writer, request) }) } }