// Package middleware provides HTTP middleware for the chat server. package middleware import ( "log/slog" "net" "net/http" "time" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/globals" "git.eeqj.de/sneak/chat/internal/logger" basicauth "github.com/99designs/basicauth-go" chimw "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" metrics "github.com/slok/go-http-metrics/metrics/prometheus" ghmm "github.com/slok/go-http-metrics/middleware" "github.com/slok/go-http-metrics/middleware/std" "github.com/spf13/viper" "go.uber.org/fx" ) const corsMaxAge = 300 // Params defines the dependencies for creating Middleware. type Params struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config } // Middleware provides HTTP middleware handlers. type Middleware struct { log *slog.Logger params *Params } // New creates a new Middleware instance. func New( _ fx.Lifecycle, params Params, ) (*Middleware, error) { mware := &Middleware{ params: ¶ms, log: params.Logger.Get(), } return mware, nil } func ipFromHostPort(hostPort string) string { host, _, err := net.SplitHostPort(hostPort) if err != nil { return "" } if len(host) > 0 && host[0] == '[' { return host[1 : len(host)-1] } return host } type loggingResponseWriter struct { http.ResponseWriter statusCode int } // newLoggingResponseWriter wraps a ResponseWriter // to capture the status code. func newLoggingResponseWriter( writer http.ResponseWriter, ) *loggingResponseWriter { return &loggingResponseWriter{ ResponseWriter: writer, statusCode: http.StatusOK, } } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // Logging returns middleware that logs each HTTP request. func (mware *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc( func( writer http.ResponseWriter, request *http.Request, ) { start := time.Now() lrw := newLoggingResponseWriter(writer) ctx := request.Context() defer func() { latency := time.Since(start) reqID, _ := ctx.Value( chimw.RequestIDKey, ).(string) mware.log.InfoContext( ctx, "request", "request_start", start, "method", request.Method, "url", request.URL.String(), "useragent", request.UserAgent(), "request_id", reqID, "referer", request.Referer(), "proto", request.Proto, "remoteIP", ipFromHostPort(request.RemoteAddr), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, request) }) } } // CORS returns middleware that handles Cross-Origin Resource Sharing. func (mware *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields AllowedOrigins: []string{"*"}, AllowedMethods: []string{ "GET", "POST", "PUT", "DELETE", "OPTIONS", }, AllowedHeaders: []string{ "Accept", "Authorization", "Content-Type", "X-CSRF-Token", }, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: corsMaxAge, }) } // Auth returns middleware that performs authentication. func (mware *Middleware) Auth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc( func( writer http.ResponseWriter, request *http.Request, ) { mware.log.Info("AUTH: before request") next.ServeHTTP(writer, request) }) } } // Metrics returns middleware that records HTTP metrics. func (mware *Middleware) Metrics() func(http.Handler) http.Handler { metricsMiddleware := ghmm.New(ghmm.Config{ //nolint:exhaustruct // optional fields Recorder: metrics.NewRecorder( metrics.Config{}, //nolint:exhaustruct // defaults ), }) return func(next http.Handler) http.Handler { return std.Handler("", metricsMiddleware, next) } } // MetricsAuth returns middleware that protects metrics with basic auth. func (mware *Middleware) MetricsAuth() func(http.Handler) http.Handler { return basicauth.New( "metrics", map[string][]string{ viper.GetString("METRICS_USERNAME"): { viper.GetString("METRICS_PASSWORD"), }, }, ) }