From 5504495e0c349c20892280b2a0d6c10aa5cb7f95 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 17 Mar 2026 01:56:57 -0700 Subject: [PATCH] Move schema_migrations table creation from Go code into 000.sql The schema_migrations table definition now lives in internal/database/schema/000.sql instead of being hardcoded as an inline SQL string in database.go. A bootstrap step checks sqlite_master for the table and applies 000.sql when it is missing. Existing databases that already have the table (created by older inline code) get version 000 back-filled so the normal migration loop skips the file. Also deduplicates the migration logic: both the Database.runMigrations method and the exported ApplyMigrations helper now delegate to a single applyMigrations function. Adds database_test.go with tests for fresh migration, idempotency, bootstrap on a fresh DB, and backwards compatibility with legacy DBs. --- internal/database/database.go | 193 +++++++++++++++------------- internal/database/database_test.go | 199 +++++++++++++++++++++++++++++ internal/database/schema/000.sql | 9 ++ 3 files changed, 310 insertions(+), 91 deletions(-) create mode 100644 internal/database/database_test.go create mode 100644 internal/database/schema/000.sql diff --git a/internal/database/database.go b/internal/database/database.go index be80f1c..3270e3e 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -21,6 +21,10 @@ import ( //go:embed schema/*.sql var schemaFS embed.FS +// bootstrapVersion is the migration that creates the schema_migrations +// table itself. It is applied before the normal migration loop. +const bootstrapVersion = "000" + // Params defines dependencies for Database. type Params struct { fx.In @@ -84,43 +88,36 @@ func (s *Database) connect(ctx context.Context) error { s.db = db s.log.Info("database connected") - return s.runMigrations(ctx) + return applyMigrations(ctx, s.db, s.log) } -func (s *Database) runMigrations(ctx context.Context) error { - // Create migrations tracking table - _, err := s.db.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) +// applyMigrations bootstraps the migrations table from 000.sql and then +// applies every remaining migration that has not been recorded yet. +func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { + if err := bootstrapMigrationsTable(ctx, db, log); err != nil { + return err } - // Get list of migration files entries, err := schemaFS.ReadDir("schema") if err != nil { return fmt.Errorf("failed to read schema directory: %w", err) } - // Sort migration files by name (001.sql, 002.sql, etc.) var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } + sort.Strings(migrations) - // Apply each migration that hasn't been applied yet for _, migration := range migrations { version := strings.TrimSuffix(migration, filepath.Ext(migration)) - // Check if already applied var count int - err := s.db.QueryRowContext(ctx, + + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) @@ -129,26 +126,24 @@ func (s *Database) runMigrations(ctx context.Context) error { } if count > 0 { - s.log.Debug("migration already applied", "version", version) + logDebug(log, "migration already applied", "version", version) continue } - // Read and apply migration content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) if err != nil { return fmt.Errorf("failed to read migration %s: %w", migration, err) } - s.log.Info("applying migration", "version", version) + logInfo(log, "applying migration", "version", version) - _, err = s.db.ExecContext(ctx, string(content)) + _, err = db.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to apply migration %s: %w", migration, err) } - // Record migration as applied - _, err = s.db.ExecContext(ctx, + _, err = db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) @@ -156,12 +151,81 @@ func (s *Database) runMigrations(ctx context.Context) error { return fmt.Errorf("failed to record migration %s: %w", migration, err) } - s.log.Info("migration applied successfully", "version", version) + logInfo(log, "migration applied successfully", "version", version) } return nil } +// bootstrapMigrationsTable ensures the schema_migrations table exists +// by applying 000.sql directly. For databases that already have the +// table (created by older code), it records version "000" for +// consistency. +func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { + var tableExists int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableExists) + if err != nil { + return fmt.Errorf("failed to check for migrations table: %w", err) + } + + if tableExists > 0 { + // Table already exists (from older inline-SQL code or a + // previous run). Make sure version "000" is recorded so the + // normal loop skips the bootstrap file. + var recorded int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + bootstrapVersion, + ).Scan(&recorded) + if err != nil { + return fmt.Errorf("failed to check bootstrap migration status: %w", err) + } + + if recorded == 0 { + _, err = db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + bootstrapVersion, + ) + if err != nil { + return fmt.Errorf("failed to record bootstrap migration: %w", err) + } + + logInfo(log, "recorded bootstrap migration for existing table", "version", bootstrapVersion) + } + + return nil + } + + // Table does not exist — apply 000.sql to create it. + content, err := schemaFS.ReadFile("schema/000.sql") + if err != nil { + return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) + } + + logInfo(log, "applying bootstrap migration", "version", bootstrapVersion) + + _, err = db.ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("failed to apply bootstrap migration: %w", err) + } + + _, err = db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + bootstrapVersion, + ) + if err != nil { + return fmt.Errorf("failed to record bootstrap migration: %w", err) + } + + logInfo(log, "bootstrap migration applied successfully", "version", bootstrapVersion) + + return nil +} + // DB returns the underlying sql.DB. func (s *Database) DB() *sql.DB { return s.db @@ -171,72 +235,19 @@ func (s *Database) DB() *sql.DB { // This is useful for testing where you want to use the real schema // without the full fx lifecycle. func ApplyMigrations(db *sql.DB) error { - ctx := context.Background() - - // Create migrations tracking table - _, err := db.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) - } - - // Get list of migration files - entries, err := schemaFS.ReadDir("schema") - if err != nil { - return fmt.Errorf("failed to read schema directory: %w", err) - } - - // Sort migration files by name (001.sql, 002.sql, etc.) - var migrations []string - for _, entry := range entries { - if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { - migrations = append(migrations, entry.Name()) - } - } - sort.Strings(migrations) - - // Apply each migration that hasn't been applied yet - for _, migration := range migrations { - version := strings.TrimSuffix(migration, filepath.Ext(migration)) - - // Check if already applied - var count int - err := db.QueryRowContext(ctx, - "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", - version, - ).Scan(&count) - if err != nil { - return fmt.Errorf("failed to check migration status: %w", err) - } - - if count > 0 { - continue - } - - // Read and apply migration - content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) - if err != nil { - return fmt.Errorf("failed to read migration %s: %w", migration, err) - } - - _, err = db.ExecContext(ctx, string(content)) - if err != nil { - return fmt.Errorf("failed to apply migration %s: %w", migration, err) - } - - // Record migration as applied - _, err = db.ExecContext(ctx, - "INSERT INTO schema_migrations (version) VALUES (?)", - version, - ) - if err != nil { - return fmt.Errorf("failed to record migration %s: %w", migration, err) - } - } - - return nil + return applyMigrations(context.Background(), db, nil) +} + +// logInfo logs at info level when a logger is available. +func logInfo(log *slog.Logger, msg string, args ...any) { + if log != nil { + log.Info(msg, args...) + } +} + +// logDebug logs at debug level when a logger is available. +func logDebug(log *slog.Logger, msg string, args ...any) { + if log != nil { + log.Debug(msg, args...) + } } diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 0000000..735c49f --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,199 @@ +package database + +import ( + "context" + "database/sql" + "testing" + + _ "modernc.org/sqlite" // SQLite driver registration +) + +// openTestDB returns a fresh in-memory SQLite database. +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to open test db: %v", err) + } + + t.Cleanup(func() { db.Close() }) + + return db +} + +func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) { + db := openTestDB(t) + + if err := ApplyMigrations(db); err != nil { + t.Fatalf("ApplyMigrations failed: %v", err) + } + + // The schema_migrations table must exist and contain at least + // version "000" (the bootstrap) and "001" (the initial schema). + rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") + if err != nil { + t.Fatalf("failed to query schema_migrations: %v", err) + } + defer rows.Close() + + var versions []string + for rows.Next() { + var v string + if err := rows.Scan(&v); err != nil { + t.Fatalf("failed to scan version: %v", err) + } + + versions = append(versions, v) + } + + if err := rows.Err(); err != nil { + t.Fatalf("row iteration error: %v", err) + } + + if len(versions) < 2 { + t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions) + } + + if versions[0] != "000" { + t.Errorf("first recorded migration = %q, want %q", versions[0], "000") + } + + if versions[1] != "001" { + t.Errorf("second recorded migration = %q, want %q", versions[1], "001") + } + + // Verify that the application tables created by 001.sql exist. + for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} { + var count int + + err := db.QueryRow( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", + table, + ).Scan(&count) + if err != nil { + t.Fatalf("failed to check for table %s: %v", table, err) + } + + if count != 1 { + t.Errorf("table %s does not exist after migrations", table) + } + } +} + +func TestApplyMigrations_Idempotent(t *testing.T) { + db := openTestDB(t) + + if err := ApplyMigrations(db); err != nil { + t.Fatalf("first ApplyMigrations failed: %v", err) + } + + // Running a second time must succeed without errors. + if err := ApplyMigrations(db); err != nil { + t.Fatalf("second ApplyMigrations failed: %v", err) + } + + // Verify no duplicate rows in schema_migrations. + var count int + + err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = '000'").Scan(&count) + if err != nil { + t.Fatalf("failed to count 000 rows: %v", err) + } + + if count != 1 { + t.Errorf("expected exactly 1 row for version 000, got %d", count) + } +} + +func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + + if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { + t.Fatalf("bootstrapMigrationsTable failed: %v", err) + } + + // schema_migrations table must exist. + var tableCount int + + err := db.QueryRow( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableCount) + if err != nil { + t.Fatalf("failed to check for table: %v", err) + } + + if tableCount != 1 { + t.Fatalf("schema_migrations table not created") + } + + // Version "000" must be recorded. + var recorded int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = '000'", + ).Scan(&recorded) + if err != nil { + t.Fatalf("failed to check version: %v", err) + } + + if recorded != 1 { + t.Errorf("expected version 000 to be recorded, got count %d", recorded) + } +} + +func TestBootstrapMigrationsTable_ExistingTableBackwardsCompat(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + + // Simulate an older database that created the table via inline SQL + // (without recording version "000"). + _, err := db.Exec(` + CREATE TABLE schema_migrations ( + version TEXT PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + t.Fatalf("failed to create legacy table: %v", err) + } + + // Insert a fake migration to prove the table already existed. + _, err = db.Exec("INSERT INTO schema_migrations (version) VALUES ('001')") + if err != nil { + t.Fatalf("failed to insert legacy version: %v", err) + } + + if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { + t.Fatalf("bootstrapMigrationsTable failed: %v", err) + } + + // Version "000" must now be recorded. + var recorded int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = '000'", + ).Scan(&recorded) + if err != nil { + t.Fatalf("failed to check version: %v", err) + } + + if recorded != 1 { + t.Errorf("expected version 000 to be recorded for legacy DB, got count %d", recorded) + } + + // The existing "001" row must still be there. + var legacyCount int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = '001'", + ).Scan(&legacyCount) + if err != nil { + t.Fatalf("failed to check legacy version: %v", err) + } + + if legacyCount != 1 { + t.Errorf("legacy version 001 row missing after bootstrap") + } +} diff --git a/internal/database/schema/000.sql b/internal/database/schema/000.sql new file mode 100644 index 0000000..c05b915 --- /dev/null +++ b/internal/database/schema/000.sql @@ -0,0 +1,9 @@ +-- Migration 000: Schema migrations tracking table +-- This must be the first migration applied. The bootstrap logic in +-- database.go applies it directly (bypassing the normal migration +-- loop) when the schema_migrations table does not yet exist. + +CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP +);