upaas/internal/models/user.go
user 97a5aae2f7 simplify: replace mutex + ON CONFLICT with a single DB transaction
Remove the sync.Mutex and CreateUserAtomic (INSERT ON CONFLICT) in favor
of a single DB transaction in CreateFirstUser that atomically checks for
existing users and inserts. SQLite serializes write transactions, so this
is sufficient to prevent the race condition without application-level locking.
2026-02-15 21:41:52 -08:00

205 lines
4.4 KiB
Go

// Package models provides Active Record style database models.
package models
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"git.eeqj.de/sneak/upaas/internal/database"
)
// User represents a user in the system.
type User struct {
db *database.Database
ID int64
Username string
PasswordHash string
CreatedAt time.Time
}
// NewUser creates a new User with a database reference.
func NewUser(db *database.Database) *User {
return &User{db: db}
}
// Save inserts or updates the user in the database.
func (u *User) Save(ctx context.Context) error {
if u.ID == 0 {
return u.insert(ctx)
}
return u.update(ctx)
}
// Delete removes the user from the database.
func (u *User) Delete(ctx context.Context) error {
_, err := u.db.Exec(ctx, "DELETE FROM users WHERE id = ?", u.ID)
return err
}
// Reload refreshes the user from the database.
func (u *User) Reload(ctx context.Context) error {
query := "SELECT id, username, password_hash, created_at FROM users WHERE id = ?"
row := u.db.QueryRow(ctx, query, u.ID)
return u.scan(row)
}
func (u *User) insert(ctx context.Context) error {
query := "INSERT INTO users (username, password_hash) VALUES (?, ?)"
result, err := u.db.Exec(ctx, query, u.Username, u.PasswordHash)
if err != nil {
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
u.ID = id
return u.Reload(ctx)
}
func (u *User) update(ctx context.Context) error {
query := "UPDATE users SET username = ?, password_hash = ? WHERE id = ?"
_, err := u.db.Exec(ctx, query, u.Username, u.PasswordHash, u.ID)
return err
}
func (u *User) scan(row *sql.Row) error {
return row.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.CreatedAt)
}
// FindUser finds a user by ID.
//
//nolint:nilnil // returning nil,nil is idiomatic for "not found" in Active Record
func FindUser(
ctx context.Context,
db *database.Database,
id int64,
) (*User, error) {
user := NewUser(db)
row := db.QueryRow(ctx,
"SELECT id, username, password_hash, created_at FROM users WHERE id = ?",
id,
)
err := user.scan(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("scanning user: %w", err)
}
return user, nil
}
// FindUserByUsername finds a user by username.
//
//nolint:nilnil // returning nil,nil is idiomatic for "not found" in Active Record
func FindUserByUsername(
ctx context.Context,
db *database.Database,
username string,
) (*User, error) {
user := NewUser(db)
row := db.QueryRow(ctx,
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
username,
)
err := user.scan(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("scanning user by username: %w", err)
}
return user, nil
}
// CreateFirstUser atomically checks that no users exist and inserts the admin user.
// Returns nil, nil if a user already exists (setup already completed).
func CreateFirstUser(
ctx context.Context,
db *database.Database,
username, passwordHash string,
) (*User, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("beginning transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Check if any user exists within the transaction.
var count int
err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count)
if err != nil {
return nil, fmt.Errorf("checking user count: %w", err)
}
if count > 0 {
return nil, nil //nolint:nilnil // nil,nil signals setup already completed
}
result, err := tx.ExecContext(ctx,
"INSERT INTO users (username, password_hash) VALUES (?, ?)",
username, passwordHash,
)
if err != nil {
return nil, fmt.Errorf("inserting user: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, fmt.Errorf("committing transaction: %w", err)
}
insertID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("getting last insert id: %w", err)
}
user := NewUser(db)
user.ID = insertID
err = user.Reload(ctx)
if err != nil {
return nil, fmt.Errorf("reloading user: %w", err)
}
return user, nil
}
// UserExists checks if any user exists in the database.
func UserExists(ctx context.Context, db *database.Database) (bool, error) {
var count int
row := db.QueryRow(ctx, "SELECT COUNT(*) FROM users")
err := row.Scan(&count)
if err != nil {
return false, fmt.Errorf("counting users: %w", err)
}
return count > 0, nil
}