Files
upaas/internal/middleware/middleware.go
clawbot 96eea71c54 fix: set authenticated user on request context in bearer token auth
tryBearerAuth validated the bearer token but never looked up the
associated user or set it on the request context. This meant
downstream handlers calling GetCurrentUser would get nil even
with a valid token.

Changes:
- Add ContextWithUser/UserFromContext helpers in auth package
- tryBearerAuth now looks up the user by token's UserID and
  sets it on the request context via auth.ContextWithUser
- GetCurrentUser checks context first before falling back to
  session cookie
- Add integration tests for bearer auth user context
2026-02-19 23:43:22 -08:00

504 lines
12 KiB
Go

// Package middleware provides HTTP middleware.
package middleware
import (
"log/slog"
"math"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/99designs/basicauth-go"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/gorilla/csrf"
"go.uber.org/fx"
"golang.org/x/time/rate"
"git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/database"
"git.eeqj.de/sneak/upaas/internal/globals"
"git.eeqj.de/sneak/upaas/internal/logger"
"git.eeqj.de/sneak/upaas/internal/models"
"git.eeqj.de/sneak/upaas/internal/service/auth"
)
// corsMaxAge is the maximum age for CORS preflight responses in seconds.
const corsMaxAge = 300
// Params contains dependencies for Middleware.
type Params struct {
fx.In
Logger *logger.Logger
Globals *globals.Globals
Config *config.Config
Auth *auth.Service
Database *database.Database
}
// Middleware provides HTTP middleware.
type Middleware struct {
log *slog.Logger
params *Params
}
// New creates a new Middleware instance.
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
return &Middleware{
log: params.Logger.Get(),
params: &params,
}, nil
}
// loggingResponseWriter wraps http.ResponseWriter to capture status code.
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func newLoggingResponseWriter(
writer http.ResponseWriter,
) *loggingResponseWriter {
return &loggingResponseWriter{writer, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Logging returns a request logging middleware.
func (m *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
start := time.Now()
lrw := newLoggingResponseWriter(writer)
ctx := request.Context()
defer func() {
latency := time.Since(start)
reqID := middleware.GetReqID(ctx)
m.log.InfoContext(ctx, "request",
"request_start", start,
"method", request.Method,
"url", request.URL.String(),
"useragent", request.UserAgent(),
"request_id", reqID,
"referer", request.Referer(),
"proto", request.Proto,
"remoteIP", realIP(request),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, request)
})
}
}
func ipFromHostPort(hostPort string) string {
host, _, err := net.SplitHostPort(hostPort)
if err != nil {
return hostPort
}
return host
}
// trustedProxyNets are RFC1918 and loopback CIDRs whose proxy headers we trust.
//
//nolint:gochecknoglobals // package-level constant nets parsed once
var trustedProxyNets = func() []*net.IPNet {
cidrs := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
nets := make([]*net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, n, _ := net.ParseCIDR(cidr)
nets = append(nets, n)
}
return nets
}()
// isTrustedProxy reports whether ip is in an RFC1918, loopback, or ULA range.
func isTrustedProxy(ip net.IP) bool {
for _, n := range trustedProxyNets {
if n.Contains(ip) {
return true
}
}
return false
}
// realIP extracts the client's real IP address from the request.
// Proxy headers (X-Real-IP, X-Forwarded-For) are only trusted when the
// direct connection originates from an RFC1918/loopback address.
// Otherwise, headers are ignored and RemoteAddr is used (fail closed).
func realIP(r *http.Request) string {
addr := ipFromHostPort(r.RemoteAddr)
remoteIP := net.ParseIP(addr)
// Only trust proxy headers from private/loopback sources.
if remoteIP == nil || !isTrustedProxy(remoteIP) {
return addr
}
// 1. X-Real-IP (set by Traefik/nginx)
if ip := strings.TrimSpace(r.Header.Get("X-Real-IP")); ip != "" {
return ip
}
// 2. X-Forwarded-For: take the first (leftmost/client) IP
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { //nolint:mnd
if ip := strings.TrimSpace(parts[0]); ip != "" {
return ip
}
}
}
// 3. Fall back to RemoteAddr
return addr
}
// CORS returns CORS middleware.
// When UPAAS_CORS_ORIGINS is empty (default), no CORS headers are sent
// (same-origin only). When configured, only the specified origins are
// allowed and credentials (cookies) are permitted.
func (m *Middleware) CORS() func(http.Handler) http.Handler {
origins := parseCORSOrigins(m.params.Config.CORSOrigins)
// No origins configured — no CORS headers (same-origin policy).
if len(origins) == 0 {
return func(next http.Handler) http.Handler {
return next
}
}
return cors.Handler(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: corsMaxAge,
})
}
// parseCORSOrigins splits a comma-separated origin string into a slice,
// trimming whitespace. Returns nil if the input is empty.
func parseCORSOrigins(raw string) []string {
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
origins := make([]string, 0, len(parts))
for _, p := range parts {
if o := strings.TrimSpace(p); o != "" {
origins = append(origins, o)
}
}
return origins
}
// MetricsAuth returns basic auth middleware for metrics endpoint.
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
if m.params.Config.MetricsUsername == "" {
return func(next http.Handler) http.Handler {
return next
}
}
return basicauth.New(
"metrics",
map[string][]string{
m.params.Config.MetricsUsername: {m.params.Config.MetricsPassword},
},
)
}
// SessionAuth returns middleware that requires authentication.
func (m *Middleware) SessionAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
user, err := m.params.Auth.GetCurrentUser(request.Context(), request)
if err != nil || user == nil {
http.Redirect(writer, request, "/login", http.StatusSeeOther)
return
}
next.ServeHTTP(writer, request)
})
}
}
// CSRF returns CSRF protection middleware using gorilla/csrf.
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
return csrf.Protect(
[]byte(m.params.Config.SessionSecret),
csrf.Secure(false), // Allow HTTP for development; reverse proxy handles TLS
csrf.Path("/"),
)
}
// loginRateLimit configures the login rate limiter.
const (
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
loginBurst = 5 // allow burst of 5
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
limiterCleanupEvery = 1 * time.Minute // sweep interval
)
// ipLimiterEntry stores a rate limiter with its last-seen timestamp.
type ipLimiterEntry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// ipLimiter tracks per-IP rate limiters for login attempts with automatic
// eviction of stale entries to prevent unbounded memory growth.
type ipLimiter struct {
mu sync.Mutex
limiters map[string]*ipLimiterEntry
lastSweep time.Time
}
func newIPLimiter() *ipLimiter {
return &ipLimiter{
limiters: make(map[string]*ipLimiterEntry),
lastSweep: time.Now(),
}
}
// sweep removes entries not seen within limiterExpiry. Must be called with mu held.
func (i *ipLimiter) sweep(now time.Time) {
for ip, entry := range i.limiters {
if now.Sub(entry.lastSeen) > limiterExpiry {
delete(i.limiters, ip)
}
}
i.lastSweep = now
}
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()
now := time.Now()
// Lazy sweep: clean up stale entries periodically.
if now.Sub(i.lastSweep) >= limiterCleanupEvery {
i.sweep(now)
}
entry, exists := i.limiters[ip]
if !exists {
entry = &ipLimiterEntry{
limiter: rate.NewLimiter(loginRateLimit, loginBurst),
}
i.limiters[ip] = entry
}
entry.lastSeen = now
return entry.limiter
}
// loginLimiter is the singleton IP rate limiter for login attempts.
//
//nolint:gochecknoglobals // intentional singleton for rate limiting state
var loginLimiter = newIPLimiter()
// LoginRateLimit returns middleware that rate-limits login attempts per IP.
// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded.
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
ip := realIP(request)
limiter := loginLimiter.getLimiter(ip)
if !limiter.Allow() {
m.log.WarnContext(request.Context(), "login rate limit exceeded",
"remoteIP", ip,
)
// Compute seconds until the next token is available.
reservation := limiter.Reserve()
delay := reservation.Delay()
reservation.Cancel()
retryAfter := max(int(math.Ceil(delay.Seconds())), 1)
writer.Header().Set("Retry-After", strconv.Itoa(retryAfter))
http.Error(
writer,
"Too Many Requests",
http.StatusTooManyRequests,
)
return
}
next.ServeHTTP(writer, request)
})
}
}
// bearerPrefix is the expected prefix for Authorization headers.
const bearerPrefix = "Bearer "
// APISessionAuth returns middleware that requires authentication
// for API routes. It checks Bearer token first, then falls back
// to session cookie. Returns JSON 401 on failure.
func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
// Try Bearer token first.
if authedReq, ok := m.tryBearerAuth(request); ok {
next.ServeHTTP(writer, authedReq)
return
}
// Fall back to session cookie.
user, err := m.params.Auth.GetCurrentUser(
request.Context(), request,
)
if err != nil || user == nil {
writer.Header().Set("Content-Type", "application/json")
http.Error(
writer,
`{"error":"unauthorized"}`,
http.StatusUnauthorized,
)
return
}
next.ServeHTTP(writer, request)
})
}
}
// SetupRequired returns middleware that redirects to setup if no user exists.
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
setupRequired, err := m.params.Auth.IsSetupRequired(request.Context())
if err != nil {
m.log.Error("failed to check setup status", "error", err)
http.Error(
writer,
"Internal Server Error",
http.StatusInternalServerError,
)
return
}
if setupRequired {
// Allow access to setup page
if request.URL.Path == "/setup" {
next.ServeHTTP(writer, request)
return
}
http.Redirect(writer, request, "/setup", http.StatusSeeOther)
return
}
// Block setup page if already set up
if request.URL.Path == "/setup" {
http.Redirect(writer, request, "/", http.StatusSeeOther)
return
}
next.ServeHTTP(writer, request)
})
}
}
// tryBearerAuth checks for a valid Bearer token in the
// Authorization header. On success it returns a new request
// with the authenticated user set on the context.
func (m *Middleware) tryBearerAuth(
request *http.Request,
) (*http.Request, bool) {
authHeader := request.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, bearerPrefix) {
return request, false
}
rawToken := strings.TrimPrefix(authHeader, bearerPrefix)
if rawToken == "" {
return request, false
}
tokenHash := database.HashAPIToken(rawToken)
apiToken, err := models.FindAPITokenByHash(
request.Context(), m.params.Database, tokenHash,
)
if err != nil || apiToken == nil {
return request, false
}
if apiToken.IsExpired() {
return request, false
}
// Look up the user associated with the token.
user, err := models.FindUser(
request.Context(), m.params.Database, apiToken.UserID,
)
if err != nil || user == nil {
return request, false
}
// Update last_used_at (best effort).
_ = apiToken.TouchLastUsed(request.Context())
// Set the authenticated user on the request context.
ctx := auth.ContextWithUser(request.Context(), user)
return request.WithContext(ctx), true
}