package database import ( "context" "database/sql" "testing" _ "modernc.org/sqlite" // SQLite driver registration ) // openTestDB returns a fresh in-memory SQLite database. func openTestDB(t *testing.T) *sql.DB { t.Helper() db, err := sql.Open("sqlite", ":memory:") if err != nil { t.Fatalf("failed to open test db: %v", err) } t.Cleanup(func() { db.Close() }) return db } func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) { db := openTestDB(t) if err := ApplyMigrations(db); err != nil { t.Fatalf("ApplyMigrations failed: %v", err) } // The schema_migrations table must exist and contain at least // version "000" (the bootstrap) and "001" (the initial schema). rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") if err != nil { t.Fatalf("failed to query schema_migrations: %v", err) } defer rows.Close() var versions []string for rows.Next() { var v string if err := rows.Scan(&v); err != nil { t.Fatalf("failed to scan version: %v", err) } versions = append(versions, v) } if err := rows.Err(); err != nil { t.Fatalf("row iteration error: %v", err) } if len(versions) < 2 { t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions) } if versions[0] != "000" { t.Errorf("first recorded migration = %q, want %q", versions[0], "000") } if versions[1] != "001" { t.Errorf("second recorded migration = %q, want %q", versions[1], "001") } // Verify that the application tables created by 001.sql exist. for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} { var count int err := db.QueryRow( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table, ).Scan(&count) if err != nil { t.Fatalf("failed to check for table %s: %v", table, err) } if count != 1 { t.Errorf("table %s does not exist after migrations", table) } } } func TestApplyMigrations_Idempotent(t *testing.T) { db := openTestDB(t) if err := ApplyMigrations(db); err != nil { t.Fatalf("first ApplyMigrations failed: %v", err) } // Running a second time must succeed without errors. if err := ApplyMigrations(db); err != nil { t.Fatalf("second ApplyMigrations failed: %v", err) } // Verify no duplicate rows in schema_migrations. var count int err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = '000'").Scan(&count) if err != nil { t.Fatalf("failed to count 000 rows: %v", err) } if count != 1 { t.Errorf("expected exactly 1 row for version 000, got %d", count) } } func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) { db := openTestDB(t) ctx := context.Background() if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { t.Fatalf("bootstrapMigrationsTable failed: %v", err) } // schema_migrations table must exist. var tableCount int err := db.QueryRow( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", ).Scan(&tableCount) if err != nil { t.Fatalf("failed to check for table: %v", err) } if tableCount != 1 { t.Fatalf("schema_migrations table not created") } // Version "000" must be recorded. var recorded int err = db.QueryRow( "SELECT COUNT(*) FROM schema_migrations WHERE version = '000'", ).Scan(&recorded) if err != nil { t.Fatalf("failed to check version: %v", err) } if recorded != 1 { t.Errorf("expected version 000 to be recorded, got count %d", recorded) } } func TestBootstrapMigrationsTable_ExistingTableBackwardsCompat(t *testing.T) { db := openTestDB(t) ctx := context.Background() // Simulate an older database that created the table via inline SQL // (without recording version "000"). _, err := db.Exec(` CREATE TABLE schema_migrations ( version TEXT PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP ) `) if err != nil { t.Fatalf("failed to create legacy table: %v", err) } // Insert a fake migration to prove the table already existed. _, err = db.Exec("INSERT INTO schema_migrations (version) VALUES ('001')") if err != nil { t.Fatalf("failed to insert legacy version: %v", err) } if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { t.Fatalf("bootstrapMigrationsTable failed: %v", err) } // Version "000" must now be recorded. var recorded int err = db.QueryRow( "SELECT COUNT(*) FROM schema_migrations WHERE version = '000'", ).Scan(&recorded) if err != nil { t.Fatalf("failed to check version: %v", err) } if recorded != 1 { t.Errorf("expected version 000 to be recorded for legacy DB, got count %d", recorded) } // The existing "001" row must still be there. var legacyCount int err = db.QueryRow( "SELECT COUNT(*) FROM schema_migrations WHERE version = '001'", ).Scan(&legacyCount) if err != nil { t.Fatalf("failed to check legacy version: %v", err) } if legacyCount != 1 { t.Errorf("legacy version 001 row missing after bootstrap") } }