// Package database provides SQLite database access. package database import ( "context" "database/sql" "embed" "fmt" "log/slog" "path/filepath" "sort" "strings" "go.uber.org/fx" "sneak.berlin/go/pixa/internal/config" "sneak.berlin/go/pixa/internal/logger" _ "modernc.org/sqlite" // SQLite driver registration ) //go:embed schema/*.sql var schemaFS embed.FS // bootstrapVersion is the migration that creates the schema_migrations // table itself. It is applied before the normal migration loop. const bootstrapVersion = "000" // Params defines dependencies for Database. type Params struct { fx.In Logger *logger.Logger Config *config.Config } // Database wraps the SQL database connection. type Database struct { db *sql.DB log *slog.Logger config *config.Config } // ParseMigrationVersion extracts the numeric version prefix from a migration // filename. Filenames must follow the pattern ".sql" or // "_.sql", where version is a zero-padded numeric // string (e.g. "001", "002"). Returns the version string and an error if // the filename does not match the expected pattern. func ParseMigrationVersion(filename string) (string, error) { name := strings.TrimSuffix(filename, filepath.Ext(filename)) if name == "" { return "", fmt.Errorf("invalid migration filename %q: empty name", filename) } // Split on underscore to separate version from description. // If there's no underscore, the entire stem is the version. version := name if idx := strings.IndexByte(name, '_'); idx >= 0 { version = name[:idx] } if version == "" { return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename) } // Validate the version is purely numeric. for _, ch := range version { if ch < '0' || ch > '9' { return "", fmt.Errorf( "invalid migration filename %q: version %q contains non-numeric character %q", filename, version, string(ch), ) } } return version, nil } // New creates a new Database instance. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := &Database{ log: params.Logger.Get(), config: params.Config, } s.log.Info("Database instantiated") lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { s.log.Info("Database OnStart Hook") return s.connect(ctx) }, OnStop: func(_ context.Context) error { s.log.Info("Database OnStop Hook") if s.db != nil { return s.db.Close() } return nil }, }) return s, nil } func (s *Database) connect(ctx context.Context) error { dbURL := s.config.DBURL s.log.Info("connecting to database", "url", dbURL) db, err := sql.Open("sqlite", dbURL) if err != nil { s.log.Error("failed to open database", "error", err) return err } if err := db.PingContext(ctx); err != nil { s.log.Error("failed to ping database", "error", err) return err } s.db = db s.log.Info("database connected") return ApplyMigrations(ctx, s.db, s.log) } // collectMigrations reads the embedded schema directory and returns // migration filenames sorted lexicographically. func collectMigrations() ([]string, error) { entries, err := schemaFS.ReadDir("schema") if err != nil { return nil, fmt.Errorf("failed to read schema directory: %w", err) } var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) return migrations, nil } // bootstrapMigrationsTable ensures the schema_migrations table exists // by applying 000.sql if the table is missing. func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { var tableExists int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", ).Scan(&tableExists) if err != nil { return fmt.Errorf("failed to check for migrations table: %w", err) } if tableExists > 0 { return nil } content, err := schemaFS.ReadFile("schema/000.sql") if err != nil { return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) } if log != nil { log.Info("applying bootstrap migration", "version", bootstrapVersion) } _, err = db.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to apply bootstrap migration: %w", err) } _, err = db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", bootstrapVersion, ) if err != nil { return fmt.Errorf("failed to record bootstrap migration: %w", err) } return nil } // ApplyMigrations applies all pending migrations to db. An optional logger // may be provided for informational output; pass nil for silent operation. // This is exported so tests can apply the real schema without the full fx // lifecycle. func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { if err := bootstrapMigrationsTable(ctx, db, log); err != nil { return err } migrations, err := collectMigrations() if err != nil { return err } for _, migration := range migrations { version, parseErr := ParseMigrationVersion(migration) if parseErr != nil { return parseErr } // Check if already applied. var count int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return fmt.Errorf("failed to check migration status: %w", err) } if count > 0 { if log != nil { log.Debug("migration already applied", "version", version) } continue } // Read and apply migration. content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration)) if readErr != nil { return fmt.Errorf("failed to read migration %s: %w", migration, readErr) } if log != nil { log.Info("applying migration", "version", version) } _, execErr := db.ExecContext(ctx, string(content)) if execErr != nil { return fmt.Errorf("failed to apply migration %s: %w", migration, execErr) } // Record migration as applied. _, recErr := db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) if recErr != nil { return fmt.Errorf("failed to record migration %s: %w", migration, recErr) } if log != nil { log.Info("migration applied successfully", "version", version) } } return nil } // DB returns the underlying sql.DB. func (s *Database) DB() *sql.DB { return s.db }