Files
pixa/internal/middleware/middleware.go
sneak 5de7a26735 Add failing tests for security headers middleware
Tests for X-Content-Type-Options, X-Frame-Options, Referrer-Policy,
X-XSS-Protection headers on responses.
2026-01-08 10:01:36 -08:00

146 lines
3.6 KiB
Go

// Package middleware provides HTTP middleware functions.
package middleware
import (
"log/slog"
"net"
"net/http"
"time"
basicauth "github.com/99designs/basicauth-go"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
metrics "github.com/slok/go-http-metrics/metrics/prometheus"
ghmm "github.com/slok/go-http-metrics/middleware"
"github.com/slok/go-http-metrics/middleware/std"
"go.uber.org/fx"
"sneak.berlin/go/pixa/internal/config"
"sneak.berlin/go/pixa/internal/logger"
)
// CORSMaxAgeSeconds is the max age for CORS preflight cache (24 hours).
const CORSMaxAgeSeconds = 86400
// Params defines dependencies for Middleware.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
}
// Middleware provides HTTP middleware functions.
type Middleware struct {
log *slog.Logger
config *config.Config
}
// New creates a new Middleware instance.
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
s := &Middleware{
log: params.Logger.Get(),
config: params.Config,
}
return s, nil
}
func ipFromHostPort(hp string) string {
h, _, err := net.SplitHostPort(hp)
if err != nil {
return ""
}
if len(h) > 0 && h[0] == '[' {
return h[1 : len(h)-1]
}
return h
}
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{w, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Logging returns a logging middleware.
func (s *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
lrw := newLoggingResponseWriter(w)
ctx := r.Context()
defer func() {
latency := time.Since(start)
reqID, _ := ctx.Value(middleware.RequestIDKey).(string)
s.log.InfoContext(ctx, "request",
"request_start", start,
"method", r.Method,
"url", r.URL.String(),
"useragent", r.UserAgent(),
"request_id", reqID,
"referer", r.Referer(),
"proto", r.Proto,
"remoteIP", ipFromHostPort(r.RemoteAddr),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, r)
})
}
}
// CORS returns a CORS middleware.
func (s *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "HEAD", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: CORSMaxAgeSeconds,
})
}
// Metrics returns a Prometheus metrics middleware.
func (s *Middleware) Metrics() func(http.Handler) http.Handler {
mdlw := ghmm.New(ghmm.Config{
Recorder: metrics.NewRecorder(metrics.Config{}),
})
return func(next http.Handler) http.Handler {
return std.Handler("", mdlw, next)
}
}
// MetricsAuth returns a basic auth middleware for the metrics endpoint.
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
return basicauth.New(
"metrics",
map[string][]string{
s.config.MetricsUsername: {
s.config.MetricsPassword,
},
},
)
}
// SecurityHeaders returns a middleware that adds security headers to responses.
func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// TODO: implement security headers
next.ServeHTTP(w, r)
})
}
}