Remove the sync.Mutex and CreateUserAtomic (INSERT ON CONFLICT) in favor of a single DB transaction in CreateFirstUser that atomically checks for existing users and inserts. SQLite serializes write transactions, so this is sufficient to prevent the race condition without application-level locking.
205 lines
4.4 KiB
Go
205 lines
4.4 KiB
Go
// Package models provides Active Record style database models.
|
|
package models
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/database"
|
|
)
|
|
|
|
// User represents a user in the system.
|
|
type User struct {
|
|
db *database.Database
|
|
|
|
ID int64
|
|
Username string
|
|
PasswordHash string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
// NewUser creates a new User with a database reference.
|
|
func NewUser(db *database.Database) *User {
|
|
return &User{db: db}
|
|
}
|
|
|
|
// Save inserts or updates the user in the database.
|
|
func (u *User) Save(ctx context.Context) error {
|
|
if u.ID == 0 {
|
|
return u.insert(ctx)
|
|
}
|
|
|
|
return u.update(ctx)
|
|
}
|
|
|
|
// Delete removes the user from the database.
|
|
func (u *User) Delete(ctx context.Context) error {
|
|
_, err := u.db.Exec(ctx, "DELETE FROM users WHERE id = ?", u.ID)
|
|
|
|
return err
|
|
}
|
|
|
|
// Reload refreshes the user from the database.
|
|
func (u *User) Reload(ctx context.Context) error {
|
|
query := "SELECT id, username, password_hash, created_at FROM users WHERE id = ?"
|
|
|
|
row := u.db.QueryRow(ctx, query, u.ID)
|
|
|
|
return u.scan(row)
|
|
}
|
|
|
|
func (u *User) insert(ctx context.Context) error {
|
|
query := "INSERT INTO users (username, password_hash) VALUES (?, ?)"
|
|
|
|
result, err := u.db.Exec(ctx, query, u.Username, u.PasswordHash)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
u.ID = id
|
|
|
|
return u.Reload(ctx)
|
|
}
|
|
|
|
func (u *User) update(ctx context.Context) error {
|
|
query := "UPDATE users SET username = ?, password_hash = ? WHERE id = ?"
|
|
|
|
_, err := u.db.Exec(ctx, query, u.Username, u.PasswordHash, u.ID)
|
|
|
|
return err
|
|
}
|
|
|
|
func (u *User) scan(row *sql.Row) error {
|
|
return row.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.CreatedAt)
|
|
}
|
|
|
|
// FindUser finds a user by ID.
|
|
//
|
|
//nolint:nilnil // returning nil,nil is idiomatic for "not found" in Active Record
|
|
func FindUser(
|
|
ctx context.Context,
|
|
db *database.Database,
|
|
id int64,
|
|
) (*User, error) {
|
|
user := NewUser(db)
|
|
|
|
row := db.QueryRow(ctx,
|
|
"SELECT id, username, password_hash, created_at FROM users WHERE id = ?",
|
|
id,
|
|
)
|
|
|
|
err := user.scan(row)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("scanning user: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// FindUserByUsername finds a user by username.
|
|
//
|
|
//nolint:nilnil // returning nil,nil is idiomatic for "not found" in Active Record
|
|
func FindUserByUsername(
|
|
ctx context.Context,
|
|
db *database.Database,
|
|
username string,
|
|
) (*User, error) {
|
|
user := NewUser(db)
|
|
|
|
row := db.QueryRow(ctx,
|
|
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
|
|
username,
|
|
)
|
|
|
|
err := user.scan(row)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("scanning user by username: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// CreateFirstUser atomically checks that no users exist and inserts the admin user.
|
|
// Returns nil, nil if a user already exists (setup already completed).
|
|
func CreateFirstUser(
|
|
ctx context.Context,
|
|
db *database.Database,
|
|
username, passwordHash string,
|
|
) (*User, error) {
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
|
}
|
|
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
// Check if any user exists within the transaction.
|
|
var count int
|
|
|
|
err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("checking user count: %w", err)
|
|
}
|
|
|
|
if count > 0 {
|
|
return nil, nil //nolint:nilnil // nil,nil signals setup already completed
|
|
}
|
|
|
|
result, err := tx.ExecContext(ctx,
|
|
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
|
|
username, passwordHash,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inserting user: %w", err)
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("committing transaction: %w", err)
|
|
}
|
|
|
|
insertID, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting last insert id: %w", err)
|
|
}
|
|
|
|
user := NewUser(db)
|
|
user.ID = insertID
|
|
|
|
err = user.Reload(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reloading user: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// UserExists checks if any user exists in the database.
|
|
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
|
|
var count int
|
|
|
|
row := db.QueryRow(ctx, "SELECT COUNT(*) FROM users")
|
|
|
|
err := row.Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf("counting users: %w", err)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|