package middleware import ( "net/http" "github.com/gorilla/csrf" ) // CSRFToken retrieves the CSRF token from the request context. // Returns an empty string if the gorilla/csrf middleware has not run. func CSRFToken(r *http.Request) string { return csrf.Token(r) } // CSRF returns middleware that provides CSRF protection using the // gorilla/csrf library. The middleware uses the session authentication // key to sign a CSRF cookie and validates a masked token submitted via // the "csrf_token" form field (or the "X-CSRF-Token" header) on // POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing // token receive a 403 Forbidden response. // // In development mode, requests are marked as plaintext HTTP so that // gorilla/csrf skips the strict Referer-origin check (which is only // meaningful over TLS). func (m *Middleware) CSRF() func(http.Handler) http.Handler { protect := csrf.Protect( m.session.GetKey(), csrf.FieldName("csrf_token"), csrf.Secure(!m.params.Config.IsDev()), csrf.SameSite(csrf.SameSiteLaxMode), csrf.Path("/"), csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.log.Warn("csrf: token validation failed", "method", r.Method, "path", r.URL.Path, "remote_addr", r.RemoteAddr, "reason", csrf.FailureReason(r), ) http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) })), ) // In development (plaintext HTTP), signal gorilla/csrf to skip // the strict TLS Referer check by injecting the PlaintextHTTP // context key before the CSRF handler sees the request. if m.params.Config.IsDev() { return func(next http.Handler) http.Handler { csrfHandler := protect(next) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r)) }) } } return protect }