// Package database provides SQLite database access. package database import ( "context" "database/sql" "embed" "fmt" "log/slog" "path/filepath" "sort" "strings" "go.uber.org/fx" "sneak.berlin/go/pixa/internal/config" "sneak.berlin/go/pixa/internal/logger" _ "modernc.org/sqlite" // SQLite driver registration ) //go:embed schema/*.sql var schemaFS embed.FS // Params defines dependencies for Database. type Params struct { fx.In Logger *logger.Logger Config *config.Config } // Database wraps the SQL database connection. type Database struct { db *sql.DB log *slog.Logger config *config.Config } // ParseMigrationVersion extracts the numeric version prefix from a migration // filename. Filenames must follow the pattern ".sql" or // "_.sql", where version is a zero-padded numeric // string (e.g. "001", "002"). Returns the version string and an error if // the filename does not match the expected pattern. func ParseMigrationVersion(filename string) (string, error) { name := strings.TrimSuffix(filename, filepath.Ext(filename)) if name == "" { return "", fmt.Errorf("invalid migration filename %q: empty name", filename) } // Split on underscore to separate version from description. // If there's no underscore, the entire stem is the version. version := name if idx := strings.IndexByte(name, '_'); idx >= 0 { version = name[:idx] } if version == "" { return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename) } // Validate the version is purely numeric. for _, ch := range version { if ch < '0' || ch > '9' { return "", fmt.Errorf( "invalid migration filename %q: version %q contains non-numeric character %q", filename, version, string(ch), ) } } return version, nil } // New creates a new Database instance. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := &Database{ log: params.Logger.Get(), config: params.Config, } s.log.Info("Database instantiated") lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { s.log.Info("Database OnStart Hook") return s.connect(ctx) }, OnStop: func(_ context.Context) error { s.log.Info("Database OnStop Hook") if s.db != nil { return s.db.Close() } return nil }, }) return s, nil } func (s *Database) connect(ctx context.Context) error { dbURL := s.config.DBURL s.log.Info("connecting to database", "url", dbURL) db, err := sql.Open("sqlite", dbURL) if err != nil { s.log.Error("failed to open database", "error", err) return err } if err := db.PingContext(ctx); err != nil { s.log.Error("failed to ping database", "error", err) return err } s.db = db s.log.Info("database connected") return s.runMigrations(ctx) } // collectMigrations reads the embedded schema directory and returns // migration filenames sorted lexicographically. func collectMigrations() ([]string, error) { entries, err := schemaFS.ReadDir("schema") if err != nil { return nil, fmt.Errorf("failed to read schema directory: %w", err) } var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) return migrations, nil } // ensureMigrationsTable creates the schema_migrations tracking table if // it does not already exist. func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP ) `) if err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } return nil } // applyMigrations applies all pending migrations to db, using log for // informational output (may be nil for silent operation). func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { if err := ensureMigrationsTable(ctx, db); err != nil { return err } migrations, err := collectMigrations() if err != nil { return err } for _, migration := range migrations { version, parseErr := ParseMigrationVersion(migration) if parseErr != nil { return parseErr } // Check if already applied. var count int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return fmt.Errorf("failed to check migration status: %w", err) } if count > 0 { if log != nil { log.Debug("migration already applied", "version", version) } continue } // Read and apply migration. content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration)) if readErr != nil { return fmt.Errorf("failed to read migration %s: %w", migration, readErr) } if log != nil { log.Info("applying migration", "version", version) } _, execErr := db.ExecContext(ctx, string(content)) if execErr != nil { return fmt.Errorf("failed to apply migration %s: %w", migration, execErr) } // Record migration as applied. _, recErr := db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) if recErr != nil { return fmt.Errorf("failed to record migration %s: %w", migration, recErr) } if log != nil { log.Info("migration applied successfully", "version", version) } } return nil } func (s *Database) runMigrations(ctx context.Context) error { return applyMigrations(ctx, s.db, s.log) } // DB returns the underlying sql.DB. func (s *Database) DB() *sql.DB { return s.db } // ApplyMigrations applies all migrations to the given database. // This is useful for testing where you want to use the real schema // without the full fx lifecycle. func ApplyMigrations(db *sql.DB) error { return applyMigrations(context.Background(), db, nil) }