Replace custom secureCompare with crypto/subtle.ConstantTimeCompare
All checks were successful
check / check (push) Successful in 2m7s

Remove the hand-rolled secureCompare function and use the standard
library's crypto/subtle.ConstantTimeCompare for CSRF token validation.
Remove the corresponding unit test for the deleted function; CSRF token
comparison is still covered by the integration tests.
This commit is contained in:
clawbot
2026-03-10 02:39:24 -07:00
parent 7f4c40caca
commit 5c69efb5bc
2 changed files with 2 additions and 23 deletions

View File

@@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"encoding/hex" "encoding/hex"
"net/http" "net/http"
) )
@@ -75,7 +76,7 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
switch r.Method { switch r.Method {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
submitted := r.FormValue(csrfFormField) submitted := r.FormValue(csrfFormField)
if !secureCompare(submitted, token) { if subtle.ConstantTimeCompare([]byte(submitted), []byte(token)) != 1 {
m.log.Warn("csrf: token mismatch", m.log.Warn("csrf: token mismatch",
"method", r.Method, "method", r.Method,
"path", r.URL.Path, "path", r.URL.Path,
@@ -99,16 +100,3 @@ func generateCSRFToken() (string, error) {
} }
return hex.EncodeToString(b), nil return hex.EncodeToString(b), nil
} }
// secureCompare performs a constant-time string comparison to prevent
// timing attacks on CSRF token validation.
func secureCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
var result byte
for i := 0; i < len(a); i++ {
result |= a[i] ^ b[i]
}
return result == 0
}

View File

@@ -173,12 +173,3 @@ func TestGenerateCSRFToken(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, token, token2, "each generated token should be unique") assert.NotEqual(t, token, token2, "each generated token should be unique")
} }
func TestSecureCompare(t *testing.T) {
t.Parallel()
assert.True(t, secureCompare("abc", "abc"))
assert.False(t, secureCompare("abc", "abd"))
assert.False(t, secureCompare("abc", "ab"))
assert.False(t, secureCompare("", "a"))
assert.True(t, secureCompare("", ""))
}