fix: prevent setup endpoint race condition (closes #26) #31

Merged
sneak merged 4 commits from :fix/setup-race-condition-closes-26 into main 2026-02-16 06:45:02 +01:00
3 changed files with 113 additions and 18 deletions

View File

@ -135,6 +135,61 @@ func FindUserByUsername(
return user, nil
}
// CreateFirstUser atomically checks that no users exist and inserts the admin user.
// Returns nil, nil if a user already exists (setup already completed).
func CreateFirstUser(
ctx context.Context,
db *database.Database,
username, passwordHash string,
) (*User, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("beginning transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Check if any user exists within the transaction.
var count int
err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count)
if err != nil {
return nil, fmt.Errorf("checking user count: %w", err)
}
if count > 0 {
return nil, nil //nolint:nilnil // nil,nil signals setup already completed
}
result, err := tx.ExecContext(ctx,
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
username, passwordHash,
)
if err != nil {
return nil, fmt.Errorf("inserting user: %w", err)
}
err = tx.Commit()
if err != nil {
return nil, fmt.Errorf("committing transaction: %w", err)
}
insertID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("getting last insert id: %w", err)
}
user := NewUser(db)
user.ID = insertID
err = user.Reload(ctx)
if err != nil {
return nil, fmt.Errorf("reloading user: %w", err)
}
return user, nil
}
// UserExists checks if any user exists in the database.
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
var count int

View File

@ -163,34 +163,27 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
}
// CreateUser creates the initial admin user.
// It uses a DB transaction to atomically check that no users exist and insert
// the new admin user, preventing race conditions from concurrent setup requests.
func (svc *Service) CreateUser(
ctx context.Context,
username, password string,
) (*models.User, error) {
// Check if user already exists
exists, err := models.UserExists(ctx, svc.db)
if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err)
}
if exists {
return nil, ErrUserExists
}
// Hash password
// Hash password before starting transaction.
hash, err := svc.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Create user
user := models.NewUser(svc.db)
user.Username = username
user.PasswordHash = hash
err = user.Save(ctx)
// Use a transaction so the "no users exist" check and the insert are atomic.
// SQLite serializes write transactions, so concurrent requests will block here.
user, err := models.CreateFirstUser(ctx, svc.db, username, hash)
if err != nil {
return nil, fmt.Errorf("failed to save user: %w", err)
return nil, fmt.Errorf("failed to create user: %w", err)
}
if user == nil {
return nil, ErrUserExists
}
svc.log.Info("user created", "username", username)

View File

@ -2,6 +2,7 @@ package auth_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"path/filepath"
@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) {
})
}
func TestCreateUserRaceCondition(testingT *testing.T) {
testingT.Parallel()
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
t.Parallel()
svc, cleanup := setupTestService(t)
defer cleanup()
const goroutines = 10
results := make(chan error, goroutines)
start := make(chan struct{})
for i := range goroutines {
go func(idx int) {
<-start // Wait for all goroutines to be ready
_, err := svc.CreateUser(
context.Background(),
fmt.Sprintf("admin%d", idx),
"password123456",
)
results <- err
}(i)
}
// Release all goroutines simultaneously
close(start)
var successes, failures int
for range goroutines {
err := <-results
if err == nil {
successes++
} else {
assert.ErrorIs(t, err, auth.ErrUserExists)
failures++
}
}
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
})
}
func TestAuthenticate(testingT *testing.T) {
testingT.Parallel()