Files
pixa/internal/middleware/middleware.go
sneak 9bfae69ccf Fix logging: add response_bytes to middleware, cache_key to handler
- Middleware now tracks and logs bytes written via response_bytes
- Handler logs cache_key for cache hit debugging
- Changed "served encrypted image" to "image served" (only URL is encrypted)
2026-01-08 13:05:10 -08:00

167 lines
4.3 KiB
Go

// Package middleware provides HTTP middleware functions.
package middleware
import (
"log/slog"
"net"
"net/http"
"time"
basicauth "github.com/99designs/basicauth-go"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
metrics "github.com/slok/go-http-metrics/metrics/prometheus"
ghmm "github.com/slok/go-http-metrics/middleware"
"github.com/slok/go-http-metrics/middleware/std"
"go.uber.org/fx"
"sneak.berlin/go/pixa/internal/config"
"sneak.berlin/go/pixa/internal/logger"
)
// CORSMaxAgeSeconds is the max age for CORS preflight cache (24 hours).
const CORSMaxAgeSeconds = 86400
// Params defines dependencies for Middleware.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
}
// Middleware provides HTTP middleware functions.
type Middleware struct {
log *slog.Logger
config *config.Config
}
// New creates a new Middleware instance.
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
s := &Middleware{
log: params.Logger.Get(),
config: params.Config,
}
return s, nil
}
func ipFromHostPort(hp string) string {
h, _, err := net.SplitHostPort(hp)
if err != nil {
return ""
}
if len(h) > 0 && h[0] == '[' {
return h[1 : len(h)-1]
}
return h
}
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int64
}
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
func (lrw *loggingResponseWriter) Write(b []byte) (int, error) {
n, err := lrw.ResponseWriter.Write(b)
lrw.bytesWritten += int64(n)
return n, err
}
// Logging returns a logging middleware.
func (s *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
lrw := newLoggingResponseWriter(w)
ctx := r.Context()
defer func() {
latency := time.Since(start)
reqID, _ := ctx.Value(middleware.RequestIDKey).(string)
s.log.InfoContext(ctx, "request",
"request_start", start,
"method", r.Method,
"url", r.URL.String(),
"useragent", r.UserAgent(),
"request_id", reqID,
"referer", r.Referer(),
"proto", r.Proto,
"remoteIP", ipFromHostPort(r.RemoteAddr),
"status", lrw.statusCode,
"response_bytes", lrw.bytesWritten,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, r)
})
}
}
// CORS returns a CORS middleware.
func (s *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "HEAD", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: CORSMaxAgeSeconds,
})
}
// Metrics returns a Prometheus metrics middleware.
func (s *Middleware) Metrics() func(http.Handler) http.Handler {
mdlw := ghmm.New(ghmm.Config{
Recorder: metrics.NewRecorder(metrics.Config{}),
})
return func(next http.Handler) http.Handler {
return std.Handler("", mdlw, next)
}
}
// MetricsAuth returns a basic auth middleware for the metrics endpoint.
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
return basicauth.New(
"metrics",
map[string][]string{
s.config.MetricsUsername: {
s.config.MetricsPassword,
},
},
)
}
// SecurityHeaders returns a middleware that adds security headers to responses.
// These headers help protect against common web vulnerabilities.
func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Control referrer information
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Disable XSS filtering (modern browsers don't need it, can cause issues)
w.Header().Set("X-XSS-Protection", "0")
next.ServeHTTP(w, r)
})
}
}