All checks were successful
check / check (push) Successful in 2m19s
Security: - Add channel membership check before PRIVMSG (prevents non-members from sending) - Add membership check on history endpoint (channels require membership, DMs scoped to own nick) - Enforce MaxBytesReader on all POST request bodies - Fix rand.Read error being silently ignored in token generation Data integrity: - Fix TOCTOU race in GetOrCreateChannel using INSERT OR IGNORE + SELECT Build: - Add CGO_ENABLED=0 to golangci-lint install in Dockerfile (fixes alpine build) Linting: - Strict .golangci.yml: only wsl disabled (deprecated in v2) - Re-enable exhaustruct, depguard, godot, wrapcheck, varnamelen - Fix linters-settings -> linters.settings for v2 config format - Fix ALL lint findings in actual code (no linter config weakening) - Wrap all external package errors (wrapcheck) - Fill struct fields or add targeted nolint:exhaustruct where appropriate - Rename short variables (ts->timestamp, n->bufIndex, etc.) - Add depguard deny policy for io/ioutil and math/rand - Exclude G704 (SSRF) in gosec config (CLI client takes user-configured URLs) Tests: - Add security tests (TestNonMemberCannotSend, TestHistoryNonMember) - Split TestInsertAndPollMessages for reduced complexity - Fix parallel test safety (viper global state prevents parallelism) - Use t.Context() instead of context.Background() in tests Docker build verified passing locally.
183 lines
4.3 KiB
Go
183 lines
4.3 KiB
Go
// Package middleware provides HTTP middleware for the chat server.
|
|
package middleware
|
|
|
|
import (
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/chat/internal/config"
|
|
"git.eeqj.de/sneak/chat/internal/globals"
|
|
"git.eeqj.de/sneak/chat/internal/logger"
|
|
basicauth "github.com/99designs/basicauth-go"
|
|
chimw "github.com/go-chi/chi/middleware"
|
|
"github.com/go-chi/cors"
|
|
metrics "github.com/slok/go-http-metrics/metrics/prometheus"
|
|
ghmm "github.com/slok/go-http-metrics/middleware"
|
|
"github.com/slok/go-http-metrics/middleware/std"
|
|
"github.com/spf13/viper"
|
|
"go.uber.org/fx"
|
|
)
|
|
|
|
const corsMaxAge = 300
|
|
|
|
// Params defines the dependencies for creating Middleware.
|
|
type Params struct {
|
|
fx.In
|
|
|
|
Logger *logger.Logger
|
|
Globals *globals.Globals
|
|
Config *config.Config
|
|
}
|
|
|
|
// Middleware provides HTTP middleware handlers.
|
|
type Middleware struct {
|
|
log *slog.Logger
|
|
params *Params
|
|
}
|
|
|
|
// New creates a new Middleware instance.
|
|
func New(
|
|
_ fx.Lifecycle, params Params,
|
|
) (*Middleware, error) {
|
|
mware := &Middleware{
|
|
params: ¶ms,
|
|
log: params.Logger.Get(),
|
|
}
|
|
|
|
return mware, nil
|
|
}
|
|
|
|
func ipFromHostPort(hostPort string) string {
|
|
host, _, err := net.SplitHostPort(hostPort)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
if len(host) > 0 && host[0] == '[' {
|
|
return host[1 : len(host)-1]
|
|
}
|
|
|
|
return host
|
|
}
|
|
|
|
type loggingResponseWriter struct {
|
|
http.ResponseWriter
|
|
|
|
statusCode int
|
|
}
|
|
|
|
// newLoggingResponseWriter wraps a ResponseWriter
|
|
// to capture the status code.
|
|
func newLoggingResponseWriter(
|
|
writer http.ResponseWriter,
|
|
) *loggingResponseWriter {
|
|
return &loggingResponseWriter{
|
|
ResponseWriter: writer,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
}
|
|
|
|
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
|
lrw.statusCode = code
|
|
lrw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// Logging returns middleware that logs each HTTP request.
|
|
func (mware *Middleware) Logging() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(
|
|
func(
|
|
writer http.ResponseWriter,
|
|
request *http.Request,
|
|
) {
|
|
start := time.Now()
|
|
lrw := newLoggingResponseWriter(writer)
|
|
ctx := request.Context()
|
|
|
|
defer func() {
|
|
latency := time.Since(start)
|
|
|
|
reqID, _ := ctx.Value(
|
|
chimw.RequestIDKey,
|
|
).(string)
|
|
|
|
mware.log.InfoContext(
|
|
ctx, "request",
|
|
"request_start", start,
|
|
"method", request.Method,
|
|
"url", request.URL.String(),
|
|
"useragent", request.UserAgent(),
|
|
"request_id", reqID,
|
|
"referer", request.Referer(),
|
|
"proto", request.Proto,
|
|
"remoteIP",
|
|
ipFromHostPort(request.RemoteAddr),
|
|
"status", lrw.statusCode,
|
|
"latency_ms",
|
|
latency.Milliseconds(),
|
|
)
|
|
}()
|
|
|
|
next.ServeHTTP(lrw, request)
|
|
})
|
|
}
|
|
}
|
|
|
|
// CORS returns middleware that handles Cross-Origin Resource Sharing.
|
|
func (mware *Middleware) CORS() func(http.Handler) http.Handler {
|
|
return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedMethods: []string{
|
|
"GET", "POST", "PUT", "DELETE", "OPTIONS",
|
|
},
|
|
AllowedHeaders: []string{
|
|
"Accept", "Authorization",
|
|
"Content-Type", "X-CSRF-Token",
|
|
},
|
|
ExposedHeaders: []string{"Link"},
|
|
AllowCredentials: false,
|
|
MaxAge: corsMaxAge,
|
|
})
|
|
}
|
|
|
|
// Auth returns middleware that performs authentication.
|
|
func (mware *Middleware) Auth() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(
|
|
func(
|
|
writer http.ResponseWriter,
|
|
request *http.Request,
|
|
) {
|
|
mware.log.Info("AUTH: before request")
|
|
next.ServeHTTP(writer, request)
|
|
})
|
|
}
|
|
}
|
|
|
|
// Metrics returns middleware that records HTTP metrics.
|
|
func (mware *Middleware) Metrics() func(http.Handler) http.Handler {
|
|
metricsMiddleware := ghmm.New(ghmm.Config{ //nolint:exhaustruct // optional fields
|
|
Recorder: metrics.NewRecorder(
|
|
metrics.Config{}, //nolint:exhaustruct // defaults
|
|
),
|
|
})
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return std.Handler("", metricsMiddleware, next)
|
|
}
|
|
}
|
|
|
|
// MetricsAuth returns middleware that protects metrics with basic auth.
|
|
func (mware *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
|
return basicauth.New(
|
|
"metrics",
|
|
map[string][]string{
|
|
viper.GetString("METRICS_USERNAME"): {
|
|
viper.GetString("METRICS_PASSWORD"),
|
|
},
|
|
},
|
|
)
|
|
}
|