package middleware import ( "log/slog" "net" "net/http" "time" "git.eeqj.de/sneak/webhooker/internal/config" "git.eeqj.de/sneak/webhooker/internal/globals" "git.eeqj.de/sneak/webhooker/internal/logger" basicauth "github.com/99designs/basicauth-go" "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" metrics "github.com/slok/go-http-metrics/metrics/prometheus" ghmm "github.com/slok/go-http-metrics/middleware" "github.com/slok/go-http-metrics/middleware/std" "go.uber.org/fx" ) // nolint:revive // MiddlewareParams is a standard fx naming convention type MiddlewareParams struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config } type Middleware struct { log *slog.Logger params *MiddlewareParams } func New(lc fx.Lifecycle, params MiddlewareParams) (*Middleware, error) { s := new(Middleware) s.params = ¶ms s.log = params.Logger.Get() return s, nil } // the following is from // https://learning-cloud-native-go.github.io/docs/a6.adding_zerolog_logger/ func ipFromHostPort(hp string) string { h, _, err := net.SplitHostPort(hp) if err != nil { return "" } if len(h) > 0 && h[0] == '[' { return h[1 : len(h)-1] } return h } type loggingResponseWriter struct { http.ResponseWriter statusCode int } // nolint:revive // unexported type is only used internally func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { return &loggingResponseWriter{w, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // type Middleware func(http.Handler) http.Handler // this returns a Middleware that is designed to do every request through the // mux, note the signature: func (s *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := NewLoggingResponseWriter(w) ctx := r.Context() defer func() { latency := time.Since(start) requestID := "" if reqID := ctx.Value(middleware.RequestIDKey); reqID != nil { if id, ok := reqID.(string); ok { requestID = id } } s.log.Info("http request", "request_start", start, "method", r.Method, "url", r.URL.String(), "useragent", r.UserAgent(), "request_id", requestID, "referer", r.Referer(), "proto", r.Proto, "remoteIP", ipFromHostPort(r.RemoteAddr), "status", lrw.statusCode, "latency_ms", latency.Milliseconds(), ) }() next.ServeHTTP(lrw, r) }) } } func (s *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ // CHANGEME! these are defaults, change them to suit your needs or // read from environment/viper. // AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts AllowedOrigins: []string{"*"}, // AllowOriginFunc: func(r *http.Request, origin string) bool { return true }, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: 300, // Maximum value not ignored by any of major browsers }) } func (s *Middleware) Auth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // TODO: implement proper authentication s.log.Debug("AUTH: before request") next.ServeHTTP(w, r) }) } } func (s *Middleware) Metrics() func(http.Handler) http.Handler { mdlw := ghmm.New(ghmm.Config{ Recorder: metrics.NewRecorder(metrics.Config{}), }) return func(next http.Handler) http.Handler { return std.Handler("", mdlw, next) } } func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler { return basicauth.New( "metrics", map[string][]string{ s.params.Config.MetricsUsername: { s.params.Config.MetricsPassword, }, }, ) }