package middleware import ( "context" "crypto/rand" "encoding/hex" "net/http" ) const ( // csrfTokenLength is the byte length of generated CSRF tokens. // 32 bytes = 64 hex characters, providing 256 bits of entropy. csrfTokenLength = 32 // csrfSessionKey is the session key where the CSRF token is stored. csrfSessionKey = "csrf_token" // csrfFormField is the HTML form field name for the CSRF token. csrfFormField = "csrf_token" ) // csrfContextKey is the context key type for CSRF tokens. type csrfContextKey struct{} // CSRFToken retrieves the CSRF token from the request context. // Returns an empty string if no token is present. func CSRFToken(r *http.Request) string { if token, ok := r.Context().Value(csrfContextKey{}).(string); ok { return token } return "" } // CSRF returns middleware that provides CSRF protection for state-changing // requests. For every request, it ensures a CSRF token exists in the // session and makes it available via the request context. For POST, PUT, // PATCH, and DELETE requests, it validates the submitted csrf_token form // field against the session token. Requests with an invalid or missing // token receive a 403 Forbidden response. func (m *Middleware) CSRF() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sess, err := m.session.Get(r) if err != nil { m.log.Error("csrf: failed to get session", "error", err) http.Error(w, "Forbidden", http.StatusForbidden) return } // Ensure a CSRF token exists in the session token, ok := sess.Values[csrfSessionKey].(string) if !ok { token = "" } if token == "" { token, err = generateCSRFToken() if err != nil { m.log.Error("csrf: failed to generate token", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } sess.Values[csrfSessionKey] = token if saveErr := m.session.Save(r, w, sess); saveErr != nil { m.log.Error("csrf: failed to save session", "error", saveErr) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } } // Store token in context for templates ctx := context.WithValue(r.Context(), csrfContextKey{}, token) r = r.WithContext(ctx) // Validate token on state-changing methods switch r.Method { case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: submitted := r.FormValue(csrfFormField) if !secureCompare(submitted, token) { m.log.Warn("csrf: token mismatch", "method", r.Method, "path", r.URL.Path, "remote_addr", r.RemoteAddr, ) http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) return } } next.ServeHTTP(w, r) }) } } // generateCSRFToken creates a cryptographically random hex-encoded token. func generateCSRFToken() (string, error) { b := make([]byte, csrfTokenLength) if _, err := rand.Read(b); err != nil { return "", err } return hex.EncodeToString(b), nil } // secureCompare performs a constant-time string comparison to prevent // timing attacks on CSRF token validation. func secureCompare(a, b string) bool { if len(a) != len(b) { return false } var result byte for i := 0; i < len(a); i++ { result |= a[i] ^ b[i] } return result == 0 }