package middleware import ( "net" "net/http" "time" basicauth "github.com/99designs/basicauth-go" "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" "github.com/rs/zerolog" "sneak.berlin/go/directory/internal/config" "sneak.berlin/go/directory/internal/globals" "sneak.berlin/go/directory/internal/logger" metrics "github.com/slok/go-http-metrics/metrics/prometheus" ghmm "github.com/slok/go-http-metrics/middleware" "github.com/slok/go-http-metrics/middleware/std" "github.com/spf13/viper" "go.uber.org/fx" ) type MiddlewareParams struct { fx.In Logger *logger.Logger Globals *globals.Globals Config *config.Config } type Middleware struct { log *zerolog.Logger params *MiddlewareParams } func New(lc fx.Lifecycle, params MiddlewareParams) (*Middleware, error) { s := new(Middleware) s.params = ¶ms s.log = params.Logger.Get() return s, nil } // the following is from // https://learning-cloud-native-go.github.io/docs/a6.adding_zerolog_logger/ func ipFromHostPort(hp string) string { h, _, err := net.SplitHostPort(hp) if err != nil { return "" } if len(h) > 0 && h[0] == '[' { return h[1 : len(h)-1] } return h } type loggingResponseWriter struct { http.ResponseWriter statusCode int } func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { return &loggingResponseWriter{w, http.StatusOK} } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } // type Middleware func(http.Handler) http.Handler // this returns a Middleware that is designed to do every request through the // mux, note the signature: func (s *Middleware) Logging() func(http.Handler) http.Handler { // FIXME this should use https://github.com/google/go-cloud/blob/master/server/requestlog/requestlog.go return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := NewLoggingResponseWriter(w) ctx := r.Context() defer func() { latency := time.Since(start) s.log.Info(). Time("request_start", start). Str("method", r.Method). Str("url", r.URL.String()). Str("useragent", r.UserAgent()). Str("request_id", ctx.Value(middleware.RequestIDKey).(string)). Str("referer", r.Referer()). Str("proto", r.Proto). Str("remoteIP", ipFromHostPort(r.RemoteAddr)). Int("status", lrw.statusCode). Int("latency_ms", int(latency.Milliseconds())). Send() }() next.ServeHTTP(lrw, r) }) } } func (s *Middleware) CORS() func(http.Handler) http.Handler { return cors.Handler(cors.Options{ // CHANGEME! these are defaults, change them to suit your needs or // read from environment/viper. // AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts AllowedOrigins: []string{"*"}, // AllowOriginFunc: func(r *http.Request, origin string) bool { return true }, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: 300, // Maximum value not ignored by any of major browsers }) } func (s *Middleware) Metrics() func(http.Handler) http.Handler { mdlw := ghmm.New(ghmm.Config{ Recorder: metrics.NewRecorder(metrics.Config{}), }) return func(next http.Handler) http.Handler { return std.Handler("", mdlw, next) } } func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler { return basicauth.New( "metrics", map[string][]string{ viper.GetString("METRICS_USERNAME"): { viper.GetString("METRICS_PASSWORD"), }, }, ) }