From 763e7226075404692c80852f7b0d9ffcc9d631dc Mon Sep 17 00:00:00 2001 From: clawbot Date: Sun, 15 Feb 2026 21:35:16 -0800 Subject: [PATCH] fix: prevent setup endpoint race condition (closes #26) Add mutex and INSERT ON CONFLICT to CreateUser to prevent TOCTOU race where concurrent requests could create multiple admin users. Changes: - Add sync.Mutex to auth.Service to serialize CreateUser calls - Add models.CreateUserAtomic using INSERT ... ON CONFLICT(username) DO NOTHING - Check RowsAffected to detect conflicts at the DB level (defense-in-depth) - Add concurrent race condition test (10 goroutines, only 1 succeeds) The existing UNIQUE constraint on users.username was already in place. This fix adds the application-level protection (items 1 & 2 from #26). --- internal/models/user.go | 40 +++++++++++++++++++++++++ internal/service/auth/auth.go | 34 +++++++++++++-------- internal/service/auth/auth_test.go | 47 ++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 12 deletions(-) diff --git a/internal/models/user.go b/internal/models/user.go index dd66fb3..1110309 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -135,6 +135,46 @@ func FindUserByUsername( return user, nil } +// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING. +// It returns nil, nil if the insert was a no-op due to a conflict (user already exists). +func CreateUserAtomic( + ctx context.Context, + db *database.Database, + username, passwordHash string, +) (*User, error) { + query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING" + + result, err := db.Exec(ctx, query, username, passwordHash) + if err != nil { + return nil, fmt.Errorf("inserting user: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, fmt.Errorf("checking rows affected: %w", err) + } + + if rowsAffected == 0 { + // Conflict: user already exists + return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened) + } + + insertID, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("getting last insert id: %w", err) + } + + user := NewUser(db) + user.ID = insertID + + err = user.Reload(ctx) + if err != nil { + return nil, fmt.Errorf("reloading user: %w", err) + } + + return user, nil +} + // UserExists checks if any user exists in the database. func UserExists(ctx context.Context, db *database.Database) (bool, error) { var count int diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index 94447e9..0c7a853 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -10,6 +10,7 @@ import ( "log/slog" "net/http" "strings" + "sync" "time" "github.com/gorilla/sessions" @@ -60,10 +61,11 @@ type ServiceParams struct { // Service provides authentication functionality. type Service struct { - log *slog.Logger - db *database.Database - store *sessions.CookieStore - params *ServiceParams + log *slog.Logger + db *database.Database + store *sessions.CookieStore + params *ServiceParams + setupMu sync.Mutex } // New creates a new auth Service. @@ -163,11 +165,17 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) { } // CreateUser creates the initial admin user. +// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions +// where multiple concurrent requests could create duplicate admin users. func (svc *Service) CreateUser( ctx context.Context, username, password string, ) (*models.User, error) { - // Check if user already exists + // Serialize setup attempts to prevent TOCTOU race conditions. + svc.setupMu.Lock() + defer svc.setupMu.Unlock() + + // Check if any user already exists (setup already completed). exists, err := models.UserExists(ctx, svc.db) if err != nil { return nil, fmt.Errorf("failed to check if user exists: %w", err) @@ -183,14 +191,16 @@ func (svc *Service) CreateUser( return nil, fmt.Errorf("failed to hash password: %w", err) } - // Create user - user := models.NewUser(svc.db) - user.Username = username - user.PasswordHash = hash - - err = user.Save(ctx) + // Use INSERT ... ON CONFLICT to handle any remaining race at the DB level. + // This is defense-in-depth: the mutex above prevents the Go-level race, + // and the UNIQUE constraint + ON CONFLICT prevents the DB-level race. + user, err := models.CreateUserAtomic(ctx, svc.db, username, hash) if err != nil { - return nil, fmt.Errorf("failed to save user: %w", err) + return nil, fmt.Errorf("failed to create user: %w", err) + } + + if user == nil { + return nil, ErrUserExists } svc.log.Info("user created", "username", username) diff --git a/internal/service/auth/auth_test.go b/internal/service/auth/auth_test.go index 9ca9430..3d1b8a3 100644 --- a/internal/service/auth/auth_test.go +++ b/internal/service/auth/auth_test.go @@ -2,6 +2,7 @@ package auth_test import ( "context" + "fmt" "net/http" "net/http/httptest" "path/filepath" @@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) { }) } +func TestCreateUserRaceCondition(testingT *testing.T) { + testingT.Parallel() + + testingT.Run("concurrent setup requests create only one user", func(t *testing.T) { + t.Parallel() + + svc, cleanup := setupTestService(t) + defer cleanup() + + const goroutines = 10 + + results := make(chan error, goroutines) + start := make(chan struct{}) + + for i := range goroutines { + go func(idx int) { + <-start // Wait for all goroutines to be ready + + _, err := svc.CreateUser( + context.Background(), + fmt.Sprintf("admin%d", idx), + "password123456", + ) + results <- err + }(i) + } + + // Release all goroutines simultaneously + close(start) + + var successes, failures int + for range goroutines { + err := <-results + if err == nil { + successes++ + } else { + assert.ErrorIs(t, err, auth.ErrUserExists) + failures++ + } + } + + assert.Equal(t, 1, successes, "exactly one goroutine should succeed") + assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists") + }) +} + func TestAuthenticate(testingT *testing.T) { testingT.Parallel()