package database import ( "context" "database/sql" "embed" "errors" "fmt" "io/fs" "log/slog" "sort" "strconv" "strings" ) //go:embed migrations/*.sql var migrationsFS embed.FS // bootstrapVersion is the migration that creates the schema_migrations // table itself. It is applied before the normal migration loop. const bootstrapVersion = 0 // ErrInvalidMigrationFilename indicates a migration filename does not follow // the expected ".sql" or "_.sql" pattern. var ErrInvalidMigrationFilename = errors.New("invalid migration filename") // ParseMigrationVersion extracts the numeric version prefix from a migration // filename. Filenames must follow the pattern ".sql" or // "_.sql", where version is a zero-padded numeric // string (e.g. "001", "002"). Returns the version as an integer and an // error if the filename does not match the expected pattern. func ParseMigrationVersion(filename string) (int, error) { name := strings.TrimSuffix(filename, ".sql") if name == "" || name == filename { return 0, fmt.Errorf("%w: %q has no .sql extension or is empty", ErrInvalidMigrationFilename, filename) } // Split on underscore to separate version from description. // If there's no underscore, the entire stem is the version. versionStr, _, _ := strings.Cut(name, "_") if versionStr == "" { return 0, fmt.Errorf("%w: %q has empty version prefix", ErrInvalidMigrationFilename, filename) } // Validate the version is purely numeric. for _, ch := range versionStr { if ch < '0' || ch > '9' { return 0, fmt.Errorf( "%w: %q version %q contains non-numeric character %q", ErrInvalidMigrationFilename, filename, versionStr, string(ch), ) } } version, err := strconv.Atoi(versionStr) if err != nil { return 0, fmt.Errorf("%w: %q: %w", ErrInvalidMigrationFilename, filename, err) } return version, nil } // collectMigrations reads the embedded migrations directory and returns // migration filenames sorted lexicographically. func collectMigrations() ([]string, error) { entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return nil, fmt.Errorf("failed to read migrations directory: %w", err) } var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) return migrations, nil } // bootstrapMigrationsTable ensures the schema_migrations table exists by // applying 000_migration.sql if the table is missing. func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { var tableExists int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", ).Scan(&tableExists) if err != nil { return fmt.Errorf("failed to check for migrations table: %w", err) } if tableExists == 0 { return applyBootstrapMigration(ctx, db, log) } return nil } // applyBootstrapMigration reads and executes 000_migration.sql to create the // schema_migrations table on a fresh database. func applyBootstrapMigration(ctx context.Context, db *sql.DB, log *slog.Logger) error { content, err := migrationsFS.ReadFile("migrations/000_migration.sql") if err != nil { return fmt.Errorf("failed to read bootstrap migration 000_migration.sql: %w", err) } if log != nil { log.Info("applying bootstrap migration", "version", bootstrapVersion) } _, err = db.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to apply bootstrap migration: %w", err) } return nil } // ApplyMigrations applies all pending migrations to db. An optional logger // may be provided for informational output; pass nil for silent operation. // This is exported so tests can apply the real schema without the full fx // lifecycle. func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { bootstrapErr := bootstrapMigrationsTable(ctx, db, log) if bootstrapErr != nil { return bootstrapErr } migrations, err := collectMigrations() if err != nil { return err } for _, migration := range migrations { version, parseErr := ParseMigrationVersion(migration) if parseErr != nil { return parseErr } // Check if already applied. var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return fmt.Errorf("failed to check migration %s: %w", migration, err) } if count > 0 { if log != nil { log.Debug("migration already applied", "version", version) } continue } // Apply migration in a transaction. applyErr := applyMigrationTx(ctx, db, migration, version) if applyErr != nil { return applyErr } if log != nil { log.Info("migration applied", "version", version) } } return nil } // applyMigrationTx reads and executes a migration file within a transaction, // recording the version in schema_migrations on success. func applyMigrationTx(ctx context.Context, db *sql.DB, filename string, version int) error { content, err := migrationsFS.ReadFile("migrations/" + filename) if err != nil { return fmt.Errorf("failed to read migration %s: %w", filename, err) } transaction, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction for migration %s: %w", filename, err) } defer func() { if err != nil { _ = transaction.Rollback() } }() _, err = transaction.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to execute migration %s: %w", filename, err) } _, err = transaction.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) if err != nil { return fmt.Errorf("failed to record migration %s: %w", filename, err) } err = transaction.Commit() if err != nil { return fmt.Errorf("failed to commit migration %s: %w", filename, err) } return nil } func (d *Database) migrate(ctx context.Context) error { return ApplyMigrations(ctx, d.database, d.log) }