// Package middleware provides HTTP middleware. package middleware import ( "context" "log/slog" "math" "net" "net/http" "strconv" "strings" "sync" "time" "github.com/99designs/basicauth-go" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/gorilla/csrf" "go.uber.org/fx" "golang.org/x/time/rate" "git.eeqj.de/sneak/upaas/internal/config" "git.eeqj.de/sneak/upaas/internal/database" "git.eeqj.de/sneak/upaas/internal/globals" "git.eeqj.de/sneak/upaas/internal/logger" "git.eeqj.de/sneak/upaas/internal/models" "git.eeqj.de/sneak/upaas/internal/service/auth" ) // corsMaxAge is the maximum age for CORS preflight responses in seconds. const corsMaxAge = 300 // apiUserContextKey is the context key for the authenticated API user. type apiUserContextKey struct{} // Params contains dependencies for Middleware. type Params struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config Auth *auth.Service Database *database.Database } // Middleware provides HTTP middleware. type Middleware struct { log *slog.Logger params *Params } // New creates a new Middleware instance. func New(_ fx.Lifecycle, params Params) (*Middleware, error) { return &Middleware{ log: params.Logger.Get(), params: ¶ms, }, nil } // loggingResponseWriter wraps http.ResponseWriter to capture status code. type loggingResponseWriter struct { http.ResponseWriter statusCode int } func newLoggingResponseWriter( writer http.ResponseWriter, ) *loggingResponseWriter { return &loggingResponseWriter{writer, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // Logging returns a request logging middleware. func (m *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { start := time.Now() lrw := newLoggingResponseWriter(writer) ctx := request.Context() defer func() { latency := time.Since(start) reqID := middleware.GetReqID(ctx) m.log.InfoContext(ctx, "request", "request_start", start, "method", request.Method, "url", request.URL.String(), "useragent", request.UserAgent(), "request_id", reqID, "referer", request.Referer(), "proto", request.Proto, "remoteIP", realIP(request), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, request) }) } } func ipFromHostPort(hostPort string) string { host, _, err := net.SplitHostPort(hostPort) if err != nil { return hostPort } return host } // trustedProxyNets are RFC1918 and loopback CIDRs whose proxy headers we trust. // //nolint:gochecknoglobals // package-level constant nets parsed once var trustedProxyNets = func() []*net.IPNet { cidrs := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", } nets := make([]*net.IPNet, 0, len(cidrs)) for _, cidr := range cidrs { _, n, _ := net.ParseCIDR(cidr) nets = append(nets, n) } return nets }() // isTrustedProxy reports whether ip is in an RFC1918, loopback, or ULA range. func isTrustedProxy(ip net.IP) bool { for _, n := range trustedProxyNets { if n.Contains(ip) { return true } } return false } // realIP extracts the client's real IP address from the request. // Proxy headers (X-Real-IP, X-Forwarded-For) are only trusted when the // direct connection originates from an RFC1918/loopback address. // Otherwise, headers are ignored and RemoteAddr is used (fail closed). func realIP(r *http.Request) string { addr := ipFromHostPort(r.RemoteAddr) remoteIP := net.ParseIP(addr) // Only trust proxy headers from private/loopback sources. if remoteIP == nil || !isTrustedProxy(remoteIP) { return addr } // 1. X-Real-IP (set by Traefik/nginx) if ip := strings.TrimSpace(r.Header.Get("X-Real-IP")); ip != "" { return ip } // 2. X-Forwarded-For: take the first (leftmost/client) IP if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { //nolint:mnd if ip := strings.TrimSpace(parts[0]); ip != "" { return ip } } } // 3. Fall back to RemoteAddr return addr } // CORS returns CORS middleware. func (m *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: corsMaxAge, }) } // MetricsAuth returns basic auth middleware for metrics endpoint. func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler { if m.params.Config.MetricsUsername == "" { return func(next http.Handler) http.Handler { return next } } return basicauth.New( "metrics", map[string][]string{ m.params.Config.MetricsUsername: {m.params.Config.MetricsPassword}, }, ) } // SessionAuth returns middleware that requires authentication. func (m *Middleware) SessionAuth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { user, err := m.params.Auth.GetCurrentUser(request.Context(), request) if err != nil || user == nil { http.Redirect(writer, request, "/login", http.StatusSeeOther) return } next.ServeHTTP(writer, request) }) } } // CSRF returns CSRF protection middleware using gorilla/csrf. func (m *Middleware) CSRF() func(http.Handler) http.Handler { return csrf.Protect( []byte(m.params.Config.SessionSecret), csrf.Secure(false), // Allow HTTP for development; reverse proxy handles TLS csrf.Path("/"), ) } // loginRateLimit configures the login rate limiter. const ( loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds loginBurst = 5 // allow burst of 5 limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes limiterCleanupEvery = 1 * time.Minute // sweep interval ) // ipLimiterEntry stores a rate limiter with its last-seen timestamp. type ipLimiterEntry struct { limiter *rate.Limiter lastSeen time.Time } // ipLimiter tracks per-IP rate limiters for login attempts with automatic // eviction of stale entries to prevent unbounded memory growth. type ipLimiter struct { mu sync.Mutex limiters map[string]*ipLimiterEntry lastSweep time.Time } func newIPLimiter() *ipLimiter { return &ipLimiter{ limiters: make(map[string]*ipLimiterEntry), lastSweep: time.Now(), } } // sweep removes entries not seen within limiterExpiry. Must be called with mu held. func (i *ipLimiter) sweep(now time.Time) { for ip, entry := range i.limiters { if now.Sub(entry.lastSeen) > limiterExpiry { delete(i.limiters, ip) } } i.lastSweep = now } func (i *ipLimiter) getLimiter(ip string) *rate.Limiter { i.mu.Lock() defer i.mu.Unlock() now := time.Now() // Lazy sweep: clean up stale entries periodically. if now.Sub(i.lastSweep) >= limiterCleanupEvery { i.sweep(now) } entry, exists := i.limiters[ip] if !exists { entry = &ipLimiterEntry{ limiter: rate.NewLimiter(loginRateLimit, loginBurst), } i.limiters[ip] = entry } entry.lastSeen = now return entry.limiter } // loginLimiter is the singleton IP rate limiter for login attempts. // //nolint:gochecknoglobals // intentional singleton for rate limiting state var loginLimiter = newIPLimiter() // LoginRateLimit returns middleware that rate-limits login attempts per IP. // It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded. func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { ip := realIP(request) limiter := loginLimiter.getLimiter(ip) if !limiter.Allow() { m.log.WarnContext(request.Context(), "login rate limit exceeded", "remoteIP", ip, ) // Compute seconds until the next token is available. reservation := limiter.Reserve() delay := reservation.Delay() reservation.Cancel() retryAfter := max(int(math.Ceil(delay.Seconds())), 1) writer.Header().Set("Retry-After", strconv.Itoa(retryAfter)) http.Error( writer, "Too Many Requests", http.StatusTooManyRequests, ) return } next.ServeHTTP(writer, request) }) } } // APITokenAuth returns middleware that authenticates requests via Bearer token. // It looks up the token hash in the database and stores the user in context. func (m *Middleware) APITokenAuth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { authHeader := request.Header.Get("Authorization") if authHeader == "" { http.Error(writer, `{"error":"missing Authorization header"}`, http.StatusUnauthorized) return } const bearerPrefix = "Bearer " if !strings.HasPrefix(authHeader, bearerPrefix) { http.Error(writer, `{"error":"invalid Authorization header"}`, http.StatusUnauthorized) return } rawToken := strings.TrimPrefix(authHeader, bearerPrefix) if rawToken == "" { http.Error(writer, `{"error":"empty token"}`, http.StatusUnauthorized) return } hash := models.HashAPIToken(rawToken) apiToken, err := models.FindAPITokenByHash(request.Context(), m.params.Database, hash) if err != nil { m.log.Error("api token lookup error", "error", err) http.Error(writer, `{"error":"internal server error"}`, http.StatusInternalServerError) return } if apiToken == nil { http.Error(writer, `{"error":"invalid token"}`, http.StatusUnauthorized) return } // Touch last used (best-effort, don't block on error) _ = apiToken.TouchLastUsed(request.Context()) user, userErr := models.FindUser(request.Context(), m.params.Database, apiToken.UserID) if userErr != nil || user == nil { http.Error(writer, `{"error":"token user not found"}`, http.StatusUnauthorized) return } ctx := context.WithValue(request.Context(), apiUserContextKey{}, user) next.ServeHTTP(writer, request.WithContext(ctx)) }) } } // APIUserFromContext extracts the authenticated API user from the context. func APIUserFromContext(ctx context.Context) *models.User { user, _ := ctx.Value(apiUserContextKey{}).(*models.User) return user } // SetupRequired returns middleware that redirects to setup if no user exists. func (m *Middleware) SetupRequired() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { setupRequired, err := m.params.Auth.IsSetupRequired(request.Context()) if err != nil { m.log.Error("failed to check setup status", "error", err) http.Error( writer, "Internal Server Error", http.StatusInternalServerError, ) return } if setupRequired { // Allow access to setup page if request.URL.Path == "/setup" { next.ServeHTTP(writer, request) return } http.Redirect(writer, request, "/setup", http.StatusSeeOther) return } // Block setup page if already set up if request.URL.Path == "/setup" { http.Redirect(writer, request, "/", http.StatusSeeOther) return } next.ServeHTTP(writer, request) }) } }