package database import ( "context" "database/sql" "embed" "errors" "fmt" "io/fs" "log/slog" "sort" "strconv" "strings" ) //go:embed migrations/*.sql var migrationsFS embed.FS // bootstrapVersion is the migration that creates the schema_migrations // table itself. It is applied before the normal migration loop. const bootstrapVersion = 0 // ErrInvalidMigrationFilename indicates a migration filename does not follow // the expected ".sql" or "_.sql" pattern. var ErrInvalidMigrationFilename = errors.New("invalid migration filename") // ParseMigrationVersion extracts the numeric version prefix from a migration // filename. Filenames must follow the pattern ".sql" or // "_.sql", where version is a zero-padded numeric // string (e.g. "001", "002"). Returns the version as an integer and an // error if the filename does not match the expected pattern. func ParseMigrationVersion(filename string) (int, error) { name := strings.TrimSuffix(filename, ".sql") if name == "" || name == filename { return 0, fmt.Errorf("%w: %q has no .sql extension or is empty", ErrInvalidMigrationFilename, filename) } // Split on underscore to separate version from description. // If there's no underscore, the entire stem is the version. versionStr, _, _ := strings.Cut(name, "_") if versionStr == "" { return 0, fmt.Errorf("%w: %q has empty version prefix", ErrInvalidMigrationFilename, filename) } // Validate the version is purely numeric. for _, ch := range versionStr { if ch < '0' || ch > '9' { return 0, fmt.Errorf( "%w: %q version %q contains non-numeric character %q", ErrInvalidMigrationFilename, filename, versionStr, string(ch), ) } } version, err := strconv.Atoi(versionStr) if err != nil { return 0, fmt.Errorf("%w: %q: %w", ErrInvalidMigrationFilename, filename, err) } return version, nil } // collectMigrations reads the embedded migrations directory and returns // migration filenames sorted lexicographically. func collectMigrations() ([]string, error) { entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return nil, fmt.Errorf("failed to read migrations directory: %w", err) } var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) return migrations, nil } // bootstrapMigrationsTable ensures the schema_migrations table exists by // applying 000_migration.sql if the table is missing. For databases with a // legacy TEXT-based schema_migrations table, it converts to the new INTEGER // format. func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { var tableExists int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", ).Scan(&tableExists) if err != nil { return fmt.Errorf("failed to check for migrations table: %w", err) } if tableExists == 0 { return applyBootstrapMigration(ctx, db, log) } // Table exists — check for and convert legacy TEXT-based versions. return convertLegacyMigrations(ctx, db, log) } // applyBootstrapMigration reads and executes 000_migration.sql to create the // schema_migrations table on a fresh database. func applyBootstrapMigration(ctx context.Context, db *sql.DB, log *slog.Logger) error { content, err := migrationsFS.ReadFile("migrations/000_migration.sql") if err != nil { return fmt.Errorf("failed to read bootstrap migration 000_migration.sql: %w", err) } if log != nil { log.Info("applying bootstrap migration", "version", bootstrapVersion) } _, err = db.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to apply bootstrap migration: %w", err) } return nil } // convertLegacyMigrations converts a schema_migrations table that uses // TEXT filename-based versions (e.g. "001_initial.sql") to INTEGER versions // (e.g. 1). This is a one-time migration for existing databases. func convertLegacyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { // Check if any version looks like a legacy filename (contains underscore). var legacyCount int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE INSTR(CAST(version AS TEXT), '_') > 0", ).Scan(&legacyCount) if err != nil { return fmt.Errorf("failed to check for legacy migrations: %w", err) } if legacyCount == 0 { return ensureBootstrapVersion(ctx, db) } if log != nil { log.Info("converting legacy schema_migrations from TEXT to INTEGER format", "legacy_entries", legacyCount) } intVersions, err := readLegacyVersions(ctx, db) if err != nil { return err } err = rebuildMigrationsTable(ctx, db, intVersions) if err != nil { return err } if log != nil { log.Info("legacy migration conversion complete", "versions_converted", len(intVersions)) } return nil } // ensureBootstrapVersion inserts version 0 if it is not already present. func ensureBootstrapVersion(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, "INSERT OR IGNORE INTO schema_migrations (version) VALUES (0)") if err != nil { return fmt.Errorf("failed to ensure bootstrap version: %w", err) } return nil } // readLegacyVersions reads all version entries from the legacy schema_migrations // table and parses them into integer versions. func readLegacyVersions(ctx context.Context, db *sql.DB) ([]int, error) { rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations") if err != nil { return nil, fmt.Errorf("failed to read legacy migrations: %w", err) } defer func() { _ = rows.Close() }() var intVersions []int for rows.Next() { var version string scanErr := rows.Scan(&version) if scanErr != nil { return nil, fmt.Errorf("failed to scan legacy version: %w", scanErr) } v, parseErr := ParseMigrationVersion(version) if parseErr != nil { return nil, fmt.Errorf("failed to parse legacy version %q: %w", version, parseErr) } intVersions = append(intVersions, v) } rowsErr := rows.Err() if rowsErr != nil { return nil, fmt.Errorf("failed to iterate legacy versions: %w", rowsErr) } return intVersions, nil } // rebuildMigrationsTable drops the old schema_migrations table, recreates it // via 000_migration.sql, and re-inserts the given integer versions. The entire // operation runs in a single transaction so a crash cannot lose migration data. func rebuildMigrationsTable(ctx context.Context, db *sql.DB, versions []int) error { content, err := migrationsFS.ReadFile("migrations/000_migration.sql") if err != nil { return fmt.Errorf("failed to read bootstrap migration 000_migration.sql: %w", err) } tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction for legacy conversion: %w", err) } defer func() { if err != nil { _ = tx.Rollback() } }() _, err = tx.ExecContext(ctx, "DROP TABLE schema_migrations") if err != nil { return fmt.Errorf("failed to drop legacy migrations table: %w", err) } _, err = tx.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to create new migrations table: %w", err) } for _, v := range versions { _, err = tx.ExecContext(ctx, "INSERT OR IGNORE INTO schema_migrations (version) VALUES (?)", v) if err != nil { return fmt.Errorf("failed to insert converted version %d: %w", v, err) } } err = tx.Commit() if err != nil { return fmt.Errorf("failed to commit legacy conversion: %w", err) } return nil } // ApplyMigrations applies all pending migrations to db. An optional logger // may be provided for informational output; pass nil for silent operation. // This is exported so tests can apply the real schema without the full fx // lifecycle. func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { bootstrapErr := bootstrapMigrationsTable(ctx, db, log) if bootstrapErr != nil { return bootstrapErr } migrations, err := collectMigrations() if err != nil { return err } for _, migration := range migrations { version, parseErr := ParseMigrationVersion(migration) if parseErr != nil { return parseErr } // Check if already applied. var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return fmt.Errorf("failed to check migration %s: %w", migration, err) } if count > 0 { if log != nil { log.Debug("migration already applied", "version", version) } continue } // Apply migration in a transaction. applyErr := applyMigrationTx(ctx, db, migration, version) if applyErr != nil { return applyErr } if log != nil { log.Info("migration applied", "version", version) } } return nil } // applyMigrationTx reads and executes a migration file within a transaction, // recording the version in schema_migrations on success. func applyMigrationTx(ctx context.Context, db *sql.DB, filename string, version int) error { content, err := migrationsFS.ReadFile("migrations/" + filename) if err != nil { return fmt.Errorf("failed to read migration %s: %w", filename, err) } transaction, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction for migration %s: %w", filename, err) } defer func() { if err != nil { _ = transaction.Rollback() } }() _, err = transaction.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to execute migration %s: %w", filename, err) } _, err = transaction.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) if err != nil { return fmt.Errorf("failed to record migration %s: %w", filename, err) } err = transaction.Commit() if err != nil { return fmt.Errorf("failed to commit migration %s: %w", filename, err) } return nil } func (d *Database) migrate(ctx context.Context) error { return ApplyMigrations(ctx, d.database, d.log) }