package database import ( "context" "embed" "fmt" "io/fs" "sort" "strings" ) //go:embed migrations/*.sql var migrationsFS embed.FS func (d *Database) migrate(ctx context.Context) error { // Create migrations table if not exists _, err := d.database.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP ) `) if err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Get list of migration files entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return fmt.Errorf("failed to read migrations directory: %w", err) } // Sort migrations by name migrations := make([]string, 0, len(entries)) for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) // Apply each migration for _, migration := range migrations { applied, err := d.isMigrationApplied(ctx, migration) if err != nil { return fmt.Errorf("failed to check migration %s: %w", migration, err) } if applied { d.log.Debug("migration already applied", "migration", migration) continue } err = d.applyMigration(ctx, migration) if err != nil { return fmt.Errorf("failed to apply migration %s: %w", migration, err) } d.log.Info("migration applied", "migration", migration) } return nil } func (d *Database) isMigrationApplied(ctx context.Context, version string) (bool, error) { var count int err := d.database.QueryRowContext( ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return false, fmt.Errorf("failed to query migration status: %w", err) } return count > 0, nil } func (d *Database) applyMigration(ctx context.Context, filename string) error { content, err := migrationsFS.ReadFile("migrations/" + filename) if err != nil { return fmt.Errorf("failed to read migration file: %w", err) } transaction, err := d.database.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err != nil { _ = transaction.Rollback() } }() // Execute migration _, err = transaction.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to execute migration: %w", err) } // Record migration _, err = transaction.ExecContext( ctx, "INSERT INTO schema_migrations (version) VALUES (?)", filename, ) if err != nil { return fmt.Errorf("failed to record migration: %w", err) } commitErr := transaction.Commit() if commitErr != nil { return fmt.Errorf("failed to commit migration: %w", commitErr) } return nil }