fix: prevent setup endpoint race condition (closes #26) #31

Merged
sneak merged 4 commits from :fix/setup-race-condition-closes-26 into main 2026-02-16 06:45:02 +01:00
3 changed files with 109 additions and 12 deletions
Showing only changes of commit 763e722607 - Show all commits

View File

@ -135,6 +135,46 @@ func FindUserByUsername(
return user, nil
}
// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING.
// It returns nil, nil if the insert was a no-op due to a conflict (user already exists).
func CreateUserAtomic(
ctx context.Context,
db *database.Database,
username, passwordHash string,
) (*User, error) {
query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING"
result, err := db.Exec(ctx, query, username, passwordHash)
if err != nil {
return nil, fmt.Errorf("inserting user: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, fmt.Errorf("checking rows affected: %w", err)
}
if rowsAffected == 0 {
// Conflict: user already exists
return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened)
}
insertID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("getting last insert id: %w", err)
}
user := NewUser(db)
user.ID = insertID
err = user.Reload(ctx)
if err != nil {
return nil, fmt.Errorf("reloading user: %w", err)
}
return user, nil
}
// UserExists checks if any user exists in the database.
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
var count int

View File

@ -10,6 +10,7 @@ import (
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
@ -64,6 +65,7 @@ type Service struct {
db *database.Database
store *sessions.CookieStore
params *ServiceParams
setupMu sync.Mutex
}
// New creates a new auth Service.
@ -163,11 +165,17 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
}
// CreateUser creates the initial admin user.
// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions
// where multiple concurrent requests could create duplicate admin users.
func (svc *Service) CreateUser(
ctx context.Context,
username, password string,
) (*models.User, error) {
// Check if user already exists
// Serialize setup attempts to prevent TOCTOU race conditions.
svc.setupMu.Lock()
defer svc.setupMu.Unlock()
// Check if any user already exists (setup already completed).
exists, err := models.UserExists(ctx, svc.db)
if err != nil {
return nil, fmt.Errorf("failed to check if user exists: %w", err)
@ -183,14 +191,16 @@ func (svc *Service) CreateUser(
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Create user
user := models.NewUser(svc.db)
user.Username = username
user.PasswordHash = hash
err = user.Save(ctx)
// Use INSERT ... ON CONFLICT to handle any remaining race at the DB level.
// This is defense-in-depth: the mutex above prevents the Go-level race,
// and the UNIQUE constraint + ON CONFLICT prevents the DB-level race.
user, err := models.CreateUserAtomic(ctx, svc.db, username, hash)
if err != nil {
return nil, fmt.Errorf("failed to save user: %w", err)
return nil, fmt.Errorf("failed to create user: %w", err)
}
if user == nil {
return nil, ErrUserExists
}
svc.log.Info("user created", "username", username)

View File

@ -2,6 +2,7 @@ package auth_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"path/filepath"
@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) {
})
}
func TestCreateUserRaceCondition(testingT *testing.T) {
testingT.Parallel()
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
t.Parallel()
svc, cleanup := setupTestService(t)
defer cleanup()
const goroutines = 10
results := make(chan error, goroutines)
start := make(chan struct{})
for i := range goroutines {
go func(idx int) {
<-start // Wait for all goroutines to be ready
_, err := svc.CreateUser(
context.Background(),
fmt.Sprintf("admin%d", idx),
"password123456",
)
results <- err
}(i)
}
// Release all goroutines simultaneously
close(start)
var successes, failures int
for range goroutines {
err := <-results
if err == nil {
successes++
} else {
assert.ErrorIs(t, err, auth.ErrUserExists)
failures++
}
}
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
})
}
func TestAuthenticate(testingT *testing.T) {
testingT.Parallel()