fix: prevent setup endpoint race condition (closes #26) #31
@ -135,6 +135,46 @@ func FindUserByUsername(
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING.
|
||||
// It returns nil, nil if the insert was a no-op due to a conflict (user already exists).
|
||||
func CreateUserAtomic(
|
||||
ctx context.Context,
|
||||
db *database.Database,
|
||||
username, passwordHash string,
|
||||
) (*User, error) {
|
||||
query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING"
|
||||
|
||||
result, err := db.Exec(ctx, query, username, passwordHash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inserting user: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("checking rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
// Conflict: user already exists
|
||||
return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened)
|
||||
}
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting last insert id: %w", err)
|
||||
}
|
||||
|
||||
user := NewUser(db)
|
||||
user.ID = insertID
|
||||
|
||||
err = user.Reload(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reloading user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UserExists checks if any user exists in the database.
|
||||
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
|
||||
var count int
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
@ -60,10 +61,11 @@ type ServiceParams struct {
|
||||
|
||||
// Service provides authentication functionality.
|
||||
type Service struct {
|
||||
log *slog.Logger
|
||||
db *database.Database
|
||||
store *sessions.CookieStore
|
||||
params *ServiceParams
|
||||
log *slog.Logger
|
||||
db *database.Database
|
||||
store *sessions.CookieStore
|
||||
params *ServiceParams
|
||||
setupMu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new auth Service.
|
||||
@ -163,11 +165,17 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates the initial admin user.
|
||||
// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions
|
||||
// where multiple concurrent requests could create duplicate admin users.
|
||||
func (svc *Service) CreateUser(
|
||||
ctx context.Context,
|
||||
username, password string,
|
||||
) (*models.User, error) {
|
||||
// Check if user already exists
|
||||
// Serialize setup attempts to prevent TOCTOU race conditions.
|
||||
svc.setupMu.Lock()
|
||||
defer svc.setupMu.Unlock()
|
||||
|
||||
// Check if any user already exists (setup already completed).
|
||||
exists, err := models.UserExists(ctx, svc.db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||
@ -183,14 +191,16 @@ func (svc *Service) CreateUser(
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Create user
|
||||
user := models.NewUser(svc.db)
|
||||
user.Username = username
|
||||
user.PasswordHash = hash
|
||||
|
||||
err = user.Save(ctx)
|
||||
// Use INSERT ... ON CONFLICT to handle any remaining race at the DB level.
|
||||
// This is defense-in-depth: the mutex above prevents the Go-level race,
|
||||
// and the UNIQUE constraint + ON CONFLICT prevents the DB-level race.
|
||||
user, err := models.CreateUserAtomic(ctx, svc.db, username, hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save user: %w", err)
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
|
||||
svc.log.Info("user created", "username", username)
|
||||
|
||||
@ -2,6 +2,7 @@ package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateUserRaceCondition(testingT *testing.T) {
|
||||
testingT.Parallel()
|
||||
|
||||
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, cleanup := setupTestService(t)
|
||||
defer cleanup()
|
||||
|
||||
const goroutines = 10
|
||||
|
||||
results := make(chan error, goroutines)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := range goroutines {
|
||||
go func(idx int) {
|
||||
<-start // Wait for all goroutines to be ready
|
||||
|
||||
_, err := svc.CreateUser(
|
||||
context.Background(),
|
||||
fmt.Sprintf("admin%d", idx),
|
||||
"password123456",
|
||||
)
|
||||
results <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Release all goroutines simultaneously
|
||||
close(start)
|
||||
|
||||
var successes, failures int
|
||||
for range goroutines {
|
||||
err := <-results
|
||||
if err == nil {
|
||||
successes++
|
||||
} else {
|
||||
assert.ErrorIs(t, err, auth.ErrUserExists)
|
||||
failures++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
|
||||
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticate(testingT *testing.T) {
|
||||
testingT.Parallel()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user