diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index ac86471..7e84f3a 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -3,6 +3,7 @@ package middleware import ( "context" "crypto/rand" + "crypto/subtle" "encoding/hex" "net/http" ) @@ -75,7 +76,7 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler { switch r.Method { case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: submitted := r.FormValue(csrfFormField) - if !secureCompare(submitted, token) { + if subtle.ConstantTimeCompare([]byte(submitted), []byte(token)) != 1 { m.log.Warn("csrf: token mismatch", "method", r.Method, "path", r.URL.Path, @@ -99,16 +100,3 @@ func generateCSRFToken() (string, error) { } return hex.EncodeToString(b), nil } - -// secureCompare performs a constant-time string comparison to prevent -// timing attacks on CSRF token validation. -func secureCompare(a, b string) bool { - if len(a) != len(b) { - return false - } - var result byte - for i := 0; i < len(a); i++ { - result |= a[i] ^ b[i] - } - return result == 0 -} diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go index f0022ac..3940e61 100644 --- a/internal/middleware/csrf_test.go +++ b/internal/middleware/csrf_test.go @@ -173,12 +173,3 @@ func TestGenerateCSRFToken(t *testing.T) { require.NoError(t, err) assert.NotEqual(t, token, token2, "each generated token should be unique") } - -func TestSecureCompare(t *testing.T) { - t.Parallel() - assert.True(t, secureCompare("abc", "abc")) - assert.False(t, secureCompare("abc", "abd")) - assert.False(t, secureCompare("abc", "ab")) - assert.False(t, secureCompare("", "a")) - assert.True(t, secureCompare("", "")) -}