diff --git a/internal/database/database.go b/internal/database/database.go index be80f1c..62b2ee4 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -35,6 +35,41 @@ type Database struct { config *config.Config } +// ParseMigrationVersion extracts the numeric version prefix from a migration +// filename. Filenames must follow the pattern ".sql" or +// "_.sql", where version is a zero-padded numeric +// string (e.g. "001", "002"). Returns the version string and an error if +// the filename does not match the expected pattern. +func ParseMigrationVersion(filename string) (string, error) { + name := strings.TrimSuffix(filename, filepath.Ext(filename)) + if name == "" { + return "", fmt.Errorf("invalid migration filename %q: empty name", filename) + } + + // Split on underscore to separate version from description. + // If there's no underscore, the entire stem is the version. + version := name + if idx := strings.IndexByte(name, '_'); idx >= 0 { + version = name[:idx] + } + + if version == "" { + return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename) + } + + // Validate the version is purely numeric. + for _, ch := range version { + if ch < '0' || ch > '9' { + return "", fmt.Errorf( + "invalid migration filename %q: version %q contains non-numeric character %q", + filename, version, string(ch), + ) + } + } + + return version, nil +} + // New creates a new Database instance. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := &Database{ @@ -87,9 +122,31 @@ func (s *Database) connect(ctx context.Context) error { return s.runMigrations(ctx) } -func (s *Database) runMigrations(ctx context.Context) error { - // Create migrations tracking table - _, err := s.db.ExecContext(ctx, ` +// collectMigrations reads the embedded schema directory and returns +// migration filenames sorted lexicographically. +func collectMigrations() ([]string, error) { + entries, err := schemaFS.ReadDir("schema") + if err != nil { + return nil, fmt.Errorf("failed to read schema directory: %w", err) + } + + var migrations []string + + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { + migrations = append(migrations, entry.Name()) + } + } + + sort.Strings(migrations) + + return migrations, nil +} + +// ensureMigrationsTable creates the schema_migrations tracking table if +// it does not already exist. +func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP @@ -99,28 +156,31 @@ func (s *Database) runMigrations(ctx context.Context) error { return fmt.Errorf("failed to create migrations table: %w", err) } - // Get list of migration files - entries, err := schemaFS.ReadDir("schema") + return nil +} + +// applyMigrations applies all pending migrations to db, using log for +// informational output (may be nil for silent operation). +func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { + if err := ensureMigrationsTable(ctx, db); err != nil { + return err + } + + migrations, err := collectMigrations() if err != nil { - return fmt.Errorf("failed to read schema directory: %w", err) + return err } - // Sort migration files by name (001.sql, 002.sql, etc.) - var migrations []string - for _, entry := range entries { - if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { - migrations = append(migrations, entry.Name()) - } - } - sort.Strings(migrations) - - // Apply each migration that hasn't been applied yet for _, migration := range migrations { - version := strings.TrimSuffix(migration, filepath.Ext(migration)) + version, parseErr := ParseMigrationVersion(migration) + if parseErr != nil { + return parseErr + } - // Check if already applied + // Check if already applied. var count int - err := s.db.QueryRowContext(ctx, + + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) @@ -129,39 +189,49 @@ func (s *Database) runMigrations(ctx context.Context) error { } if count > 0 { - s.log.Debug("migration already applied", "version", version) + if log != nil { + log.Debug("migration already applied", "version", version) + } continue } - // Read and apply migration - content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) - if err != nil { - return fmt.Errorf("failed to read migration %s: %w", migration, err) + // Read and apply migration. + content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration)) + if readErr != nil { + return fmt.Errorf("failed to read migration %s: %w", migration, readErr) } - s.log.Info("applying migration", "version", version) - - _, err = s.db.ExecContext(ctx, string(content)) - if err != nil { - return fmt.Errorf("failed to apply migration %s: %w", migration, err) + if log != nil { + log.Info("applying migration", "version", version) } - // Record migration as applied - _, err = s.db.ExecContext(ctx, + _, execErr := db.ExecContext(ctx, string(content)) + if execErr != nil { + return fmt.Errorf("failed to apply migration %s: %w", migration, execErr) + } + + // Record migration as applied. + _, recErr := db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) - if err != nil { - return fmt.Errorf("failed to record migration %s: %w", migration, err) + if recErr != nil { + return fmt.Errorf("failed to record migration %s: %w", migration, recErr) } - s.log.Info("migration applied successfully", "version", version) + if log != nil { + log.Info("migration applied successfully", "version", version) + } } return nil } +func (s *Database) runMigrations(ctx context.Context) error { + return applyMigrations(ctx, s.db, s.log) +} + // DB returns the underlying sql.DB. func (s *Database) DB() *sql.DB { return s.db @@ -171,72 +241,5 @@ func (s *Database) DB() *sql.DB { // This is useful for testing where you want to use the real schema // without the full fx lifecycle. func ApplyMigrations(db *sql.DB) error { - ctx := context.Background() - - // Create migrations tracking table - _, err := db.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) - } - - // Get list of migration files - entries, err := schemaFS.ReadDir("schema") - if err != nil { - return fmt.Errorf("failed to read schema directory: %w", err) - } - - // Sort migration files by name (001.sql, 002.sql, etc.) - var migrations []string - for _, entry := range entries { - if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { - migrations = append(migrations, entry.Name()) - } - } - sort.Strings(migrations) - - // Apply each migration that hasn't been applied yet - for _, migration := range migrations { - version := strings.TrimSuffix(migration, filepath.Ext(migration)) - - // Check if already applied - var count int - err := db.QueryRowContext(ctx, - "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", - version, - ).Scan(&count) - if err != nil { - return fmt.Errorf("failed to check migration status: %w", err) - } - - if count > 0 { - continue - } - - // Read and apply migration - content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) - if err != nil { - return fmt.Errorf("failed to read migration %s: %w", migration, err) - } - - _, err = db.ExecContext(ctx, string(content)) - if err != nil { - return fmt.Errorf("failed to apply migration %s: %w", migration, err) - } - - // Record migration as applied - _, err = db.ExecContext(ctx, - "INSERT INTO schema_migrations (version) VALUES (?)", - version, - ) - if err != nil { - return fmt.Errorf("failed to record migration %s: %w", migration, err) - } - } - - return nil + return applyMigrations(context.Background(), db, nil) } diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 0000000..ab893ea --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,155 @@ +package database + +import ( + "context" + "database/sql" + "testing" + + _ "modernc.org/sqlite" // SQLite driver registration +) + +func TestParseMigrationVersion(t *testing.T) { + tests := []struct { + name string + filename string + want string + wantErr bool + }{ + { + name: "version only", + filename: "001.sql", + want: "001", + }, + { + name: "version with description", + filename: "001_initial_schema.sql", + want: "001", + }, + { + name: "multi-digit version", + filename: "042_add_indexes.sql", + want: "042", + }, + { + name: "long version number", + filename: "00001_long_prefix.sql", + want: "00001", + }, + { + name: "description with multiple underscores", + filename: "003_add_user_auth_tables.sql", + want: "003", + }, + { + name: "empty filename", + filename: ".sql", + wantErr: true, + }, + { + name: "leading underscore", + filename: "_description.sql", + wantErr: true, + }, + { + name: "non-numeric version", + filename: "abc_migration.sql", + wantErr: true, + }, + { + name: "mixed alphanumeric version", + filename: "001a_migration.sql", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseMigrationVersion(tt.filename) + if tt.wantErr { + if err == nil { + t.Errorf("ParseMigrationVersion(%q) expected error, got %q", tt.filename, got) + } + + return + } + + if err != nil { + t.Errorf("ParseMigrationVersion(%q) unexpected error: %v", tt.filename, err) + + return + } + + if got != tt.want { + t.Errorf("ParseMigrationVersion(%q) = %q, want %q", tt.filename, got, tt.want) + } + }) + } +} + +func TestApplyMigrations(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to open in-memory database: %v", err) + } + defer db.Close() + + // Apply migrations should succeed. + if err := ApplyMigrations(db); err != nil { + t.Fatalf("ApplyMigrations failed: %v", err) + } + + // Verify the schema_migrations table recorded the version. + var version string + + err = db.QueryRowContext(context.Background(), + "SELECT version FROM schema_migrations LIMIT 1", + ).Scan(&version) + if err != nil { + t.Fatalf("failed to query schema_migrations: %v", err) + } + + if version != "001" { + t.Errorf("expected version %q, got %q", "001", version) + } + + // Verify a table from the migration exists (source_content). + var tableName string + + err = db.QueryRowContext(context.Background(), + "SELECT name FROM sqlite_master WHERE type='table' AND name='source_content'", + ).Scan(&tableName) + if err != nil { + t.Fatalf("expected source_content table to exist: %v", err) + } +} + +func TestApplyMigrationsIdempotent(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to open in-memory database: %v", err) + } + defer db.Close() + + // Apply twice should succeed (idempotent). + if err := ApplyMigrations(db); err != nil { + t.Fatalf("first ApplyMigrations failed: %v", err) + } + + if err := ApplyMigrations(db); err != nil { + t.Fatalf("second ApplyMigrations failed: %v", err) + } + + // Should still have exactly one migration recorded. + var count int + + err = db.QueryRowContext(context.Background(), + "SELECT COUNT(*) FROM schema_migrations", + ).Scan(&count) + if err != nil { + t.Fatalf("failed to count schema_migrations: %v", err) + } + + if count != 1 { + t.Errorf("expected 1 migration record, got %d", count) + } +} diff --git a/internal/database/schema/001.sql b/internal/database/schema/001_initial_schema.sql similarity index 100% rename from internal/database/schema/001.sql rename to internal/database/schema/001_initial_schema.sql