diff --git a/internal/database/database.go b/internal/database/database.go index ddc01ed..b724b89 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -21,6 +21,10 @@ import ( //go:embed schema/*.sql var schemaFS embed.FS +// bootstrapVersion is the migration that creates the schema_migrations +// table itself. It is applied before the normal migration loop. +const bootstrapVersion = "000" + // Params defines dependencies for Database. type Params struct { fx.In @@ -143,17 +147,86 @@ func collectMigrations() ([]string, error) { return migrations, nil } -// ensureMigrationsTable creates the schema_migrations tracking table if -// it does not already exist. -func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { - _, err := db.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) +// bootstrapMigrationsTable ensures the schema_migrations table exists +// by applying 000.sql directly. For databases that already have the +// table (created by older code), it records version "000" for +// consistency. +func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { + var tableExists int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableExists) if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) + return fmt.Errorf("failed to check for migrations table: %w", err) + } + + if tableExists > 0 { + return ensureBootstrapVersionRecorded(ctx, db, log) + } + + return applyBootstrapMigration(ctx, db, log) +} + +// ensureBootstrapVersionRecorded checks whether version "000" is already +// recorded in an existing schema_migrations table and inserts it if not. +func ensureBootstrapVersionRecorded(ctx context.Context, db *sql.DB, log *slog.Logger) error { + var recorded int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + bootstrapVersion, + ).Scan(&recorded) + if err != nil { + return fmt.Errorf("failed to check bootstrap migration status: %w", err) + } + + if recorded > 0 { + return nil + } + + _, err = db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + bootstrapVersion, + ) + if err != nil { + return fmt.Errorf("failed to record bootstrap migration: %w", err) + } + + if log != nil { + log.Info("recorded bootstrap migration for existing table", "version", bootstrapVersion) + } + + return nil +} + +// applyBootstrapMigration reads and executes 000.sql to create the +// schema_migrations table on a fresh database. +func applyBootstrapMigration(ctx context.Context, db *sql.DB, log *slog.Logger) error { + content, err := schemaFS.ReadFile("schema/000.sql") + if err != nil { + return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) + } + + if log != nil { + log.Info("applying bootstrap migration", "version", bootstrapVersion) + } + + _, err = db.ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("failed to apply bootstrap migration: %w", err) + } + + _, err = db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + bootstrapVersion, + ) + if err != nil { + return fmt.Errorf("failed to record bootstrap migration: %w", err) + } + + if log != nil { + log.Info("bootstrap migration applied successfully", "version", bootstrapVersion) } return nil @@ -164,7 +237,7 @@ func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { // This is exported so tests can apply the real schema without the full fx // lifecycle. func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { - if err := ensureMigrationsTable(ctx, db); err != nil { + if err := bootstrapMigrationsTable(ctx, db, log); err != nil { return err } diff --git a/internal/database/database_test.go b/internal/database/database_test.go index 9c2fc2c..ede0d1c 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -8,6 +8,20 @@ import ( _ "modernc.org/sqlite" // SQLite driver registration ) +// openTestDB returns a fresh in-memory SQLite database. +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to open test db: %v", err) + } + + t.Cleanup(func() { db.Close() }) + + return db +} + func TestParseMigrationVersion(t *testing.T) { tests := []struct { name string @@ -86,70 +100,180 @@ func TestParseMigrationVersion(t *testing.T) { } } -func TestApplyMigrations(t *testing.T) { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("failed to open in-memory database: %v", err) - } - defer db.Close() +func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() - // Apply migrations should succeed. - if err := ApplyMigrations(context.Background(), db, nil); err != nil { + if err := ApplyMigrations(ctx, db, nil); err != nil { t.Fatalf("ApplyMigrations failed: %v", err) } - // Verify the schema_migrations table recorded the version. - var version string - - err = db.QueryRowContext(context.Background(), - "SELECT version FROM schema_migrations LIMIT 1", - ).Scan(&version) + // The schema_migrations table must exist and contain at least + // version "000" (the bootstrap) and "001" (the initial schema). + rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") if err != nil { t.Fatalf("failed to query schema_migrations: %v", err) } + defer rows.Close() - if version != "001" { - t.Errorf("expected version %q, got %q", "001", version) + var versions []string + for rows.Next() { + var v string + if err := rows.Scan(&v); err != nil { + t.Fatalf("failed to scan version: %v", err) + } + + versions = append(versions, v) } - // Verify a table from the migration exists (source_content). - var tableName string + if err := rows.Err(); err != nil { + t.Fatalf("row iteration error: %v", err) + } - err = db.QueryRowContext(context.Background(), - "SELECT name FROM sqlite_master WHERE type='table' AND name='source_content'", - ).Scan(&tableName) - if err != nil { - t.Fatalf("expected source_content table to exist: %v", err) + if len(versions) < 2 { + t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions) + } + + if versions[0] != "000" { + t.Errorf("first recorded migration = %q, want %q", versions[0], "000") + } + + if versions[1] != "001" { + t.Errorf("second recorded migration = %q, want %q", versions[1], "001") + } + + // Verify that the application tables created by 001.sql exist. + for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} { + var count int + + err := db.QueryRow( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", + table, + ).Scan(&count) + if err != nil { + t.Fatalf("failed to check for table %s: %v", table, err) + } + + if count != 1 { + t.Errorf("table %s does not exist after migrations", table) + } } } -func TestApplyMigrationsIdempotent(t *testing.T) { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("failed to open in-memory database: %v", err) - } - defer db.Close() +func TestApplyMigrations_Idempotent(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() - // Apply twice should succeed (idempotent). - if err := ApplyMigrations(context.Background(), db, nil); err != nil { + if err := ApplyMigrations(ctx, db, nil); err != nil { t.Fatalf("first ApplyMigrations failed: %v", err) } - if err := ApplyMigrations(context.Background(), db, nil); err != nil { + // Running a second time must succeed without errors. + if err := ApplyMigrations(ctx, db, nil); err != nil { t.Fatalf("second ApplyMigrations failed: %v", err) } - // Should still have exactly one migration recorded. + // Verify no duplicate rows in schema_migrations. var count int - err = db.QueryRowContext(context.Background(), - "SELECT COUNT(*) FROM schema_migrations", - ).Scan(&count) + err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = '000'").Scan(&count) if err != nil { - t.Fatalf("failed to count schema_migrations: %v", err) + t.Fatalf("failed to count 000 rows: %v", err) } if count != 1 { - t.Errorf("expected 1 migration record, got %d", count) + t.Errorf("expected exactly 1 row for version 000, got %d", count) + } +} + +func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + + if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { + t.Fatalf("bootstrapMigrationsTable failed: %v", err) + } + + // schema_migrations table must exist. + var tableCount int + + err := db.QueryRow( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableCount) + if err != nil { + t.Fatalf("failed to check for table: %v", err) + } + + if tableCount != 1 { + t.Fatalf("schema_migrations table not created") + } + + // Version "000" must be recorded. + var recorded int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = '000'", + ).Scan(&recorded) + if err != nil { + t.Fatalf("failed to check version: %v", err) + } + + if recorded != 1 { + t.Errorf("expected version 000 to be recorded, got count %d", recorded) + } +} + +func TestBootstrapMigrationsTable_ExistingTableBackwardsCompat(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + + // Simulate an older database that created the table via inline SQL + // (without recording version "000"). + _, err := db.Exec(` + CREATE TABLE schema_migrations ( + version TEXT PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + t.Fatalf("failed to create legacy table: %v", err) + } + + // Insert a fake migration to prove the table already existed. + _, err = db.Exec("INSERT INTO schema_migrations (version) VALUES ('001')") + if err != nil { + t.Fatalf("failed to insert legacy version: %v", err) + } + + if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { + t.Fatalf("bootstrapMigrationsTable failed: %v", err) + } + + // Version "000" must now be recorded. + var recorded int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = '000'", + ).Scan(&recorded) + if err != nil { + t.Fatalf("failed to check version: %v", err) + } + + if recorded != 1 { + t.Errorf("expected version 000 to be recorded for legacy DB, got count %d", recorded) + } + + // The existing "001" row must still be there. + var legacyCount int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = '001'", + ).Scan(&legacyCount) + if err != nil { + t.Fatalf("failed to check legacy version: %v", err) + } + + if legacyCount != 1 { + t.Errorf("legacy version 001 row missing after bootstrap") } } diff --git a/internal/database/schema/000.sql b/internal/database/schema/000.sql new file mode 100644 index 0000000..c05b915 --- /dev/null +++ b/internal/database/schema/000.sql @@ -0,0 +1,9 @@ +-- Migration 000: Schema migrations tracking table +-- This must be the first migration applied. The bootstrap logic in +-- database.go applies it directly (bypassing the normal migration +-- loop) when the schema_migrations table does not yet exist. + +CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP +);