fix: prevent setup endpoint race condition (closes #26)

Add mutex and INSERT ON CONFLICT to CreateUser to prevent TOCTOU race
where concurrent requests could create multiple admin users.

Changes:
- Add sync.Mutex to auth.Service to serialize CreateUser calls
- Add models.CreateUserAtomic using INSERT ... ON CONFLICT(username) DO NOTHING
- Check RowsAffected to detect conflicts at the DB level (defense-in-depth)
- Add concurrent race condition test (10 goroutines, only 1 succeeds)

The existing UNIQUE constraint on users.username was already in place.
This fix adds the application-level protection (items 1 & 2 from #26).
This commit is contained in:
clawbot 2026-02-15 21:35:16 -08:00
parent 97ee1e212f
commit 763e722607
3 changed files with 109 additions and 12 deletions

View File

@ -135,6 +135,46 @@ func FindUserByUsername(
return user, nil return user, nil
} }
// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING.
// It returns nil, nil if the insert was a no-op due to a conflict (user already exists).
func CreateUserAtomic(
ctx context.Context,
db *database.Database,
username, passwordHash string,
) (*User, error) {
query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING"
result, err := db.Exec(ctx, query, username, passwordHash)
if err != nil {
return nil, fmt.Errorf("inserting user: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, fmt.Errorf("checking rows affected: %w", err)
}
if rowsAffected == 0 {
// Conflict: user already exists
return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened)
}
insertID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("getting last insert id: %w", err)
}
user := NewUser(db)
user.ID = insertID
err = user.Reload(ctx)
if err != nil {
return nil, fmt.Errorf("reloading user: %w", err)
}
return user, nil
}
// UserExists checks if any user exists in the database. // UserExists checks if any user exists in the database.
func UserExists(ctx context.Context, db *database.Database) (bool, error) { func UserExists(ctx context.Context, db *database.Database) (bool, error) {
var count int var count int

View File

@ -10,6 +10,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
@ -60,10 +61,11 @@ type ServiceParams struct {
// Service provides authentication functionality. // Service provides authentication functionality.
type Service struct { type Service struct {
log *slog.Logger log *slog.Logger
db *database.Database db *database.Database
store *sessions.CookieStore store *sessions.CookieStore
params *ServiceParams params *ServiceParams
setupMu sync.Mutex
} }
// New creates a new auth Service. // New creates a new auth Service.
@ -163,11 +165,17 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
} }
// CreateUser creates the initial admin user. // CreateUser creates the initial admin user.
// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions
// where multiple concurrent requests could create duplicate admin users.
func (svc *Service) CreateUser( func (svc *Service) CreateUser(
ctx context.Context, ctx context.Context,
username, password string, username, password string,
) (*models.User, error) { ) (*models.User, error) {
// Check if user already exists // Serialize setup attempts to prevent TOCTOU race conditions.
svc.setupMu.Lock()
defer svc.setupMu.Unlock()
// Check if any user already exists (setup already completed).
exists, err := models.UserExists(ctx, svc.db) exists, err := models.UserExists(ctx, svc.db)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err) return nil, fmt.Errorf("failed to check if user exists: %w", err)
@ -183,14 +191,16 @@ func (svc *Service) CreateUser(
return nil, fmt.Errorf("failed to hash password: %w", err) return nil, fmt.Errorf("failed to hash password: %w", err)
} }
// Create user // Use INSERT ... ON CONFLICT to handle any remaining race at the DB level.
user := models.NewUser(svc.db) // This is defense-in-depth: the mutex above prevents the Go-level race,
user.Username = username // and the UNIQUE constraint + ON CONFLICT prevents the DB-level race.
user.PasswordHash = hash user, err := models.CreateUserAtomic(ctx, svc.db, username, hash)
err = user.Save(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to save user: %w", err) return nil, fmt.Errorf("failed to create user: %w", err)
}
if user == nil {
return nil, ErrUserExists
} }
svc.log.Info("user created", "username", username) svc.log.Info("user created", "username", username)

View File

@ -2,6 +2,7 @@ package auth_test
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath" "path/filepath"
@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) {
}) })
} }
func TestCreateUserRaceCondition(testingT *testing.T) {
testingT.Parallel()
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
t.Parallel()
svc, cleanup := setupTestService(t)
defer cleanup()
const goroutines = 10
results := make(chan error, goroutines)
start := make(chan struct{})
for i := range goroutines {
go func(idx int) {
<-start // Wait for all goroutines to be ready
_, err := svc.CreateUser(
context.Background(),
fmt.Sprintf("admin%d", idx),
"password123456",
)
results <- err
}(i)
}
// Release all goroutines simultaneously
close(start)
var successes, failures int
for range goroutines {
err := <-results
if err == nil {
successes++
} else {
assert.ErrorIs(t, err, auth.ErrUserExists)
failures++
}
}
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
})
}
func TestAuthenticate(testingT *testing.T) { func TestAuthenticate(testingT *testing.T) {
testingT.Parallel() testingT.Parallel()