When Commit() failed, the error was stored in commitErr instead of err, so the deferred rollback (which checks err) was skipped.
123 lines
2.7 KiB
Go
123 lines
2.7 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationsFS embed.FS
|
|
|
|
func (d *Database) migrate(ctx context.Context) error {
|
|
// Create migrations table if not exists
|
|
_, err := d.database.ExecContext(ctx, `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version TEXT PRIMARY KEY,
|
|
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
// Get list of migration files
|
|
entries, err := fs.ReadDir(migrationsFS, "migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read migrations directory: %w", err)
|
|
}
|
|
|
|
// Sort migrations by name
|
|
migrations := make([]string, 0, len(entries))
|
|
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
|
|
migrations = append(migrations, entry.Name())
|
|
}
|
|
}
|
|
|
|
sort.Strings(migrations)
|
|
|
|
// Apply each migration
|
|
for _, migration := range migrations {
|
|
applied, err := d.isMigrationApplied(ctx, migration)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check migration %s: %w", migration, err)
|
|
}
|
|
|
|
if applied {
|
|
d.log.Debug("migration already applied", "migration", migration)
|
|
|
|
continue
|
|
}
|
|
|
|
err = d.applyMigration(ctx, migration)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
|
|
}
|
|
|
|
d.log.Info("migration applied", "migration", migration)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Database) isMigrationApplied(ctx context.Context, version string) (bool, error) {
|
|
var count int
|
|
|
|
err := d.database.QueryRowContext(
|
|
ctx,
|
|
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
|
|
version,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to query migration status: %w", err)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
func (d *Database) applyMigration(ctx context.Context, filename string) error {
|
|
content, err := migrationsFS.ReadFile("migrations/" + filename)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read migration file: %w", err)
|
|
}
|
|
|
|
transaction, err := d.database.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
_ = transaction.Rollback()
|
|
}
|
|
}()
|
|
|
|
// Execute migration
|
|
_, err = transaction.ExecContext(ctx, string(content))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to execute migration: %w", err)
|
|
}
|
|
|
|
// Record migration
|
|
_, err = transaction.ExecContext(
|
|
ctx,
|
|
"INSERT INTO schema_migrations (version) VALUES (?)",
|
|
filename,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to record migration: %w", err)
|
|
}
|
|
|
|
err = transaction.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to commit migration: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|