// Package middleware provides HTTP middleware for the chat server. package middleware import ( "log/slog" "net" "net/http" "time" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/globals" "git.eeqj.de/sneak/chat/internal/logger" basicauth "github.com/99designs/basicauth-go" "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" metrics "github.com/slok/go-http-metrics/metrics/prometheus" ghmm "github.com/slok/go-http-metrics/middleware" "github.com/slok/go-http-metrics/middleware/std" "github.com/spf13/viper" "go.uber.org/fx" ) const corsMaxAge = 300 // Params defines the dependencies for creating Middleware. type Params struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config } // Middleware provides HTTP middleware handlers. type Middleware struct { log *slog.Logger params *Params } // New creates a new Middleware instance. func New(_ fx.Lifecycle, params Params) (*Middleware, error) { s := new(Middleware) s.params = ¶ms s.log = params.Logger.Get() return s, nil } func ipFromHostPort(hp string) string { h, _, err := net.SplitHostPort(hp) if err != nil { return "" } if len(h) > 0 && h[0] == '[' { return h[1 : len(h)-1] } return h } type loggingResponseWriter struct { http.ResponseWriter statusCode int } // newLoggingResponseWriter wraps a ResponseWriter to capture the status code. func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { return &loggingResponseWriter{w, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // Logging returns middleware that logs each HTTP request. func (s *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := newLoggingResponseWriter(w) ctx := r.Context() defer func() { latency := time.Since(start) reqID, _ := ctx.Value(middleware.RequestIDKey).(string) s.log.InfoContext(ctx, "request", "request_start", start, "method", r.Method, "url", r.URL.String(), "useragent", r.UserAgent(), "request_id", reqID, "referer", r.Referer(), "proto", r.Proto, "remoteIP", ipFromHostPort(r.RemoteAddr), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, r) }) } } // CORS returns middleware that handles Cross-Origin Resource Sharing. func (s *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: corsMaxAge, }) } // Auth returns middleware that performs authentication. func (s *Middleware) Auth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.log.Info("AUTH: before request") next.ServeHTTP(w, r) }) } } // Metrics returns middleware that records HTTP metrics. func (s *Middleware) Metrics() func(http.Handler) http.Handler { mdlw := ghmm.New(ghmm.Config{ Recorder: metrics.NewRecorder(metrics.Config{}), }) return func(next http.Handler) http.Handler { return std.Handler("", mdlw, next) } } // MetricsAuth returns middleware that protects metrics with basic auth. func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler { return basicauth.New( "metrics", map[string][]string{ viper.GetString("METRICS_USERNAME"): { viper.GetString("METRICS_PASSWORD"), }, }, ) }