Add gorilla/csrf middleware to protect all HTML-serving routes against
cross-site request forgery attacks. The webhook endpoint is excluded
since it uses secret-based authentication.
Changes:
- Add gorilla/csrf v1.7.3 dependency
- Add CSRF() middleware method using session secret as key
- Apply CSRF middleware to all HTML route groups in routes.go
- Pass CSRF token to all templates via addGlobals helper
- Add {{ .CSRFField }} / {{ $.CSRFField }} hidden inputs to all forms
Closes #11
208 lines
5.0 KiB
Go
208 lines
5.0 KiB
Go
// Package middleware provides HTTP middleware.
|
|
package middleware
|
|
|
|
import (
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/99designs/basicauth-go"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/go-chi/cors"
|
|
"github.com/gorilla/csrf"
|
|
"go.uber.org/fx"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/config"
|
|
"git.eeqj.de/sneak/upaas/internal/globals"
|
|
"git.eeqj.de/sneak/upaas/internal/logger"
|
|
"git.eeqj.de/sneak/upaas/internal/service/auth"
|
|
)
|
|
|
|
// corsMaxAge is the maximum age for CORS preflight responses in seconds.
|
|
const corsMaxAge = 300
|
|
|
|
// Params contains dependencies for Middleware.
|
|
type Params struct {
|
|
fx.In
|
|
|
|
Logger *logger.Logger
|
|
Globals *globals.Globals
|
|
Config *config.Config
|
|
Auth *auth.Service
|
|
}
|
|
|
|
// Middleware provides HTTP middleware.
|
|
type Middleware struct {
|
|
log *slog.Logger
|
|
params *Params
|
|
}
|
|
|
|
// New creates a new Middleware instance.
|
|
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
|
|
return &Middleware{
|
|
log: params.Logger.Get(),
|
|
params: ¶ms,
|
|
}, nil
|
|
}
|
|
|
|
// loggingResponseWriter wraps http.ResponseWriter to capture status code.
|
|
type loggingResponseWriter struct {
|
|
http.ResponseWriter
|
|
|
|
statusCode int
|
|
}
|
|
|
|
func newLoggingResponseWriter(
|
|
writer http.ResponseWriter,
|
|
) *loggingResponseWriter {
|
|
return &loggingResponseWriter{writer, http.StatusOK}
|
|
}
|
|
|
|
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
|
lrw.statusCode = code
|
|
lrw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// Logging returns a request logging middleware.
|
|
func (m *Middleware) Logging() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(
|
|
writer http.ResponseWriter,
|
|
request *http.Request,
|
|
) {
|
|
start := time.Now()
|
|
lrw := newLoggingResponseWriter(writer)
|
|
ctx := request.Context()
|
|
|
|
defer func() {
|
|
latency := time.Since(start)
|
|
reqID := middleware.GetReqID(ctx)
|
|
m.log.InfoContext(ctx, "request",
|
|
"request_start", start,
|
|
"method", request.Method,
|
|
"url", request.URL.String(),
|
|
"useragent", request.UserAgent(),
|
|
"request_id", reqID,
|
|
"referer", request.Referer(),
|
|
"proto", request.Proto,
|
|
"remoteIP", ipFromHostPort(request.RemoteAddr),
|
|
"status", lrw.statusCode,
|
|
"latency_ms", latency.Milliseconds(),
|
|
)
|
|
}()
|
|
|
|
next.ServeHTTP(lrw, request)
|
|
})
|
|
}
|
|
}
|
|
|
|
func ipFromHostPort(hostPort string) string {
|
|
host, _, err := net.SplitHostPort(hostPort)
|
|
if err != nil {
|
|
return hostPort
|
|
}
|
|
|
|
return host
|
|
}
|
|
|
|
// CORS returns CORS middleware.
|
|
func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
|
return cors.Handler(cors.Options{
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
|
ExposedHeaders: []string{"Link"},
|
|
AllowCredentials: false,
|
|
MaxAge: corsMaxAge,
|
|
})
|
|
}
|
|
|
|
// MetricsAuth returns basic auth middleware for metrics endpoint.
|
|
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
|
if m.params.Config.MetricsUsername == "" {
|
|
return func(next http.Handler) http.Handler {
|
|
return next
|
|
}
|
|
}
|
|
|
|
return basicauth.New(
|
|
"metrics",
|
|
map[string][]string{
|
|
m.params.Config.MetricsUsername: {m.params.Config.MetricsPassword},
|
|
},
|
|
)
|
|
}
|
|
|
|
// SessionAuth returns middleware that requires authentication.
|
|
func (m *Middleware) SessionAuth() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(
|
|
writer http.ResponseWriter,
|
|
request *http.Request,
|
|
) {
|
|
user, err := m.params.Auth.GetCurrentUser(request.Context(), request)
|
|
if err != nil || user == nil {
|
|
http.Redirect(writer, request, "/login", http.StatusSeeOther)
|
|
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(writer, request)
|
|
})
|
|
}
|
|
}
|
|
|
|
// CSRF returns CSRF protection middleware using gorilla/csrf.
|
|
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
|
return csrf.Protect(
|
|
[]byte(m.params.Config.SessionSecret),
|
|
csrf.Secure(false), // Allow HTTP for development; reverse proxy handles TLS
|
|
csrf.Path("/"),
|
|
)
|
|
}
|
|
|
|
// SetupRequired returns middleware that redirects to setup if no user exists.
|
|
func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(
|
|
writer http.ResponseWriter,
|
|
request *http.Request,
|
|
) {
|
|
setupRequired, err := m.params.Auth.IsSetupRequired(request.Context())
|
|
if err != nil {
|
|
m.log.Error("failed to check setup status", "error", err)
|
|
http.Error(
|
|
writer,
|
|
"Internal Server Error",
|
|
http.StatusInternalServerError,
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
if setupRequired {
|
|
// Allow access to setup page
|
|
if request.URL.Path == "/setup" {
|
|
next.ServeHTTP(writer, request)
|
|
|
|
return
|
|
}
|
|
|
|
http.Redirect(writer, request, "/setup", http.StatusSeeOther)
|
|
|
|
return
|
|
}
|
|
|
|
// Block setup page if already set up
|
|
if request.URL.Path == "/setup" {
|
|
http.Redirect(writer, request, "/", http.StatusSeeOther)
|
|
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(writer, request)
|
|
})
|
|
}
|
|
}
|