// Package db provides database access and migration management. package db import ( "context" "database/sql" "embed" "fmt" "io/fs" "log/slog" "sort" "strconv" "strings" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/logger" "go.uber.org/fx" _ "github.com/joho/godotenv/autoload" // .env _ "modernc.org/sqlite" // driver ) const minMigrationParts = 2 // SchemaFiles contains embedded SQL migration files. // //go:embed schema/*.sql var SchemaFiles embed.FS // Params defines the dependencies for creating a Database. type Params struct { fx.In Logger *logger.Logger Config *config.Config } // Database manages the SQLite connection and migrations. type Database struct { conn *sql.DB log *slog.Logger params *Params } // New creates a new Database and registers lifecycle hooks. func New( lifecycle fx.Lifecycle, params Params, ) (*Database, error) { database := &Database{ //nolint:exhaustruct // conn set in OnStart params: ¶ms, log: params.Logger.Get(), } database.log.Info("Database instantiated") lifecycle.Append(fx.Hook{ OnStart: func(ctx context.Context) error { database.log.Info("Database OnStart Hook") return database.connect(ctx) }, OnStop: func(_ context.Context) error { database.log.Info("Database OnStop Hook") if database.conn != nil { closeErr := database.conn.Close() if closeErr != nil { return fmt.Errorf( "close db: %w", closeErr, ) } } return nil }, }) return database, nil } // GetDB returns the underlying sql.DB connection. func (database *Database) GetDB() *sql.DB { return database.conn } func (database *Database) connect(ctx context.Context) error { dbURL := database.params.Config.DBURL if dbURL == "" { dbURL = "file:./data.db?_journal_mode=WAL&_busy_timeout=5000" } database.log.Info( "connecting to database", "url", dbURL, ) conn, err := sql.Open("sqlite", dbURL) if err != nil { return fmt.Errorf("open database: %w", err) } err = conn.PingContext(ctx) if err != nil { return fmt.Errorf("ping database: %w", err) } conn.SetMaxOpenConns(1) database.conn = conn database.log.Info("database connected") _, err = database.conn.ExecContext( ctx, "PRAGMA foreign_keys = ON", ) if err != nil { return fmt.Errorf("enable foreign keys: %w", err) } _, err = database.conn.ExecContext( ctx, "PRAGMA busy_timeout = 5000", ) if err != nil { return fmt.Errorf("set busy timeout: %w", err) } return database.runMigrations(ctx) } type migration struct { version int name string sql string } func (database *Database) runMigrations( ctx context.Context, ) error { _, err := database.conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)`) if err != nil { return fmt.Errorf( "create schema_migrations: %w", err, ) } migrations, err := database.loadMigrations() if err != nil { return err } for _, mig := range migrations { err = database.applyMigration(ctx, mig) if err != nil { return err } } database.log.Info("database migrations complete") return nil } func (database *Database) applyMigration( ctx context.Context, mig migration, ) error { var exists int err := database.conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM schema_migrations WHERE version = ?`, mig.version, ).Scan(&exists) if err != nil { return fmt.Errorf( "check migration %d: %w", mig.version, err, ) } if exists > 0 { return nil } database.log.Info( "applying migration", "version", mig.version, "name", mig.name, ) return database.execMigration(ctx, mig) } func (database *Database) execMigration( ctx context.Context, mig migration, ) error { transaction, err := database.conn.BeginTx(ctx, nil) if err != nil { return fmt.Errorf( "begin tx for migration %d: %w", mig.version, err, ) } _, err = transaction.ExecContext(ctx, mig.sql) if err != nil { _ = transaction.Rollback() return fmt.Errorf( "apply migration %d (%s): %w", mig.version, mig.name, err, ) } _, err = transaction.ExecContext(ctx, `INSERT INTO schema_migrations (version) VALUES (?)`, mig.version, ) if err != nil { _ = transaction.Rollback() return fmt.Errorf( "record migration %d: %w", mig.version, err, ) } err = transaction.Commit() if err != nil { return fmt.Errorf( "commit migration %d: %w", mig.version, err, ) } return nil } func (database *Database) loadMigrations() ( []migration, error, ) { entries, err := fs.ReadDir(SchemaFiles, "schema") if err != nil { return nil, fmt.Errorf( "read schema dir: %w", err, ) } migrations := make([]migration, 0, len(entries)) for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } parts := strings.SplitN( entry.Name(), "_", minMigrationParts, ) if len(parts) < minMigrationParts { continue } version, parseErr := strconv.Atoi(parts[0]) if parseErr != nil { continue } content, readErr := SchemaFiles.ReadFile( "schema/" + entry.Name(), ) if readErr != nil { return nil, fmt.Errorf( "read migration %s: %w", entry.Name(), readErr, ) } migrations = append(migrations, migration{ version: version, name: entry.Name(), sql: string(content), }) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version }) return migrations, nil }