upaas/internal/middleware/middleware.go
user 0536f57ec2 feat: add JSON API with token auth (closes #69)
- Add API token model with SHA-256 hashed tokens
- Add migration 006_add_api_tokens.sql
- Add Bearer token auth middleware
- Add API endpoints under /api/v1/:
  - GET /whoami
  - POST /tokens (create new API token)
  - GET /apps (list all apps)
  - POST /apps (create app)
  - GET /apps/{id} (get app)
  - DELETE /apps/{id} (delete app)
  - POST /apps/{id}/deploy (trigger deployment)
  - GET /apps/{id}/deployments (list deployments)
- Add comprehensive tests for all API endpoints
- All tests pass, zero lint issues
2026-02-16 00:24:45 -08:00

460 lines
12 KiB
Go

// Package middleware provides HTTP middleware.
package middleware
import (
"context"
"log/slog"
"math"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/99designs/basicauth-go"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/gorilla/csrf"
"go.uber.org/fx"
"golang.org/x/time/rate"
"git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/database"
"git.eeqj.de/sneak/upaas/internal/globals"
"git.eeqj.de/sneak/upaas/internal/logger"
"git.eeqj.de/sneak/upaas/internal/models"
"git.eeqj.de/sneak/upaas/internal/service/auth"
)
// corsMaxAge is the maximum age for CORS preflight responses in seconds.
const corsMaxAge = 300
// apiUserContextKey is the context key for the authenticated API user.
type apiUserContextKey struct{}
// Params contains dependencies for Middleware.
type Params struct {
fx.In
Logger *logger.Logger
Globals *globals.Globals
Config *config.Config
Auth *auth.Service
Database *database.Database
}
// Middleware provides HTTP middleware.
type Middleware struct {
log *slog.Logger
params *Params
}
// New creates a new Middleware instance.
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
return &Middleware{
log: params.Logger.Get(),
params: &params,
}, nil
}
// loggingResponseWriter wraps http.ResponseWriter to capture status code.
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func newLoggingResponseWriter(
writer http.ResponseWriter,
) *loggingResponseWriter {
return &loggingResponseWriter{writer, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Logging returns a request logging middleware.
func (m *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
start := time.Now()
lrw := newLoggingResponseWriter(writer)
ctx := request.Context()
defer func() {
latency := time.Since(start)
reqID := middleware.GetReqID(ctx)
m.log.InfoContext(ctx, "request",
"request_start", start,
"method", request.Method,
"url", request.URL.String(),
"useragent", request.UserAgent(),
"request_id", reqID,
"referer", request.Referer(),
"proto", request.Proto,
"remoteIP", realIP(request),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, request)
})
}
}
func ipFromHostPort(hostPort string) string {
host, _, err := net.SplitHostPort(hostPort)
if err != nil {
return hostPort
}
return host
}
// trustedProxyNets are RFC1918 and loopback CIDRs whose proxy headers we trust.
//
//nolint:gochecknoglobals // package-level constant nets parsed once
var trustedProxyNets = func() []*net.IPNet {
cidrs := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
nets := make([]*net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, n, _ := net.ParseCIDR(cidr)
nets = append(nets, n)
}
return nets
}()
// isTrustedProxy reports whether ip is in an RFC1918, loopback, or ULA range.
func isTrustedProxy(ip net.IP) bool {
for _, n := range trustedProxyNets {
if n.Contains(ip) {
return true
}
}
return false
}
// realIP extracts the client's real IP address from the request.
// Proxy headers (X-Real-IP, X-Forwarded-For) are only trusted when the
// direct connection originates from an RFC1918/loopback address.
// Otherwise, headers are ignored and RemoteAddr is used (fail closed).
func realIP(r *http.Request) string {
addr := ipFromHostPort(r.RemoteAddr)
remoteIP := net.ParseIP(addr)
// Only trust proxy headers from private/loopback sources.
if remoteIP == nil || !isTrustedProxy(remoteIP) {
return addr
}
// 1. X-Real-IP (set by Traefik/nginx)
if ip := strings.TrimSpace(r.Header.Get("X-Real-IP")); ip != "" {
return ip
}
// 2. X-Forwarded-For: take the first (leftmost/client) IP
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { //nolint:mnd
if ip := strings.TrimSpace(parts[0]); ip != "" {
return ip
}
}
}
// 3. Fall back to RemoteAddr
return addr
}
// CORS returns CORS middleware.
func (m *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: corsMaxAge,
})
}
// MetricsAuth returns basic auth middleware for metrics endpoint.
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
if m.params.Config.MetricsUsername == "" {
return func(next http.Handler) http.Handler {
return next
}
}
return basicauth.New(
"metrics",
map[string][]string{
m.params.Config.MetricsUsername: {m.params.Config.MetricsPassword},
},
)
}
// SessionAuth returns middleware that requires authentication.
func (m *Middleware) SessionAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
user, err := m.params.Auth.GetCurrentUser(request.Context(), request)
if err != nil || user == nil {
http.Redirect(writer, request, "/login", http.StatusSeeOther)
return
}
next.ServeHTTP(writer, request)
})
}
}
// CSRF returns CSRF protection middleware using gorilla/csrf.
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
return csrf.Protect(
[]byte(m.params.Config.SessionSecret),
csrf.Secure(false), // Allow HTTP for development; reverse proxy handles TLS
csrf.Path("/"),
)
}
// loginRateLimit configures the login rate limiter.
const (
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
loginBurst = 5 // allow burst of 5
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
limiterCleanupEvery = 1 * time.Minute // sweep interval
)
// ipLimiterEntry stores a rate limiter with its last-seen timestamp.
type ipLimiterEntry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// ipLimiter tracks per-IP rate limiters for login attempts with automatic
// eviction of stale entries to prevent unbounded memory growth.
type ipLimiter struct {
mu sync.Mutex
limiters map[string]*ipLimiterEntry
lastSweep time.Time
}
func newIPLimiter() *ipLimiter {
return &ipLimiter{
limiters: make(map[string]*ipLimiterEntry),
lastSweep: time.Now(),
}
}
// sweep removes entries not seen within limiterExpiry. Must be called with mu held.
func (i *ipLimiter) sweep(now time.Time) {
for ip, entry := range i.limiters {
if now.Sub(entry.lastSeen) > limiterExpiry {
delete(i.limiters, ip)
}
}
i.lastSweep = now
}
func (i *ipLimiter) getLimiter(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()
now := time.Now()
// Lazy sweep: clean up stale entries periodically.
if now.Sub(i.lastSweep) >= limiterCleanupEvery {
i.sweep(now)
}
entry, exists := i.limiters[ip]
if !exists {
entry = &ipLimiterEntry{
limiter: rate.NewLimiter(loginRateLimit, loginBurst),
}
i.limiters[ip] = entry
}
entry.lastSeen = now
return entry.limiter
}
// loginLimiter is the singleton IP rate limiter for login attempts.
//
//nolint:gochecknoglobals // intentional singleton for rate limiting state
var loginLimiter = newIPLimiter()
// LoginRateLimit returns middleware that rate-limits login attempts per IP.
// It allows 5 attempts per minute and returns 429 Too Many Requests when exceeded.
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
ip := realIP(request)
limiter := loginLimiter.getLimiter(ip)
if !limiter.Allow() {
m.log.WarnContext(request.Context(), "login rate limit exceeded",
"remoteIP", ip,
)
// Compute seconds until the next token is available.
reservation := limiter.Reserve()
delay := reservation.Delay()
reservation.Cancel()
retryAfter := max(int(math.Ceil(delay.Seconds())), 1)
writer.Header().Set("Retry-After", strconv.Itoa(retryAfter))
http.Error(
writer,
"Too Many Requests",
http.StatusTooManyRequests,
)
return
}
next.ServeHTTP(writer, request)
})
}
}
// APITokenAuth returns middleware that authenticates requests via Bearer token.
// It looks up the token hash in the database and stores the user in context.
func (m *Middleware) APITokenAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
authHeader := request.Header.Get("Authorization")
if authHeader == "" {
http.Error(writer, `{"error":"missing Authorization header"}`, http.StatusUnauthorized)
return
}
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
http.Error(writer, `{"error":"invalid Authorization header"}`, http.StatusUnauthorized)
return
}
rawToken := strings.TrimPrefix(authHeader, bearerPrefix)
if rawToken == "" {
http.Error(writer, `{"error":"empty token"}`, http.StatusUnauthorized)
return
}
hash := models.HashAPIToken(rawToken)
apiToken, err := models.FindAPITokenByHash(request.Context(), m.params.Database, hash)
if err != nil {
m.log.Error("api token lookup error", "error", err)
http.Error(writer, `{"error":"internal server error"}`, http.StatusInternalServerError)
return
}
if apiToken == nil {
http.Error(writer, `{"error":"invalid token"}`, http.StatusUnauthorized)
return
}
// Touch last used (best-effort, don't block on error)
_ = apiToken.TouchLastUsed(request.Context())
user, userErr := models.FindUser(request.Context(), m.params.Database, apiToken.UserID)
if userErr != nil || user == nil {
http.Error(writer, `{"error":"token user not found"}`, http.StatusUnauthorized)
return
}
ctx := context.WithValue(request.Context(), apiUserContextKey{}, user)
next.ServeHTTP(writer, request.WithContext(ctx))
})
}
}
// APIUserFromContext extracts the authenticated API user from the context.
func APIUserFromContext(ctx context.Context) *models.User {
user, _ := ctx.Value(apiUserContextKey{}).(*models.User)
return user
}
// SetupRequired returns middleware that redirects to setup if no user exists.
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
setupRequired, err := m.params.Auth.IsSetupRequired(request.Context())
if err != nil {
m.log.Error("failed to check setup status", "error", err)
http.Error(
writer,
"Internal Server Error",
http.StatusInternalServerError,
)
return
}
if setupRequired {
// Allow access to setup page
if request.URL.Path == "/setup" {
next.ServeHTTP(writer, request)
return
}
http.Redirect(writer, request, "/setup", http.StatusSeeOther)
return
}
// Block setup page if already set up
if request.URL.Path == "/setup" {
http.Redirect(writer, request, "/", http.StatusSeeOther)
return
}
next.ServeHTTP(writer, request)
})
}
}