simplify: replace mutex + ON CONFLICT with a single DB transaction
Remove the sync.Mutex and CreateUserAtomic (INSERT ON CONFLICT) in favor of a single DB transaction in CreateFirstUser that atomically checks for existing users and inserts. SQLite serializes write transactions, so this is sufficient to prevent the race condition without application-level locking.
This commit is contained in:
parent
763e722607
commit
97a5aae2f7
@ -135,28 +135,42 @@ func FindUserByUsername(
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING.
|
// CreateFirstUser atomically checks that no users exist and inserts the admin user.
|
||||||
// It returns nil, nil if the insert was a no-op due to a conflict (user already exists).
|
// Returns nil, nil if a user already exists (setup already completed).
|
||||||
func CreateUserAtomic(
|
func CreateFirstUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db *database.Database,
|
db *database.Database,
|
||||||
username, passwordHash string,
|
username, passwordHash string,
|
||||||
) (*User, error) {
|
) (*User, error) {
|
||||||
query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING"
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
result, err := db.Exec(ctx, query, username, passwordHash)
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
// Check if any user exists within the transaction.
|
||||||
|
var count int
|
||||||
|
|
||||||
|
err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("checking user count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
return nil, nil //nolint:nilnil // nil,nil signals setup already completed
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tx.ExecContext(ctx,
|
||||||
|
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
|
||||||
|
username, passwordHash,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("inserting user: %w", err)
|
return nil, fmt.Errorf("inserting user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
if err = tx.Commit(); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("committing transaction: %w", err)
|
||||||
return nil, fmt.Errorf("checking rows affected: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
|
||||||
// Conflict: user already exists
|
|
||||||
return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
insertID, err := result.LastInsertId()
|
insertID, err := result.LastInsertId()
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
@ -65,7 +64,6 @@ type Service struct {
|
|||||||
db *database.Database
|
db *database.Database
|
||||||
store *sessions.CookieStore
|
store *sessions.CookieStore
|
||||||
params *ServiceParams
|
params *ServiceParams
|
||||||
setupMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new auth Service.
|
// New creates a new auth Service.
|
||||||
@ -165,36 +163,21 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates the initial admin user.
|
// CreateUser creates the initial admin user.
|
||||||
// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions
|
// It uses a DB transaction to atomically check that no users exist and insert
|
||||||
// where multiple concurrent requests could create duplicate admin users.
|
// the new admin user, preventing race conditions from concurrent setup requests.
|
||||||
func (svc *Service) CreateUser(
|
func (svc *Service) CreateUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
username, password string,
|
username, password string,
|
||||||
) (*models.User, error) {
|
) (*models.User, error) {
|
||||||
// Serialize setup attempts to prevent TOCTOU race conditions.
|
// Hash password before starting transaction.
|
||||||
svc.setupMu.Lock()
|
|
||||||
defer svc.setupMu.Unlock()
|
|
||||||
|
|
||||||
// Check if any user already exists (setup already completed).
|
|
||||||
exists, err := models.UserExists(ctx, svc.db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
return nil, ErrUserExists
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash password
|
|
||||||
hash, err := svc.HashPassword(password)
|
hash, err := svc.HashPassword(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use INSERT ... ON CONFLICT to handle any remaining race at the DB level.
|
// Use a transaction so the "no users exist" check and the insert are atomic.
|
||||||
// This is defense-in-depth: the mutex above prevents the Go-level race,
|
// SQLite serializes write transactions, so concurrent requests will block here.
|
||||||
// and the UNIQUE constraint + ON CONFLICT prevents the DB-level race.
|
user, err := models.CreateFirstUser(ctx, svc.db, username, hash)
|
||||||
user, err := models.CreateUserAtomic(ctx, svc.db, username, hash)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user