// Package middleware provides HTTP middleware. package middleware import ( "log/slog" "net" "net/http" "strings" "time" "github.com/99designs/basicauth-go" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/globals" "sneak.berlin/go/dnswatcher/internal/logger" ) // corsMaxAge is the maximum age for CORS preflight responses. const corsMaxAge = 300 // Params contains dependencies for Middleware. type Params struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config } // Middleware provides HTTP middleware. type Middleware struct { log *slog.Logger params *Params } // New creates a new Middleware instance. func New( _ fx.Lifecycle, params Params, ) (*Middleware, error) { return &Middleware{ log: params.Logger.Get(), params: ¶ms, }, nil } // loggingResponseWriter wraps http.ResponseWriter to capture status. type loggingResponseWriter struct { http.ResponseWriter statusCode int } func newLoggingResponseWriter( writer http.ResponseWriter, ) *loggingResponseWriter { return &loggingResponseWriter{writer, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // Logging returns a request logging middleware. func (m *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func( writer http.ResponseWriter, request *http.Request, ) { start := time.Now() lrw := newLoggingResponseWriter(writer) ctx := request.Context() defer func() { latency := time.Since(start) reqID := middleware.GetReqID(ctx) m.log.InfoContext(ctx, "request", "request_start", start, "method", request.Method, "url", request.URL.String(), "useragent", request.UserAgent(), "request_id", reqID, "referer", request.Referer(), "proto", request.Proto, "remoteIP", realIP(request), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, request) }) } } func ipFromHostPort(hostPort string) string { host, _, err := net.SplitHostPort(hostPort) if err != nil { return hostPort } return host } // trustedProxyNets are RFC1918 and loopback CIDRs. // //nolint:gochecknoglobals // package-level constant nets parsed once var trustedProxyNets = func() []*net.IPNet { cidrs := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", } nets := make([]*net.IPNet, 0, len(cidrs)) for _, cidr := range cidrs { _, n, _ := net.ParseCIDR(cidr) nets = append(nets, n) } return nets }() func isTrustedProxy(ip net.IP) bool { for _, n := range trustedProxyNets { if n.Contains(ip) { return true } } return false } // realIP extracts the client's real IP address from the request. // Proxy headers are only trusted from RFC1918/loopback addresses. func realIP(r *http.Request) string { addr := ipFromHostPort(r.RemoteAddr) remoteIP := net.ParseIP(addr) if remoteIP == nil || !isTrustedProxy(remoteIP) { return addr } if ip := strings.TrimSpace( r.Header.Get("X-Real-IP"), ); ip != "" { return ip } if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if parts := strings.SplitN( xff, ",", 2, //nolint:mnd ); len(parts) > 0 { if ip := strings.TrimSpace(parts[0]); ip != "" { return ip } } } return addr } // CORS returns CORS middleware. func (m *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{ "GET", "POST", "PUT", "DELETE", "OPTIONS", }, AllowedHeaders: []string{ "Accept", "Authorization", "Content-Type", "X-CSRF-Token", }, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: corsMaxAge, }) } // MetricsAuth returns basic auth middleware for /metrics. func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler { if m.params.Config.MetricsUsername == "" { return func(next http.Handler) http.Handler { return next } } return basicauth.New( "metrics", map[string][]string{ m.params.Config.MetricsUsername: { m.params.Config.MetricsPassword, }, }, ) }