dnswatcher/internal/middleware/middleware.go
sneak 144a2df665 Initial scaffold with per-nameserver DNS monitoring model
Full project structure following upaas conventions: uber/fx DI, go-chi
routing, slog logging, Viper config. State persisted as JSON file with
per-nameserver record tracking for inconsistency detection. Stub
implementations for resolver, portcheck, tlscheck, and watcher.
2026-02-19 21:05:39 +01:00

206 lines
4.3 KiB
Go

// Package middleware provides HTTP middleware.
package middleware
import (
"log/slog"
"net"
"net/http"
"strings"
"time"
"github.com/99designs/basicauth-go"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/globals"
"sneak.berlin/go/dnswatcher/internal/logger"
)
// corsMaxAge is the maximum age for CORS preflight responses.
const corsMaxAge = 300
// Params contains dependencies for Middleware.
type Params struct {
fx.In
Logger *logger.Logger
Globals *globals.Globals
Config *config.Config
}
// Middleware provides HTTP middleware.
type Middleware struct {
log *slog.Logger
params *Params
}
// New creates a new Middleware instance.
func New(
_ fx.Lifecycle,
params Params,
) (*Middleware, error) {
return &Middleware{
log: params.Logger.Get(),
params: &params,
}, nil
}
// loggingResponseWriter wraps http.ResponseWriter to capture status.
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func newLoggingResponseWriter(
writer http.ResponseWriter,
) *loggingResponseWriter {
return &loggingResponseWriter{writer, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Logging returns a request logging middleware.
func (m *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(
writer http.ResponseWriter,
request *http.Request,
) {
start := time.Now()
lrw := newLoggingResponseWriter(writer)
ctx := request.Context()
defer func() {
latency := time.Since(start)
reqID := middleware.GetReqID(ctx)
m.log.InfoContext(ctx, "request",
"request_start", start,
"method", request.Method,
"url", request.URL.String(),
"useragent", request.UserAgent(),
"request_id", reqID,
"referer", request.Referer(),
"proto", request.Proto,
"remoteIP", realIP(request),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, request)
})
}
}
func ipFromHostPort(hostPort string) string {
host, _, err := net.SplitHostPort(hostPort)
if err != nil {
return hostPort
}
return host
}
// trustedProxyNets are RFC1918 and loopback CIDRs.
//
//nolint:gochecknoglobals // package-level constant nets parsed once
var trustedProxyNets = func() []*net.IPNet {
cidrs := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
nets := make([]*net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, n, _ := net.ParseCIDR(cidr)
nets = append(nets, n)
}
return nets
}()
func isTrustedProxy(ip net.IP) bool {
for _, n := range trustedProxyNets {
if n.Contains(ip) {
return true
}
}
return false
}
// realIP extracts the client's real IP address from the request.
// Proxy headers are only trusted from RFC1918/loopback addresses.
func realIP(r *http.Request) string {
addr := ipFromHostPort(r.RemoteAddr)
remoteIP := net.ParseIP(addr)
if remoteIP == nil || !isTrustedProxy(remoteIP) {
return addr
}
if ip := strings.TrimSpace(
r.Header.Get("X-Real-IP"),
); ip != "" {
return ip
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if parts := strings.SplitN(
xff, ",", 2, //nolint:mnd
); len(parts) > 0 {
if ip := strings.TrimSpace(parts[0]); ip != "" {
return ip
}
}
}
return addr
}
// CORS returns CORS middleware.
func (m *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
"GET", "POST", "PUT", "DELETE", "OPTIONS",
},
AllowedHeaders: []string{
"Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: corsMaxAge,
})
}
// MetricsAuth returns basic auth middleware for /metrics.
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
if m.params.Config.MetricsUsername == "" {
return func(next http.Handler) http.Handler {
return next
}
}
return basicauth.New(
"metrics",
map[string][]string{
m.params.Config.MetricsUsername: {
m.params.Config.MetricsPassword,
},
},
)
}