fix: prevent setup endpoint race condition (closes #26) #31
@ -135,28 +135,42 @@ func FindUserByUsername(
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING.
|
||||
// It returns nil, nil if the insert was a no-op due to a conflict (user already exists).
|
||||
func CreateUserAtomic(
|
||||
// CreateFirstUser atomically checks that no users exist and inserts the admin user.
|
||||
// Returns nil, nil if a user already exists (setup already completed).
|
||||
func CreateFirstUser(
|
||||
ctx context.Context,
|
||||
db *database.Database,
|
||||
username, passwordHash string,
|
||||
) (*User, error) {
|
||||
query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING"
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||
}
|
||||
|
||||
result, err := db.Exec(ctx, query, username, passwordHash)
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Check if any user exists within the transaction.
|
||||
var count int
|
||||
|
||||
err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("checking user count: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil, nil //nolint:nilnil // nil,nil signals setup already completed
|
||||
}
|
||||
|
||||
result, err := tx.ExecContext(ctx,
|
||||
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
|
||||
username, passwordHash,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inserting user: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("checking rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
// Conflict: user already exists
|
||||
return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened)
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
|
||||
@ -10,7 +10,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
@ -61,11 +60,10 @@ type ServiceParams struct {
|
||||
|
||||
// Service provides authentication functionality.
|
||||
type Service struct {
|
||||
log *slog.Logger
|
||||
db *database.Database
|
||||
store *sessions.CookieStore
|
||||
params *ServiceParams
|
||||
setupMu sync.Mutex
|
||||
log *slog.Logger
|
||||
db *database.Database
|
||||
store *sessions.CookieStore
|
||||
params *ServiceParams
|
||||
}
|
||||
|
||||
// New creates a new auth Service.
|
||||
@ -165,36 +163,21 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates the initial admin user.
|
||||
// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions
|
||||
// where multiple concurrent requests could create duplicate admin users.
|
||||
// It uses a DB transaction to atomically check that no users exist and insert
|
||||
// the new admin user, preventing race conditions from concurrent setup requests.
|
||||
func (svc *Service) CreateUser(
|
||||
ctx context.Context,
|
||||
username, password string,
|
||||
) (*models.User, error) {
|
||||
// Serialize setup attempts to prevent TOCTOU race conditions.
|
||||
svc.setupMu.Lock()
|
||||
defer svc.setupMu.Unlock()
|
||||
|
||||
// Check if any user already exists (setup already completed).
|
||||
exists, err := models.UserExists(ctx, svc.db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
|
||||
// Hash password
|
||||
// Hash password before starting transaction.
|
||||
hash, err := svc.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Use INSERT ... ON CONFLICT to handle any remaining race at the DB level.
|
||||
// This is defense-in-depth: the mutex above prevents the Go-level race,
|
||||
// and the UNIQUE constraint + ON CONFLICT prevents the DB-level race.
|
||||
user, err := models.CreateUserAtomic(ctx, svc.db, username, hash)
|
||||
// Use a transaction so the "no users exist" check and the insert are atomic.
|
||||
// SQLite serializes write transactions, so concurrent requests will block here.
|
||||
user, err := models.CreateFirstUser(ctx, svc.db, username, hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user