package middleware import ( "net/http" "github.com/gorilla/csrf" ) // CSRFToken retrieves the CSRF token from the request context. // Returns an empty string if the gorilla/csrf middleware has not run. func CSRFToken(r *http.Request) string { return csrf.Token(r) } // isClientTLS reports whether the client-facing connection uses TLS. // It checks for a direct TLS connection (r.TLS) or a TLS-terminating // reverse proxy that sets the standard X-Forwarded-Proto header. func isClientTLS(r *http.Request) bool { return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" } // CSRF returns middleware that provides CSRF protection using the // gorilla/csrf library. The middleware uses the session authentication // key to sign a CSRF cookie and validates a masked token submitted via // the "csrf_token" form field (or the "X-CSRF-Token" header) on // POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing // token receive a 403 Forbidden response. // // The middleware detects the client-facing transport protocol per-request // using r.TLS and the X-Forwarded-Proto header. This allows correct // behavior in all deployment scenarios: // // - Direct HTTPS: strict Referer/Origin checks, Secure cookies. // - Behind a TLS-terminating reverse proxy: strict checks (the // browser is on HTTPS, so Origin/Referer headers use https://), // Secure cookies (the browser sees HTTPS from the proxy). // - Direct HTTP: relaxed Referer/Origin checks via PlaintextHTTPRequest, // non-Secure cookies so the browser sends them over HTTP. // // Two gorilla/csrf instances are maintained — one with Secure cookies // (for TLS) and one without (for plaintext HTTP) — because the // csrf.Secure option is set at creation time, not per-request. func (m *Middleware) CSRF() func(http.Handler) http.Handler { csrfErrorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.log.Warn("csrf: token validation failed", "method", r.Method, "path", r.URL.Path, "remote_addr", r.RemoteAddr, "reason", csrf.FailureReason(r), ) http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) }) key := m.session.GetKey() baseOpts := []csrf.Option{ csrf.FieldName("csrf_token"), csrf.SameSite(csrf.SameSiteLaxMode), csrf.Path("/"), csrf.ErrorHandler(csrfErrorHandler), } // Two middleware instances with different Secure flags but the // same signing key, so cookies are interchangeable between them. tlsProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(true))...) httpProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(false))...) return func(next http.Handler) http.Handler { tlsCSRF := tlsProtect(next) httpCSRF := httpProtect(next) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if isClientTLS(r) { // Client is on TLS (directly or via reverse proxy). // Use Secure cookies and strict Origin/Referer checks. tlsCSRF.ServeHTTP(w, r) } else { // Plaintext HTTP: use non-Secure cookies and tell // gorilla/csrf to use "http" for scheme comparisons, // skipping the strict Referer check that assumes TLS. httpCSRF.ServeHTTP(w, csrf.PlaintextHTTPRequest(r)) } }) } }