// Package middleware provides HTTP middleware functions. package middleware import ( "log/slog" "net" "net/http" "time" basicauth "github.com/99designs/basicauth-go" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" metrics "github.com/slok/go-http-metrics/metrics/prometheus" ghmm "github.com/slok/go-http-metrics/middleware" "github.com/slok/go-http-metrics/middleware/std" "go.uber.org/fx" "sneak.berlin/go/pixa/internal/config" "sneak.berlin/go/pixa/internal/logger" ) // CORSMaxAgeSeconds is the max age for CORS preflight cache (24 hours). const CORSMaxAgeSeconds = 86400 // Params defines dependencies for Middleware. type Params struct { fx.In Logger *logger.Logger Config *config.Config } // Middleware provides HTTP middleware functions. type Middleware struct { log *slog.Logger config *config.Config } // New creates a new Middleware instance. func New(_ fx.Lifecycle, params Params) (*Middleware, error) { s := &Middleware{ log: params.Logger.Get(), config: params.Config, } return s, nil } func ipFromHostPort(hp string) string { h, _, err := net.SplitHostPort(hp) if err != nil { return "" } if len(h) > 0 && h[0] == '[' { return h[1 : len(h)-1] } return h } type loggingResponseWriter struct { http.ResponseWriter statusCode int } func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { return &loggingResponseWriter{w, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // Logging returns a logging middleware. func (s *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := newLoggingResponseWriter(w) ctx := r.Context() defer func() { latency := time.Since(start) reqID, _ := ctx.Value(middleware.RequestIDKey).(string) s.log.InfoContext(ctx, "request", "request_start", start, "method", r.Method, "url", r.URL.String(), "useragent", r.UserAgent(), "request_id", reqID, "referer", r.Referer(), "proto", r.Proto, "remoteIP", ipFromHostPort(r.RemoteAddr), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, r) }) } } // CORS returns a CORS middleware. func (s *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "HEAD", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: CORSMaxAgeSeconds, }) } // Metrics returns a Prometheus metrics middleware. func (s *Middleware) Metrics() func(http.Handler) http.Handler { mdlw := ghmm.New(ghmm.Config{ Recorder: metrics.NewRecorder(metrics.Config{}), }) return func(next http.Handler) http.Handler { return std.Handler("", mdlw, next) } } // MetricsAuth returns a basic auth middleware for the metrics endpoint. func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler { return basicauth.New( "metrics", map[string][]string{ s.config.MetricsUsername: { s.config.MetricsPassword, }, }, ) }