fix: prevent setup endpoint race condition (closes #26)
Add mutex and INSERT ON CONFLICT to CreateUser to prevent TOCTOU race where concurrent requests could create multiple admin users. Changes: - Add sync.Mutex to auth.Service to serialize CreateUser calls - Add models.CreateUserAtomic using INSERT ... ON CONFLICT(username) DO NOTHING - Check RowsAffected to detect conflicts at the DB level (defense-in-depth) - Add concurrent race condition test (10 goroutines, only 1 succeeds) The existing UNIQUE constraint on users.username was already in place. This fix adds the application-level protection (items 1 & 2 from #26).
This commit is contained in:
parent
97ee1e212f
commit
763e722607
@ -135,6 +135,46 @@ func FindUserByUsername(
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateUserAtomic inserts a user using INSERT ... ON CONFLICT(username) DO NOTHING.
|
||||||
|
// It returns nil, nil if the insert was a no-op due to a conflict (user already exists).
|
||||||
|
func CreateUserAtomic(
|
||||||
|
ctx context.Context,
|
||||||
|
db *database.Database,
|
||||||
|
username, passwordHash string,
|
||||||
|
) (*User, error) {
|
||||||
|
query := "INSERT INTO users (username, password_hash) VALUES (?, ?) ON CONFLICT(username) DO NOTHING"
|
||||||
|
|
||||||
|
result, err := db.Exec(ctx, query, username, passwordHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inserting user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("checking rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
// Conflict: user already exists
|
||||||
|
return nil, nil //nolint:nilnil // nil,nil means conflict (no insert happened)
|
||||||
|
}
|
||||||
|
|
||||||
|
insertID, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting last insert id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := NewUser(db)
|
||||||
|
user.ID = insertID
|
||||||
|
|
||||||
|
err = user.Reload(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reloading user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserExists checks if any user exists in the database.
|
// UserExists checks if any user exists in the database.
|
||||||
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
|
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
|
||||||
var count int
|
var count int
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
@ -64,6 +65,7 @@ type Service struct {
|
|||||||
db *database.Database
|
db *database.Database
|
||||||
store *sessions.CookieStore
|
store *sessions.CookieStore
|
||||||
params *ServiceParams
|
params *ServiceParams
|
||||||
|
setupMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new auth Service.
|
// New creates a new auth Service.
|
||||||
@ -163,11 +165,17 @@ func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates the initial admin user.
|
// CreateUser creates the initial admin user.
|
||||||
|
// It uses a mutex and INSERT ... ON CONFLICT to prevent race conditions
|
||||||
|
// where multiple concurrent requests could create duplicate admin users.
|
||||||
func (svc *Service) CreateUser(
|
func (svc *Service) CreateUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
username, password string,
|
username, password string,
|
||||||
) (*models.User, error) {
|
) (*models.User, error) {
|
||||||
// Check if user already exists
|
// Serialize setup attempts to prevent TOCTOU race conditions.
|
||||||
|
svc.setupMu.Lock()
|
||||||
|
defer svc.setupMu.Unlock()
|
||||||
|
|
||||||
|
// Check if any user already exists (setup already completed).
|
||||||
exists, err := models.UserExists(ctx, svc.db)
|
exists, err := models.UserExists(ctx, svc.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
||||||
@ -183,14 +191,16 @@ func (svc *Service) CreateUser(
|
|||||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create user
|
// Use INSERT ... ON CONFLICT to handle any remaining race at the DB level.
|
||||||
user := models.NewUser(svc.db)
|
// This is defense-in-depth: the mutex above prevents the Go-level race,
|
||||||
user.Username = username
|
// and the UNIQUE constraint + ON CONFLICT prevents the DB-level race.
|
||||||
user.PasswordHash = hash
|
user, err := models.CreateUserAtomic(ctx, svc.db, username, hash)
|
||||||
|
|
||||||
err = user.Save(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to save user: %w", err)
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return nil, ErrUserExists
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.log.Info("user created", "username", username)
|
svc.log.Info("user created", "username", username)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package auth_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -270,6 +271,52 @@ func TestCreateUser(testingT *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateUserRaceCondition(testingT *testing.T) {
|
||||||
|
testingT.Parallel()
|
||||||
|
|
||||||
|
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
svc, cleanup := setupTestService(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
const goroutines = 10
|
||||||
|
|
||||||
|
results := make(chan error, goroutines)
|
||||||
|
start := make(chan struct{})
|
||||||
|
|
||||||
|
for i := range goroutines {
|
||||||
|
go func(idx int) {
|
||||||
|
<-start // Wait for all goroutines to be ready
|
||||||
|
|
||||||
|
_, err := svc.CreateUser(
|
||||||
|
context.Background(),
|
||||||
|
fmt.Sprintf("admin%d", idx),
|
||||||
|
"password123456",
|
||||||
|
)
|
||||||
|
results <- err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release all goroutines simultaneously
|
||||||
|
close(start)
|
||||||
|
|
||||||
|
var successes, failures int
|
||||||
|
for range goroutines {
|
||||||
|
err := <-results
|
||||||
|
if err == nil {
|
||||||
|
successes++
|
||||||
|
} else {
|
||||||
|
assert.ErrorIs(t, err, auth.ErrUserExists)
|
||||||
|
failures++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
|
||||||
|
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthenticate(testingT *testing.T) {
|
func TestAuthenticate(testingT *testing.T) {
|
||||||
testingT.Parallel()
|
testingT.Parallel()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user