Add mutex and INSERT ON CONFLICT to CreateUser to prevent TOCTOU race where concurrent requests could create multiple admin users. Changes: - Add sync.Mutex to auth.Service to serialize CreateUser calls - Add models.CreateUserAtomic using INSERT ... ON CONFLICT(username) DO NOTHING - Check RowsAffected to detect conflicts at the DB level (defense-in-depth) - Add concurrent race condition test (10 goroutines, only 1 succeeds) The existing UNIQUE constraint on users.username was already in place. This fix adds the application-level protection (items 1 & 2 from #26).
361 lines
8.3 KiB
Go
361 lines
8.3 KiB
Go
package auth_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/fx"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/config"
|
|
"git.eeqj.de/sneak/upaas/internal/database"
|
|
"git.eeqj.de/sneak/upaas/internal/globals"
|
|
"git.eeqj.de/sneak/upaas/internal/logger"
|
|
"git.eeqj.de/sneak/upaas/internal/service/auth"
|
|
)
|
|
|
|
func setupTestService(t *testing.T) (*auth.Service, func()) {
|
|
t.Helper()
|
|
|
|
// Create temp directory
|
|
tmpDir := t.TempDir()
|
|
|
|
// Set up globals
|
|
globals.SetAppname("upaas-test")
|
|
globals.SetVersion("test")
|
|
|
|
globalsInst, err := globals.New(fx.Lifecycle(nil))
|
|
require.NoError(t, err)
|
|
|
|
loggerInst, err := logger.New(
|
|
fx.Lifecycle(nil),
|
|
logger.Params{Globals: globalsInst},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// Create test config
|
|
cfg := &config.Config{
|
|
Port: 8080,
|
|
DataDir: tmpDir,
|
|
SessionSecret: "test-secret-key-at-least-32-chars",
|
|
}
|
|
|
|
// Create database
|
|
dbInst, err := database.New(fx.Lifecycle(nil), database.Params{
|
|
Logger: loggerInst,
|
|
Config: cfg,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Connect database manually for tests
|
|
dbPath := filepath.Join(tmpDir, "upaas.db")
|
|
cfg.DataDir = tmpDir
|
|
_ = dbPath // database will create this
|
|
|
|
// Create service
|
|
svc, err := auth.New(fx.Lifecycle(nil), auth.ServiceParams{
|
|
Logger: loggerInst,
|
|
Config: cfg,
|
|
Database: dbInst,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// t.TempDir() automatically cleans up after test
|
|
cleanup := func() {}
|
|
|
|
return svc, cleanup
|
|
}
|
|
|
|
func TestSessionCookieSecureFlag(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("secure flag is true when debug is false", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
|
|
globals.SetAppname("upaas-test")
|
|
globals.SetVersion("test")
|
|
|
|
globalsInst, err := globals.New(fx.Lifecycle(nil))
|
|
require.NoError(t, err)
|
|
|
|
loggerInst, err := logger.New(
|
|
fx.Lifecycle(nil),
|
|
logger.Params{Globals: globalsInst},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
cfg := &config.Config{
|
|
Port: 8080,
|
|
DataDir: tmpDir,
|
|
SessionSecret: "test-secret-key-at-least-32-chars",
|
|
Debug: false,
|
|
}
|
|
|
|
dbInst, err := database.New(fx.Lifecycle(nil), database.Params{
|
|
Logger: loggerInst,
|
|
Config: cfg,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
svc, err := auth.New(fx.Lifecycle(nil), auth.ServiceParams{
|
|
Logger: loggerInst,
|
|
Config: cfg,
|
|
Database: dbInst,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create user and session, check cookie has Secure flag
|
|
_, err = svc.CreateUser(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
|
|
user, err := svc.Authenticate(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
err = svc.CreateSession(recorder, request, user)
|
|
require.NoError(t, err)
|
|
|
|
cookies := recorder.Result().Cookies()
|
|
require.NotEmpty(t, cookies)
|
|
|
|
var sessionCookie *http.Cookie
|
|
for _, c := range cookies {
|
|
if c.Name == "upaas_session" {
|
|
sessionCookie = c
|
|
break
|
|
}
|
|
}
|
|
require.NotNil(t, sessionCookie, "session cookie should exist")
|
|
assert.True(t, sessionCookie.Secure, "session cookie should have Secure flag in production mode")
|
|
})
|
|
}
|
|
|
|
func TestHashPassword(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("hashes password successfully", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
hash, err := svc.HashPassword("testpassword")
|
|
require.NoError(t, err)
|
|
assert.NotEmpty(t, hash)
|
|
assert.NotEqual(t, "testpassword", hash)
|
|
assert.Contains(t, hash, "$") // salt$hash format
|
|
})
|
|
|
|
testingT.Run("produces different hashes for same password", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
hash1, err := svc.HashPassword("testpassword")
|
|
require.NoError(t, err)
|
|
|
|
hash2, err := svc.HashPassword("testpassword")
|
|
require.NoError(t, err)
|
|
|
|
assert.NotEqual(t, hash1, hash2) // Different salts
|
|
})
|
|
}
|
|
|
|
func TestVerifyPassword(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("verifies correct password", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
hash, err := svc.HashPassword("correctpassword")
|
|
require.NoError(t, err)
|
|
|
|
valid := svc.VerifyPassword(hash, "correctpassword")
|
|
assert.True(t, valid)
|
|
})
|
|
|
|
testingT.Run("rejects incorrect password", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
hash, err := svc.HashPassword("correctpassword")
|
|
require.NoError(t, err)
|
|
|
|
valid := svc.VerifyPassword(hash, "wrongpassword")
|
|
assert.False(t, valid)
|
|
})
|
|
|
|
testingT.Run("rejects empty password", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
hash, err := svc.HashPassword("correctpassword")
|
|
require.NoError(t, err)
|
|
|
|
valid := svc.VerifyPassword(hash, "")
|
|
assert.False(t, valid)
|
|
})
|
|
|
|
testingT.Run("rejects invalid hash format", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
valid := svc.VerifyPassword("invalid-hash", "password")
|
|
assert.False(t, valid)
|
|
})
|
|
}
|
|
|
|
func TestIsSetupRequired(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("returns true when no users exist", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
required, err := svc.IsSetupRequired(context.Background())
|
|
require.NoError(t, err)
|
|
assert.True(t, required)
|
|
})
|
|
}
|
|
|
|
func TestCreateUser(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("creates user successfully", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
user, err := svc.CreateUser(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, user)
|
|
|
|
assert.Equal(t, "admin", user.Username)
|
|
assert.NotEmpty(t, user.PasswordHash)
|
|
assert.NotZero(t, user.ID)
|
|
})
|
|
|
|
testingT.Run("rejects duplicate user", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
_, err := svc.CreateUser(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
|
|
_, err = svc.CreateUser(context.Background(), "admin2", "password456")
|
|
assert.ErrorIs(t, err, auth.ErrUserExists)
|
|
})
|
|
}
|
|
|
|
func TestCreateUserRaceCondition(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("concurrent setup requests create only one user", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
const goroutines = 10
|
|
|
|
results := make(chan error, goroutines)
|
|
start := make(chan struct{})
|
|
|
|
for i := range goroutines {
|
|
go func(idx int) {
|
|
<-start // Wait for all goroutines to be ready
|
|
|
|
_, err := svc.CreateUser(
|
|
context.Background(),
|
|
fmt.Sprintf("admin%d", idx),
|
|
"password123456",
|
|
)
|
|
results <- err
|
|
}(i)
|
|
}
|
|
|
|
// Release all goroutines simultaneously
|
|
close(start)
|
|
|
|
var successes, failures int
|
|
for range goroutines {
|
|
err := <-results
|
|
if err == nil {
|
|
successes++
|
|
} else {
|
|
assert.ErrorIs(t, err, auth.ErrUserExists)
|
|
failures++
|
|
}
|
|
}
|
|
|
|
assert.Equal(t, 1, successes, "exactly one goroutine should succeed")
|
|
assert.Equal(t, goroutines-1, failures, "all other goroutines should fail with ErrUserExists")
|
|
})
|
|
}
|
|
|
|
func TestAuthenticate(testingT *testing.T) {
|
|
testingT.Parallel()
|
|
|
|
testingT.Run("authenticates valid credentials", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
_, err := svc.CreateUser(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
|
|
user, err := svc.Authenticate(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, user)
|
|
assert.Equal(t, "admin", user.Username)
|
|
})
|
|
|
|
testingT.Run("rejects invalid password", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
_, err := svc.CreateUser(context.Background(), "admin", "password123")
|
|
require.NoError(t, err)
|
|
|
|
_, err = svc.Authenticate(context.Background(), "admin", "wrongpassword")
|
|
assert.ErrorIs(t, err, auth.ErrInvalidCredentials)
|
|
})
|
|
|
|
testingT.Run("rejects unknown user", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc, cleanup := setupTestService(t)
|
|
defer cleanup()
|
|
|
|
_, err := svc.Authenticate(context.Background(), "nonexistent", "password")
|
|
assert.ErrorIs(t, err, auth.ErrInvalidCredentials)
|
|
})
|
|
}
|