diff --git a/internal/database/database.go b/internal/database/database.go index ddc01ed..c4af7e2 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -9,6 +9,7 @@ import ( "log/slog" "path/filepath" "sort" + "strconv" "strings" "go.uber.org/fx" @@ -21,6 +22,10 @@ import ( //go:embed schema/*.sql var schemaFS embed.FS +// bootstrapVersion is the migration that creates the schema_migrations +// table itself. It is applied before the normal migration loop. +const bootstrapVersion = 0 + // Params defines dependencies for Database. type Params struct { fx.In @@ -38,35 +43,40 @@ type Database struct { // ParseMigrationVersion extracts the numeric version prefix from a migration // filename. Filenames must follow the pattern ".sql" or // "_.sql", where version is a zero-padded numeric -// string (e.g. "001", "002"). Returns the version string and an error if -// the filename does not match the expected pattern. -func ParseMigrationVersion(filename string) (string, error) { +// string (e.g. "001", "002"). Returns the version as an integer and an +// error if the filename does not match the expected pattern. +func ParseMigrationVersion(filename string) (int, error) { name := strings.TrimSuffix(filename, filepath.Ext(filename)) if name == "" { - return "", fmt.Errorf("invalid migration filename %q: empty name", filename) + return 0, fmt.Errorf("invalid migration filename %q: empty name", filename) } // Split on underscore to separate version from description. // If there's no underscore, the entire stem is the version. - version := name + versionStr := name if idx := strings.IndexByte(name, '_'); idx >= 0 { - version = name[:idx] + versionStr = name[:idx] } - if version == "" { - return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename) + if versionStr == "" { + return 0, fmt.Errorf("invalid migration filename %q: empty version prefix", filename) } // Validate the version is purely numeric. - for _, ch := range version { + for _, ch := range versionStr { if ch < '0' || ch > '9' { - return "", fmt.Errorf( + return 0, fmt.Errorf( "invalid migration filename %q: version %q contains non-numeric character %q", - filename, version, string(ch), + filename, versionStr, string(ch), ) } } + version, err := strconv.Atoi(versionStr) + if err != nil { + return 0, fmt.Errorf("invalid migration filename %q: %w", filename, err) + } + return version, nil } @@ -143,17 +153,34 @@ func collectMigrations() ([]string, error) { return migrations, nil } -// ensureMigrationsTable creates the schema_migrations tracking table if -// it does not already exist. -func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { - _, err := db.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) +// bootstrapMigrationsTable ensures the schema_migrations table exists +// by applying 000.sql if the table is missing. +func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error { + var tableExists int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableExists) if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) + return fmt.Errorf("failed to check for migrations table: %w", err) + } + + if tableExists > 0 { + return nil + } + + content, err := schemaFS.ReadFile("schema/000.sql") + if err != nil { + return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) + } + + if log != nil { + log.Info("applying bootstrap migration", "version", bootstrapVersion) + } + + _, err = db.ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("failed to apply bootstrap migration: %w", err) } return nil @@ -164,7 +191,7 @@ func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { // This is exported so tests can apply the real schema without the full fx // lifecycle. func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { - if err := ensureMigrationsTable(ctx, db); err != nil { + if err := bootstrapMigrationsTable(ctx, db, log); err != nil { return err } diff --git a/internal/database/database_test.go b/internal/database/database_test.go index 9c2fc2c..015ae22 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -8,37 +8,51 @@ import ( _ "modernc.org/sqlite" // SQLite driver registration ) +// openTestDB returns a fresh in-memory SQLite database. +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to open test db: %v", err) + } + + t.Cleanup(func() { db.Close() }) + + return db +} + func TestParseMigrationVersion(t *testing.T) { tests := []struct { name string filename string - want string + want int wantErr bool }{ { name: "version only", filename: "001.sql", - want: "001", + want: 1, }, { name: "version with description", filename: "001_initial_schema.sql", - want: "001", + want: 1, }, { name: "multi-digit version", filename: "042_add_indexes.sql", - want: "042", + want: 42, }, { name: "long version number", filename: "00001_long_prefix.sql", - want: "00001", + want: 1, }, { name: "description with multiple underscores", filename: "003_add_user_auth_tables.sql", - want: "003", + want: 3, }, { name: "empty filename", @@ -67,7 +81,7 @@ func TestParseMigrationVersion(t *testing.T) { got, err := ParseMigrationVersion(tt.filename) if tt.wantErr { if err == nil { - t.Errorf("ParseMigrationVersion(%q) expected error, got %q", tt.filename, got) + t.Errorf("ParseMigrationVersion(%q) expected error, got %d", tt.filename, got) } return @@ -80,76 +94,131 @@ func TestParseMigrationVersion(t *testing.T) { } if got != tt.want { - t.Errorf("ParseMigrationVersion(%q) = %q, want %q", tt.filename, got, tt.want) + t.Errorf("ParseMigrationVersion(%q) = %d, want %d", tt.filename, got, tt.want) } }) } } -func TestApplyMigrations(t *testing.T) { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("failed to open in-memory database: %v", err) - } - defer db.Close() +func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() - // Apply migrations should succeed. - if err := ApplyMigrations(context.Background(), db, nil); err != nil { + if err := ApplyMigrations(ctx, db, nil); err != nil { t.Fatalf("ApplyMigrations failed: %v", err) } - // Verify the schema_migrations table recorded the version. - var version string - - err = db.QueryRowContext(context.Background(), - "SELECT version FROM schema_migrations LIMIT 1", - ).Scan(&version) + // The schema_migrations table must exist and contain at least + // version 0 (the bootstrap) and 1 (the initial schema). + rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") if err != nil { t.Fatalf("failed to query schema_migrations: %v", err) } + defer rows.Close() - if version != "001" { - t.Errorf("expected version %q, got %q", "001", version) + var versions []int + for rows.Next() { + var v int + if err := rows.Scan(&v); err != nil { + t.Fatalf("failed to scan version: %v", err) + } + + versions = append(versions, v) } - // Verify a table from the migration exists (source_content). - var tableName string + if err := rows.Err(); err != nil { + t.Fatalf("row iteration error: %v", err) + } - err = db.QueryRowContext(context.Background(), - "SELECT name FROM sqlite_master WHERE type='table' AND name='source_content'", - ).Scan(&tableName) - if err != nil { - t.Fatalf("expected source_content table to exist: %v", err) + if len(versions) < 2 { + t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions) + } + + if versions[0] != 0 { + t.Errorf("first recorded migration = %d, want %d", versions[0], 0) + } + + if versions[1] != 1 { + t.Errorf("second recorded migration = %d, want %d", versions[1], 1) + } + + // Verify that the application tables created by 001.sql exist. + for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} { + var count int + + err := db.QueryRow( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", + table, + ).Scan(&count) + if err != nil { + t.Fatalf("failed to check for table %s: %v", table, err) + } + + if count != 1 { + t.Errorf("table %s does not exist after migrations", table) + } } } -func TestApplyMigrationsIdempotent(t *testing.T) { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("failed to open in-memory database: %v", err) - } - defer db.Close() +func TestApplyMigrations_Idempotent(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() - // Apply twice should succeed (idempotent). - if err := ApplyMigrations(context.Background(), db, nil); err != nil { + if err := ApplyMigrations(ctx, db, nil); err != nil { t.Fatalf("first ApplyMigrations failed: %v", err) } - if err := ApplyMigrations(context.Background(), db, nil); err != nil { + // Running a second time must succeed without errors. + if err := ApplyMigrations(ctx, db, nil); err != nil { t.Fatalf("second ApplyMigrations failed: %v", err) } - // Should still have exactly one migration recorded. + // Verify no duplicate rows in schema_migrations. var count int - err = db.QueryRowContext(context.Background(), - "SELECT COUNT(*) FROM schema_migrations", - ).Scan(&count) + err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = 0").Scan(&count) if err != nil { - t.Fatalf("failed to count schema_migrations: %v", err) + t.Fatalf("failed to count version 0 rows: %v", err) } if count != 1 { - t.Errorf("expected 1 migration record, got %d", count) + t.Errorf("expected exactly 1 row for version 0, got %d", count) + } +} + +func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + + if err := bootstrapMigrationsTable(ctx, db, nil); err != nil { + t.Fatalf("bootstrapMigrationsTable failed: %v", err) + } + + // schema_migrations table must exist. + var tableCount int + + err := db.QueryRow( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableCount) + if err != nil { + t.Fatalf("failed to check for table: %v", err) + } + + if tableCount != 1 { + t.Fatalf("schema_migrations table not created") + } + + // Version 0 must be recorded. + var recorded int + + err = db.QueryRow( + "SELECT COUNT(*) FROM schema_migrations WHERE version = 0", + ).Scan(&recorded) + if err != nil { + t.Fatalf("failed to check version: %v", err) + } + + if recorded != 1 { + t.Errorf("expected version 0 to be recorded, got count %d", recorded) } } diff --git a/internal/database/schema/000.sql b/internal/database/schema/000.sql new file mode 100644 index 0000000..e06a2da --- /dev/null +++ b/internal/database/schema/000.sql @@ -0,0 +1,9 @@ +-- Migration 000: Schema migrations tracking table +-- Applied as a bootstrap step before the normal migration loop. + +CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +INSERT OR IGNORE INTO schema_migrations (version) VALUES (0);