diff --git a/internal/models/user.go b/internal/models/user.go index dd66fb3..9a48f09 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -135,6 +135,61 @@ func FindUserByUsername( return user, nil } +// CreateFirstUser atomically checks that no users exist and inserts the admin user. +// Returns nil, nil if a user already exists (setup already completed). +func CreateFirstUser( + ctx context.Context, + db *database.Database, + username, passwordHash string, +) (*User, error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("beginning transaction: %w", err) + } + + defer func() { _ = tx.Rollback() }() + + // Check if any user exists within the transaction. + var count int + + err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count) + if err != nil { + return nil, fmt.Errorf("checking user count: %w", err) + } + + if count > 0 { + return nil, nil //nolint:nilnil // nil,nil signals setup already completed + } + + result, err := tx.ExecContext(ctx, + "INSERT INTO users (username, password_hash) VALUES (?, ?)", + username, passwordHash, + ) + if err != nil { + return nil, fmt.Errorf("inserting user: %w", err) + } + + err = tx.Commit() + if err != nil { + return nil, fmt.Errorf("committing transaction: %w", err) + } + + insertID, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("getting last insert id: %w", err) + } + + user := NewUser(db) + user.ID = insertID + + err = user.Reload(ctx) + if err != nil { + return nil, fmt.Errorf("reloading user: %w", err) + } + + return user, nil +} + // UserExists checks if any user exists in the database. func UserExists(ctx context.Context, db *database.Database) (bool, error) { var count int diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index 94447e9..aa83128 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -163,34 +163,27 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) { } // CreateUser creates the initial admin user. +// It uses a DB transaction to atomically check that no users exist and insert +// the new admin user, preventing race conditions from concurrent setup requests. func (svc *Service) CreateUser( ctx context.Context, username, password string, ) (*models.User, error) { - // Check if user already exists - exists, err := models.UserExists(ctx, svc.db) - if err != nil { - return nil, fmt.Errorf("failed to check if user exists: %w", err) - } - - if exists { - return nil, ErrUserExists - } - - // Hash password + // Hash password before starting transaction. hash, err := svc.HashPassword(password) if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } - // Create user - user := models.NewUser(svc.db) - user.Username = username - user.PasswordHash = hash - - err = user.Save(ctx) + // Use a transaction so the "no users exist" check and the insert are atomic. + // SQLite serializes write transactions, so concurrent requests will block here. + user, err := models.CreateFirstUser(ctx, svc.db, username, hash) if err != nil { - return nil, fmt.Errorf("failed to save user: %w", err) + return nil, fmt.Errorf("failed to create user: %w", err) + } + + if user == nil { + return nil, ErrUserExists } svc.log.Info("user created", "username", username) diff --git a/internal/service/auth/auth_test.go b/internal/service/auth/auth_test.go index 9ca9430..3d1b8a3 100644 --- a/internal/service/auth/auth_test.go +++ b/internal/service/auth/auth_test.go @@ -2,6 +2,7 @@ package auth_test import ( "context" + "fmt" "net/http" "net/http/httptest" "path/filepath" @@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) { }) } +func TestCreateUserRaceCondition(testingT *testing.T) { + testingT.Parallel() + + testingT.Run("concurrent setup requests create only one user", func(t *testing.T) { + t.Parallel() + + svc, cleanup := setupTestService(t) + defer cleanup() + + const goroutines = 10 + + results := make(chan error, goroutines) + start := make(chan struct{}) + + for i := range goroutines { + go func(idx int) { + <-start // Wait for all goroutines to be ready + + _, err := svc.CreateUser( + context.Background(), + fmt.Sprintf("admin%d", idx), + "password123456", + ) + results <- err + }(i) + } + + // Release all goroutines simultaneously + close(start) + + var successes, failures int + for range goroutines { + err := <-results + if err == nil { + successes++ + } else { + assert.ErrorIs(t, err, auth.ErrUserExists) + failures++ + } + } + + assert.Equal(t, 1, successes, "exactly one goroutine should succeed") + assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists") + }) +} + func TestAuthenticate(testingT *testing.T) { testingT.Parallel()