Merge pull request 'fix: prevent setup endpoint race condition (closes #26)' (#31) from clawbot/upaas:fix/setup-race-condition-closes-26 into main
Reviewed-on: #31
This commit is contained in:
commit
297f6e64f4
@ -135,6 +135,61 @@ func FindUserByUsername(
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateFirstUser atomically checks that no users exist and inserts the admin user.
|
||||||
|
// Returns nil, nil if a user already exists (setup already completed).
|
||||||
|
func CreateFirstUser(
|
||||||
|
ctx context.Context,
|
||||||
|
db *database.Database,
|
||||||
|
username, passwordHash string,
|
||||||
|
) (*User, error) {
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
// Check if any user exists within the transaction.
|
||||||
|
var count int
|
||||||
|
|
||||||
|
err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("checking user count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
return nil, nil //nolint:nilnil // nil,nil signals setup already completed
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tx.ExecContext(ctx,
|
||||||
|
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
|
||||||
|
username, passwordHash,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inserting user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
insertID, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting last insert id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := NewUser(db)
|
||||||
|
user.ID = insertID
|
||||||
|
|
||||||
|
err = user.Reload(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reloading user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserExists checks if any user exists in the database.
|
// UserExists checks if any user exists in the database.
|
||||||
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
|
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
|
||||||
var count int
|
var count int
|
||||||
|
|||||||
@ -163,34 +163,27 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates the initial admin user.
|
// CreateUser creates the initial admin user.
|
||||||
|
// It uses a DB transaction to atomically check that no users exist and insert
|
||||||
|
// the new admin user, preventing race conditions from concurrent setup requests.
|
||||||
func (svc *Service) CreateUser(
|
func (svc *Service) CreateUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
username, password string,
|
username, password string,
|
||||||
) (*models.User, error) {
|
) (*models.User, error) {
|
||||||
// Check if user already exists
|
// Hash password before starting transaction.
|
||||||
exists, err := models.UserExists(ctx, svc.db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
return nil, ErrUserExists
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash password
|
|
||||||
hash, err := svc.HashPassword(password)
|
hash, err := svc.HashPassword(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create user
|
// Use a transaction so the "no users exist" check and the insert are atomic.
|
||||||
user := models.NewUser(svc.db)
|
// SQLite serializes write transactions, so concurrent requests will block here.
|
||||||
user.Username = username
|
user, err := models.CreateFirstUser(ctx, svc.db, username, hash)
|
||||||
user.PasswordHash = hash
|
|
||||||
|
|
||||||
err = user.Save(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to save user: %w", err)
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return nil, ErrUserExists
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.log.Info("user created", "username", username)
|
svc.log.Info("user created", "username", username)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package auth_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateUserRaceCondition(testingT *testing.T) {
|
||||||
|
testingT.Parallel()
|
||||||
|
|
||||||
|
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
svc, cleanup := setupTestService(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
const goroutines = 10
|
||||||
|
|
||||||
|
results := make(chan error, goroutines)
|
||||||
|
start := make(chan struct{})
|
||||||
|
|
||||||
|
for i := range goroutines {
|
||||||
|
go func(idx int) {
|
||||||
|
<-start // Wait for all goroutines to be ready
|
||||||
|
|
||||||
|
_, err := svc.CreateUser(
|
||||||
|
context.Background(),
|
||||||
|
fmt.Sprintf("admin%d", idx),
|
||||||
|
"password123456",
|
||||||
|
)
|
||||||
|
results <- err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release all goroutines simultaneously
|
||||||
|
close(start)
|
||||||
|
|
||||||
|
var successes, failures int
|
||||||
|
for range goroutines {
|
||||||
|
err := <-results
|
||||||
|
if err == nil {
|
||||||
|
successes++
|
||||||
|
} else {
|
||||||
|
assert.ErrorIs(t, err, auth.ErrUserExists)
|
||||||
|
failures++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
|
||||||
|
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthenticate(testingT *testing.T) {
|
func TestAuthenticate(testingT *testing.T) {
|
||||||
testingT.Parallel()
|
testingT.Parallel()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user