// Package db provides database access and migration management. package db import ( "context" "database/sql" "embed" "fmt" "io/fs" "log/slog" "sort" "strconv" "strings" "time" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/logger" "git.eeqj.de/sneak/chat/internal/models" "go.uber.org/fx" _ "github.com/joho/godotenv/autoload" // loads .env file _ "modernc.org/sqlite" // SQLite driver ) const ( minMigrationParts = 2 ) // SchemaFiles contains embedded SQL migration files. // //go:embed schema/*.sql var SchemaFiles embed.FS // Params defines the dependencies for creating a Database. type Params struct { fx.In Logger *logger.Logger Config *config.Config } // Database manages the SQLite database connection and migrations. type Database struct { db *sql.DB log *slog.Logger params *Params } // New creates a new Database instance and registers lifecycle hooks. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := new(Database) s.params = ¶ms s.log = params.Logger.Get() s.log.Info("Database instantiated") lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { s.log.Info("Database OnStart Hook") return s.connect(ctx) }, OnStop: func(_ context.Context) error { s.log.Info("Database OnStop Hook") if s.db != nil { return s.db.Close() } return nil }, }) return s, nil } // NewTest creates a Database for testing, bypassing fx lifecycle. // It connects to the given DSN and runs all migrations. func NewTest(dsn string) (*Database, error) { d, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } s := &Database{ db: d, log: slog.Default(), } // Item 9: Enable foreign keys _, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx // no context in sql.Open path if err != nil { _ = d.Close() return nil, fmt.Errorf("enable foreign keys: %w", err) } ctx := context.Background() err = s.runMigrations(ctx) if err != nil { _ = d.Close() return nil, err } return s, nil } // GetDB returns the underlying sql.DB connection. func (s *Database) GetDB() *sql.DB { return s.db } // Hydrate injects the database reference into any model that // embeds Base. func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) { m.SetDB(s) } // GetUserByID looks up a user by their ID. func (s *Database) GetUserByID( ctx context.Context, id string, ) (*models.User, error) { u := &models.User{} s.Hydrate(u) err := s.db.QueryRowContext(ctx, ` SELECT id, nick, password_hash, created_at, updated_at, last_seen_at FROM users WHERE id = ?`, id, ).Scan( &u.ID, &u.Nick, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, ) if err != nil { return nil, err } return u, nil } // GetChannelByID looks up a channel by its ID. func (s *Database) GetChannelByID( ctx context.Context, id string, ) (*models.Channel, error) { c := &models.Channel{} s.Hydrate(c) err := s.db.QueryRowContext(ctx, ` SELECT id, name, topic, modes, created_at, updated_at FROM channels WHERE id = ?`, id, ).Scan( &c.ID, &c.Name, &c.Topic, &c.Modes, &c.CreatedAt, &c.UpdatedAt, ) if err != nil { return nil, err } return c, nil } // GetUserByNickModel looks up a user by their nick. func (s *Database) GetUserByNickModel( ctx context.Context, nick string, ) (*models.User, error) { u := &models.User{} s.Hydrate(u) err := s.db.QueryRowContext(ctx, ` SELECT id, nick, password_hash, created_at, updated_at, last_seen_at FROM users WHERE nick = ?`, nick, ).Scan( &u.ID, &u.Nick, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, ) if err != nil { return nil, err } return u, nil } // GetUserByTokenModel looks up a user by their auth token. func (s *Database) GetUserByTokenModel( ctx context.Context, token string, ) (*models.User, error) { u := &models.User{} s.Hydrate(u) err := s.db.QueryRowContext(ctx, ` SELECT u.id, u.nick, u.password_hash, u.created_at, u.updated_at, u.last_seen_at FROM users u JOIN auth_tokens t ON t.user_id = u.id WHERE t.token = ?`, token, ).Scan( &u.ID, &u.Nick, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, ) if err != nil { return nil, err } return u, nil } // DeleteAuthToken removes an auth token from the database. func (s *Database) DeleteAuthToken( ctx context.Context, token string, ) error { _, err := s.db.ExecContext(ctx, `DELETE FROM auth_tokens WHERE token = ?`, token, ) return err } // UpdateUserLastSeen updates the last_seen_at timestamp for a user. func (s *Database) UpdateUserLastSeen( ctx context.Context, userID string, ) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`, userID, ) return err } // CreateUserModel inserts a new user into the database. func (s *Database) CreateUserModel( ctx context.Context, id, nick, passwordHash string, ) (*models.User, error) { now := time.Now() _, err := s.db.ExecContext(ctx, `INSERT INTO users (id, nick, password_hash) VALUES (?, ?, ?)`, id, nick, passwordHash, ) if err != nil { return nil, err } u := &models.User{ ID: id, Nick: nick, PasswordHash: passwordHash, CreatedAt: now, UpdatedAt: now, } s.Hydrate(u) return u, nil } // CreateChannel inserts a new channel into the database. func (s *Database) CreateChannel( ctx context.Context, id, name, topic, modes string, ) (*models.Channel, error) { now := time.Now() _, err := s.db.ExecContext(ctx, `INSERT INTO channels (id, name, topic, modes) VALUES (?, ?, ?, ?)`, id, name, topic, modes, ) if err != nil { return nil, err } c := &models.Channel{ ID: id, Name: name, Topic: topic, Modes: modes, CreatedAt: now, UpdatedAt: now, } s.Hydrate(c) return c, nil } // AddChannelMember adds a user to a channel with the given modes. func (s *Database) AddChannelMember( ctx context.Context, channelID, userID, modes string, ) (*models.ChannelMember, error) { now := time.Now() _, err := s.db.ExecContext(ctx, `INSERT INTO channel_members (channel_id, user_id, modes) VALUES (?, ?, ?)`, channelID, userID, modes, ) if err != nil { return nil, err } cm := &models.ChannelMember{ ChannelID: channelID, UserID: userID, Modes: modes, JoinedAt: now, } s.Hydrate(cm) return cm, nil } // CreateMessage inserts a new message into the database. func (s *Database) CreateMessage( ctx context.Context, id, fromUserID, fromNick, target, msgType, body string, ) (*models.Message, error) { now := time.Now() _, err := s.db.ExecContext(ctx, `INSERT INTO messages (id, from_user_id, from_nick, target, type, body) VALUES (?, ?, ?, ?, ?, ?)`, id, fromUserID, fromNick, target, msgType, body, ) if err != nil { return nil, err } m := &models.Message{ ID: id, FromUserID: fromUserID, FromNick: fromNick, Target: target, Type: msgType, Body: body, Timestamp: now, CreatedAt: now, } s.Hydrate(m) return m, nil } // QueueMessage adds a message to a user's delivery queue. func (s *Database) QueueMessage( ctx context.Context, userID, messageID string, ) (*models.MessageQueueEntry, error) { now := time.Now() res, err := s.db.ExecContext(ctx, `INSERT INTO message_queue (user_id, message_id) VALUES (?, ?)`, userID, messageID, ) if err != nil { return nil, err } entryID, err := res.LastInsertId() if err != nil { return nil, fmt.Errorf("get last insert id: %w", err) } mq := &models.MessageQueueEntry{ ID: entryID, UserID: userID, MessageID: messageID, QueuedAt: now, } s.Hydrate(mq) return mq, nil } // DequeueMessages returns up to limit pending messages for a user, // ordered by queue time (oldest first). func (s *Database) DequeueMessages( ctx context.Context, userID string, limit int, ) ([]*models.MessageQueueEntry, error) { rows, err := s.db.QueryContext(ctx, ` SELECT id, user_id, message_id, queued_at FROM message_queue WHERE user_id = ? ORDER BY queued_at ASC LIMIT ?`, userID, limit, ) if err != nil { return nil, err } defer func() { _ = rows.Close() }() entries := []*models.MessageQueueEntry{} for rows.Next() { e := &models.MessageQueueEntry{} s.Hydrate(e) err = rows.Scan(&e.ID, &e.UserID, &e.MessageID, &e.QueuedAt) if err != nil { return nil, err } entries = append(entries, e) } return entries, rows.Err() } // AckMessages removes the given queue entry IDs, marking them as delivered. func (s *Database) AckMessages( ctx context.Context, entryIDs []int64, ) error { if len(entryIDs) == 0 { return nil } placeholders := make([]string, len(entryIDs)) args := make([]any, len(entryIDs)) for i, id := range entryIDs { placeholders[i] = "?" args[i] = id } query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input "DELETE FROM message_queue WHERE id IN (%s)", strings.Join(placeholders, ","), ) _, err := s.db.ExecContext(ctx, query, args...) return err } // CreateAuthToken inserts a new auth token for a user. func (s *Database) CreateAuthToken( ctx context.Context, token, userID string, ) (*models.AuthToken, error) { now := time.Now() _, err := s.db.ExecContext(ctx, `INSERT INTO auth_tokens (token, user_id) VALUES (?, ?)`, token, userID, ) if err != nil { return nil, err } at := &models.AuthToken{Token: token, UserID: userID, CreatedAt: now} s.Hydrate(at) return at, nil } // CreateSession inserts a new session for a user. func (s *Database) CreateSession( ctx context.Context, id, userID string, ) (*models.Session, error) { now := time.Now() _, err := s.db.ExecContext(ctx, `INSERT INTO sessions (id, user_id) VALUES (?, ?)`, id, userID, ) if err != nil { return nil, err } sess := &models.Session{ ID: id, UserID: userID, CreatedAt: now, LastActiveAt: now, } s.Hydrate(sess) return sess, nil } // CreateServerLink inserts a new server link. func (s *Database) CreateServerLink( ctx context.Context, id, name, url, sharedKeyHash string, isActive bool, ) (*models.ServerLink, error) { now := time.Now() active := 0 if isActive { active = 1 } _, err := s.db.ExecContext(ctx, `INSERT INTO server_links (id, name, url, shared_key_hash, is_active) VALUES (?, ?, ?, ?, ?)`, id, name, url, sharedKeyHash, active, ) if err != nil { return nil, err } sl := &models.ServerLink{ ID: id, Name: name, URL: url, SharedKeyHash: sharedKeyHash, IsActive: isActive, CreatedAt: now, } s.Hydrate(sl) return sl, nil } func (s *Database) connect(ctx context.Context) error { dbURL := s.params.Config.DBURL if dbURL == "" { dbURL = "file:./data.db?_journal_mode=WAL" } s.log.Info("connecting to database", "url", dbURL) d, err := sql.Open("sqlite", dbURL) if err != nil { s.log.Error("failed to open database", "error", err) return err } err = d.PingContext(ctx) if err != nil { s.log.Error("failed to ping database", "error", err) return err } s.db = d s.log.Info("database connected") // Item 9: Enable foreign keys on every connection _, err = s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON") if err != nil { return fmt.Errorf("enable foreign keys: %w", err) } return s.runMigrations(ctx) } type migration struct { version int name string sql string } func (s *Database) runMigrations(ctx context.Context) error { err := s.bootstrapMigrationsTable(ctx) if err != nil { return err } migrations, err := s.loadMigrations() if err != nil { return err } err = s.applyMigrations(ctx, migrations) if err != nil { return err } s.log.Info("database migrations complete") return nil } func (s *Database) bootstrapMigrationsTable( ctx context.Context, ) error { _, err := s.db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP )`) if err != nil { return fmt.Errorf( "create schema_migrations table: %w", err, ) } return nil } func (s *Database) loadMigrations() ([]migration, error) { entries, err := fs.ReadDir(SchemaFiles, "schema") if err != nil { return nil, fmt.Errorf("read schema dir: %w", err) } var migrations []migration for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } parts := strings.SplitN( entry.Name(), "_", minMigrationParts, ) if len(parts) < minMigrationParts { continue } version, err := strconv.Atoi(parts[0]) if err != nil { continue } content, err := SchemaFiles.ReadFile( "schema/" + entry.Name(), ) if err != nil { return nil, fmt.Errorf( "read migration %s: %w", entry.Name(), err, ) } migrations = append(migrations, migration{ version: version, name: entry.Name(), sql: string(content), }) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version }) return migrations, nil } // Item 4: Wrap each migration in a transaction func (s *Database) applyMigrations( ctx context.Context, migrations []migration, ) error { for _, m := range migrations { var exists int err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version, ).Scan(&exists) if err != nil { return fmt.Errorf( "check migration %d: %w", m.version, err, ) } if exists > 0 { continue } s.log.Info( "applying migration", "version", m.version, "name", m.name, ) err = s.executeMigration(ctx, m) if err != nil { return err } } return nil } func (s *Database) executeMigration( ctx context.Context, m migration, ) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf( "begin tx for migration %d: %w", m.version, err, ) } _, err = tx.ExecContext(ctx, m.sql) if err != nil { _ = tx.Rollback() return fmt.Errorf( "apply migration %d (%s): %w", m.version, m.name, err, ) } _, err = tx.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version, ) if err != nil { _ = tx.Rollback() return fmt.Errorf( "record migration %d: %w", m.version, err, ) } err = tx.Commit() if err != nil { return fmt.Errorf( "commit migration %d: %w", m.version, err, ) } return nil }