From 31a6299ce7b6c60fc60edccb67b82530b15fd3eb Mon Sep 17 00:00:00 2001 From: user Date: Thu, 26 Mar 2026 06:43:53 -0700 Subject: [PATCH] Move schema_migrations table creation into 000.sql with INTEGER version column Refactors the migration system to follow the pixa pattern: - Add 000.sql bootstrap migration that creates schema_migrations with INTEGER PRIMARY KEY version column - Go code no longer creates the migrations table inline; it reads and executes 000.sql as a bootstrap step before the normal migration loop - Export ParseMigrationVersion and ApplyMigrations for test use - Add legacy TEXT-to-INTEGER conversion for existing databases that stored migration versions as filenames (e.g. '001_initial.sql') - Wrap individual migration application in transactions for safety - Add comprehensive tests for version parsing, fresh database bootstrap, idempotent re-application, and legacy conversion --- internal/database/migrations.go | 318 ++++++++++++++++++++++----- internal/database/migrations/000.sql | 9 + internal/database/migrations_test.go | 192 ++++++++++++++++ 3 files changed, 468 insertions(+), 51 deletions(-) create mode 100644 internal/database/migrations/000.sql create mode 100644 internal/database/migrations_test.go diff --git a/internal/database/migrations.go b/internal/database/migrations.go index 03a69fa..4636340 100644 --- a/internal/database/migrations.go +++ b/internal/database/migrations.go @@ -2,36 +2,74 @@ package database import ( "context" + "database/sql" "embed" + "errors" "fmt" "io/fs" + "log/slog" "sort" + "strconv" "strings" ) //go:embed migrations/*.sql var migrationsFS embed.FS -func (d *Database) migrate(ctx context.Context) error { - // Create migrations table if not exists - _, err := d.database.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) +// bootstrapVersion is the migration that creates the schema_migrations +// table itself. It is applied before the normal migration loop. +const bootstrapVersion = 0 + +// ErrInvalidMigrationFilename indicates a migration filename does not follow +// the expected ".sql" or "_.sql" pattern. +var ErrInvalidMigrationFilename = errors.New("invalid migration filename") + +// ParseMigrationVersion extracts the numeric version prefix from a migration +// filename. Filenames must follow the pattern ".sql" or +// "_.sql", where version is a zero-padded numeric +// string (e.g. "001", "002"). Returns the version as an integer and an +// error if the filename does not match the expected pattern. +func ParseMigrationVersion(filename string) (int, error) { + name := strings.TrimSuffix(filename, ".sql") + if name == "" || name == filename { + return 0, fmt.Errorf("%w: %q has no .sql extension or is empty", ErrInvalidMigrationFilename, filename) } - // Get list of migration files + // Split on underscore to separate version from description. + // If there's no underscore, the entire stem is the version. + versionStr, _, _ := strings.Cut(name, "_") + + if versionStr == "" { + return 0, fmt.Errorf("%w: %q has empty version prefix", ErrInvalidMigrationFilename, filename) + } + + // Validate the version is purely numeric. + for _, ch := range versionStr { + if ch < '0' || ch > '9' { + return 0, fmt.Errorf( + "%w: %q version %q contains non-numeric character %q", + ErrInvalidMigrationFilename, filename, versionStr, string(ch), + ) + } + } + + version, err := strconv.Atoi(versionStr) + if err != nil { + return 0, fmt.Errorf("%w: %q: %w", ErrInvalidMigrationFilename, filename, err) + } + + return version, nil +} + +// collectMigrations reads the embedded migrations directory and returns +// migration filenames sorted lexicographically. +func collectMigrations() ([]string, error) { entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { - return fmt.Errorf("failed to read migrations directory: %w", err) + return nil, fmt.Errorf("failed to read migrations directory: %w", err) } - // Sort migrations by name - migrations := make([]string, 0, len(entries)) + var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { @@ -41,54 +79,231 @@ func (d *Database) migrate(ctx context.Context) error { sort.Strings(migrations) - // Apply each migration - for _, migration := range migrations { - applied, err := d.isMigrationApplied(ctx, migration) - if err != nil { - return fmt.Errorf("failed to check migration %s: %w", migration, err) - } + return migrations, nil +} - if applied { - d.log.Debug("migration already applied", "migration", migration) +// bootstrapMigrationsTable ensures the schema_migrations table exists by +// applying 000.sql if the table is missing. For databases with a legacy +// TEXT-based schema_migrations table, it converts to the new INTEGER format. +func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { + var tableExists int - continue - } + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableExists) + if err != nil { + return fmt.Errorf("failed to check for migrations table: %w", err) + } - err = d.applyMigration(ctx, migration) - if err != nil { - return fmt.Errorf("failed to apply migration %s: %w", migration, err) - } + if tableExists == 0 { + return applyBootstrapMigration(ctx, db, log) + } - d.log.Info("migration applied", "migration", migration) + // Table exists — check for and convert legacy TEXT-based versions. + return convertLegacyMigrations(ctx, db, log) +} + +// applyBootstrapMigration reads and executes 000.sql to create the +// schema_migrations table on a fresh database. +func applyBootstrapMigration(ctx context.Context, db *sql.DB, log *slog.Logger) error { + content, err := migrationsFS.ReadFile("migrations/000.sql") + if err != nil { + return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) + } + + if log != nil { + log.Info("applying bootstrap migration", "version", bootstrapVersion) + } + + _, err = db.ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("failed to apply bootstrap migration: %w", err) } return nil } -func (d *Database) isMigrationApplied(ctx context.Context, version string) (bool, error) { - var count int +// convertLegacyMigrations converts a schema_migrations table that uses +// TEXT filename-based versions (e.g. "001_initial.sql") to INTEGER versions +// (e.g. 1). This is a one-time migration for existing databases. +func convertLegacyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { + // Check if any version looks like a legacy filename (contains underscore). + var legacyCount int - err := d.database.QueryRowContext( - ctx, - "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", - version, - ).Scan(&count) + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE INSTR(CAST(version AS TEXT), '_') > 0", + ).Scan(&legacyCount) if err != nil { - return false, fmt.Errorf("failed to query migration status: %w", err) + return fmt.Errorf("failed to check for legacy migrations: %w", err) } - return count > 0, nil + if legacyCount == 0 { + return ensureBootstrapVersion(ctx, db) + } + + if log != nil { + log.Info("converting legacy schema_migrations from TEXT to INTEGER format", + "legacy_entries", legacyCount) + } + + intVersions, err := readLegacyVersions(ctx, db) + if err != nil { + return err + } + + err = rebuildMigrationsTable(ctx, db, intVersions) + if err != nil { + return err + } + + if log != nil { + log.Info("legacy migration conversion complete", "versions_converted", len(intVersions)) + } + + return nil } -func (d *Database) applyMigration(ctx context.Context, filename string) error { - content, err := migrationsFS.ReadFile("migrations/" + filename) +// ensureBootstrapVersion inserts version 0 if it is not already present. +func ensureBootstrapVersion(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, + "INSERT OR IGNORE INTO schema_migrations (version) VALUES (0)") if err != nil { - return fmt.Errorf("failed to read migration file: %w", err) + return fmt.Errorf("failed to ensure bootstrap version: %w", err) } - transaction, err := d.database.BeginTx(ctx, nil) + return nil +} + +// readLegacyVersions reads all version entries from the legacy schema_migrations +// table and parses them into integer versions. +func readLegacyVersions(ctx context.Context, db *sql.DB) ([]int, error) { + rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations") if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) + return nil, fmt.Errorf("failed to read legacy migrations: %w", err) + } + + defer func() { _ = rows.Close() }() + + var intVersions []int + + for rows.Next() { + var version string + + scanErr := rows.Scan(&version) + if scanErr != nil { + return nil, fmt.Errorf("failed to scan legacy version: %w", scanErr) + } + + v, parseErr := ParseMigrationVersion(version) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse legacy version %q: %w", version, parseErr) + } + + intVersions = append(intVersions, v) + } + + rowsErr := rows.Err() + if rowsErr != nil { + return nil, fmt.Errorf("failed to iterate legacy versions: %w", rowsErr) + } + + return intVersions, nil +} + +// rebuildMigrationsTable drops the old schema_migrations table, recreates it +// via 000.sql, and re-inserts the given integer versions. +func rebuildMigrationsTable(ctx context.Context, db *sql.DB, versions []int) error { + _, err := db.ExecContext(ctx, "DROP TABLE schema_migrations") + if err != nil { + return fmt.Errorf("failed to drop legacy migrations table: %w", err) + } + + content, err := migrationsFS.ReadFile("migrations/000.sql") + if err != nil { + return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) + } + + _, err = db.ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("failed to create new migrations table: %w", err) + } + + for _, v := range versions { + _, insErr := db.ExecContext(ctx, + "INSERT OR IGNORE INTO schema_migrations (version) VALUES (?)", v) + if insErr != nil { + return fmt.Errorf("failed to insert converted version %d: %w", v, insErr) + } + } + + return nil +} + +// ApplyMigrations applies all pending migrations to db. An optional logger +// may be provided for informational output; pass nil for silent operation. +// This is exported so tests can apply the real schema without the full fx +// lifecycle. +func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { + bootstrapErr := bootstrapMigrationsTable(ctx, db, log) + if bootstrapErr != nil { + return bootstrapErr + } + + migrations, err := collectMigrations() + if err != nil { + return err + } + + for _, migration := range migrations { + version, parseErr := ParseMigrationVersion(migration) + if parseErr != nil { + return parseErr + } + + // Check if already applied. + var count int + + err = db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + version, + ).Scan(&count) + if err != nil { + return fmt.Errorf("failed to check migration %s: %w", migration, err) + } + + if count > 0 { + if log != nil { + log.Debug("migration already applied", "version", version) + } + + continue + } + + // Apply migration in a transaction. + applyErr := applyMigrationTx(ctx, db, migration, version) + if applyErr != nil { + return applyErr + } + + if log != nil { + log.Info("migration applied", "version", version) + } + } + + return nil +} + +// applyMigrationTx reads and executes a migration file within a transaction, +// recording the version in schema_migrations on success. +func applyMigrationTx(ctx context.Context, db *sql.DB, filename string, version int) error { + content, err := migrationsFS.ReadFile("migrations/" + filename) + if err != nil { + return fmt.Errorf("failed to read migration %s: %w", filename, err) + } + + transaction, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction for migration %s: %w", filename, err) } defer func() { @@ -97,26 +312,27 @@ func (d *Database) applyMigration(ctx context.Context, filename string) error { } }() - // Execute migration _, err = transaction.ExecContext(ctx, string(content)) if err != nil { - return fmt.Errorf("failed to execute migration: %w", err) + return fmt.Errorf("failed to execute migration %s: %w", filename, err) } - // Record migration - _, err = transaction.ExecContext( - ctx, + _, err = transaction.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", - filename, + version, ) if err != nil { - return fmt.Errorf("failed to record migration: %w", err) + return fmt.Errorf("failed to record migration %s: %w", filename, err) } err = transaction.Commit() if err != nil { - return fmt.Errorf("failed to commit migration: %w", err) + return fmt.Errorf("failed to commit migration %s: %w", filename, err) } return nil } + +func (d *Database) migrate(ctx context.Context) error { + return ApplyMigrations(ctx, d.database, d.log) +} diff --git a/internal/database/migrations/000.sql b/internal/database/migrations/000.sql new file mode 100644 index 0000000..e06a2da --- /dev/null +++ b/internal/database/migrations/000.sql @@ -0,0 +1,9 @@ +-- Migration 000: Schema migrations tracking table +-- Applied as a bootstrap step before the normal migration loop. + +CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +INSERT OR IGNORE INTO schema_migrations (version) VALUES (0); diff --git a/internal/database/migrations_test.go b/internal/database/migrations_test.go new file mode 100644 index 0000000..12f0f64 --- /dev/null +++ b/internal/database/migrations_test.go @@ -0,0 +1,192 @@ +package database_test + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sneak.berlin/go/upaas/internal/database" +) + +func TestParseMigrationVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + filename string + wantVersion int + wantErr bool + }{ + {filename: "000.sql", wantVersion: 0}, + {filename: "001_initial.sql", wantVersion: 1}, + {filename: "002_remove_container_id.sql", wantVersion: 2}, + {filename: "007_add_resource_limits.sql", wantVersion: 7}, + {filename: "100_large_version.sql", wantVersion: 100}, + {filename: ".sql", wantErr: true}, + {filename: "_foo.sql", wantErr: true}, + {filename: "abc_foo.sql", wantErr: true}, + {filename: "1a2_bad.sql", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.filename, func(t *testing.T) { + t.Parallel() + + version, err := database.ParseMigrationVersion(tt.filename) + if tt.wantErr { + assert.Error(t, err) + + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantVersion, version) + }) + } +} + +func TestApplyMigrationsFreshDatabase(t *testing.T) { + t.Parallel() + + db, err := sql.Open("sqlite3", ":memory:?_foreign_keys=on") + require.NoError(t, err) + + defer func() { _ = db.Close() }() + + ctx := context.Background() + + err = database.ApplyMigrations(ctx, db, nil) + require.NoError(t, err) + + // Verify schema_migrations table exists with INTEGER version column. + var version int + + err = db.QueryRowContext(ctx, + "SELECT version FROM schema_migrations WHERE version = 0", + ).Scan(&version) + require.NoError(t, err) + assert.Equal(t, 0, version) + + // Verify that all migrations were recorded. + var count int + + err = db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations", + ).Scan(&count) + require.NoError(t, err) + // 000 bootstrap + 001 through 007 = 8 entries. + assert.Equal(t, 8, count) + + // Verify application tables were created by the migrations. + var tableCount int + + err = db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'", + ).Scan(&tableCount) + require.NoError(t, err) + assert.Equal(t, 1, tableCount) +} + +func TestApplyMigrationsIdempotent(t *testing.T) { + t.Parallel() + + db, err := sql.Open("sqlite3", ":memory:?_foreign_keys=on") + require.NoError(t, err) + + defer func() { _ = db.Close() }() + + ctx := context.Background() + + // Apply twice — second run should be a no-op. + err = database.ApplyMigrations(ctx, db, nil) + require.NoError(t, err) + + err = database.ApplyMigrations(ctx, db, nil) + require.NoError(t, err) + + var count int + + err = db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations", + ).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 8, count) +} + +func TestApplyMigrationsLegacyConversion(t *testing.T) { + t.Parallel() + + db, err := sql.Open("sqlite3", ":memory:?_foreign_keys=on") + require.NoError(t, err) + + defer func() { _ = db.Close() }() + + ctx := context.Background() + + // Simulate the old TEXT-based schema_migrations table. + _, err = db.ExecContext(ctx, ` + CREATE TABLE schema_migrations ( + version TEXT PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + require.NoError(t, err) + + // Insert legacy filename-based versions (as if migrations 001-007 were applied). + legacyVersions := []string{ + "001_initial.sql", + "002_remove_container_id.sql", + "003_add_ports.sql", + "004_add_commit_url.sql", + "005_add_webhook_secret_hash.sql", + "006_add_previous_image_id.sql", + "007_add_resource_limits.sql", + } + + for _, v := range legacyVersions { + _, err = db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", v) + require.NoError(t, err) + } + + // Also create the application tables that would exist on a real database + // so that the migrations (001-007) don't fail when re-applied. + // Actually, since the legacy versions will be converted and recognized + // as already-applied, the DDL migrations won't re-run. We just need + // to make sure ApplyMigrations succeeds. + + // Run ApplyMigrations — this should convert legacy versions and + // skip all already-applied migrations. + err = database.ApplyMigrations(ctx, db, nil) + require.NoError(t, err) + + // Verify version 0 is now present (from bootstrap). + var zeroVersion int + + err = db.QueryRowContext(ctx, + "SELECT version FROM schema_migrations WHERE version = 0", + ).Scan(&zeroVersion) + require.NoError(t, err) + assert.Equal(t, 0, zeroVersion) + + // Verify all 8 versions are present (0 through 7). + var count int + + err = db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations", + ).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 8, count) + + // Verify versions are stored as integers, not text filenames. + var maxVersion int + + err = db.QueryRowContext(ctx, + "SELECT MAX(version) FROM schema_migrations", + ).Scan(&maxVersion) + require.NoError(t, err) + assert.Equal(t, 7, maxVersion) +}