Všechny kontroly byly úspěšné
check / check (push) Successful in 1m8s
Remove the unexported applyMigrations() and the runMigrations() method. ApplyMigrations() is now the single implementation, accepting context and an optional logger. connect() calls it directly. All callers updated to pass context.Background() and nil logger.
237 řádky
5.5 KiB
Go
237 řádky
5.5 KiB
Go
// Package database provides SQLite database access.
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"log/slog"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"go.uber.org/fx"
|
|
"sneak.berlin/go/pixa/internal/config"
|
|
"sneak.berlin/go/pixa/internal/logger"
|
|
|
|
_ "modernc.org/sqlite" // SQLite driver registration
|
|
)
|
|
|
|
//go:embed schema/*.sql
|
|
var schemaFS embed.FS
|
|
|
|
// Params defines dependencies for Database.
|
|
type Params struct {
|
|
fx.In
|
|
Logger *logger.Logger
|
|
Config *config.Config
|
|
}
|
|
|
|
// Database wraps the SQL database connection.
|
|
type Database struct {
|
|
db *sql.DB
|
|
log *slog.Logger
|
|
config *config.Config
|
|
}
|
|
|
|
// ParseMigrationVersion extracts the numeric version prefix from a migration
|
|
// filename. Filenames must follow the pattern "<version>.sql" or
|
|
// "<version>_<description>.sql", where version is a zero-padded numeric
|
|
// string (e.g. "001", "002"). Returns the version string and an error if
|
|
// the filename does not match the expected pattern.
|
|
func ParseMigrationVersion(filename string) (string, error) {
|
|
name := strings.TrimSuffix(filename, filepath.Ext(filename))
|
|
if name == "" {
|
|
return "", fmt.Errorf("invalid migration filename %q: empty name", filename)
|
|
}
|
|
|
|
// Split on underscore to separate version from description.
|
|
// If there's no underscore, the entire stem is the version.
|
|
version := name
|
|
if idx := strings.IndexByte(name, '_'); idx >= 0 {
|
|
version = name[:idx]
|
|
}
|
|
|
|
if version == "" {
|
|
return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename)
|
|
}
|
|
|
|
// Validate the version is purely numeric.
|
|
for _, ch := range version {
|
|
if ch < '0' || ch > '9' {
|
|
return "", fmt.Errorf(
|
|
"invalid migration filename %q: version %q contains non-numeric character %q",
|
|
filename, version, string(ch),
|
|
)
|
|
}
|
|
}
|
|
|
|
return version, nil
|
|
}
|
|
|
|
// New creates a new Database instance.
|
|
func New(lc fx.Lifecycle, params Params) (*Database, error) {
|
|
s := &Database{
|
|
log: params.Logger.Get(),
|
|
config: params.Config,
|
|
}
|
|
|
|
s.log.Info("Database instantiated")
|
|
|
|
lc.Append(fx.Hook{
|
|
OnStart: func(ctx context.Context) error {
|
|
s.log.Info("Database OnStart Hook")
|
|
|
|
return s.connect(ctx)
|
|
},
|
|
OnStop: func(_ context.Context) error {
|
|
s.log.Info("Database OnStop Hook")
|
|
if s.db != nil {
|
|
return s.db.Close()
|
|
}
|
|
|
|
return nil
|
|
},
|
|
})
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Database) connect(ctx context.Context) error {
|
|
dbURL := s.config.DBURL
|
|
|
|
s.log.Info("connecting to database", "url", dbURL)
|
|
|
|
db, err := sql.Open("sqlite", dbURL)
|
|
if err != nil {
|
|
s.log.Error("failed to open database", "error", err)
|
|
|
|
return err
|
|
}
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
s.log.Error("failed to ping database", "error", err)
|
|
|
|
return err
|
|
}
|
|
|
|
s.db = db
|
|
s.log.Info("database connected")
|
|
|
|
return ApplyMigrations(ctx, s.db, s.log)
|
|
}
|
|
|
|
// collectMigrations reads the embedded schema directory and returns
|
|
// migration filenames sorted lexicographically.
|
|
func collectMigrations() ([]string, error) {
|
|
entries, err := schemaFS.ReadDir("schema")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read schema directory: %w", err)
|
|
}
|
|
|
|
var migrations []string
|
|
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
|
|
migrations = append(migrations, entry.Name())
|
|
}
|
|
}
|
|
|
|
sort.Strings(migrations)
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// ensureMigrationsTable creates the schema_migrations tracking table if
|
|
// it does not already exist.
|
|
func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version TEXT PRIMARY KEY,
|
|
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ApplyMigrations applies all pending migrations to db. An optional logger
|
|
// may be provided for informational output; pass nil for silent operation.
|
|
// This is exported so tests can apply the real schema without the full fx
|
|
// lifecycle.
|
|
func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
|
|
if err := ensureMigrationsTable(ctx, db); err != nil {
|
|
return err
|
|
}
|
|
|
|
migrations, err := collectMigrations()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
version, parseErr := ParseMigrationVersion(migration)
|
|
if parseErr != nil {
|
|
return parseErr
|
|
}
|
|
|
|
// Check if already applied.
|
|
var count int
|
|
|
|
err := db.QueryRowContext(ctx,
|
|
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
|
|
version,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check migration status: %w", err)
|
|
}
|
|
|
|
if count > 0 {
|
|
if log != nil {
|
|
log.Debug("migration already applied", "version", version)
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
// Read and apply migration.
|
|
content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration))
|
|
if readErr != nil {
|
|
return fmt.Errorf("failed to read migration %s: %w", migration, readErr)
|
|
}
|
|
|
|
if log != nil {
|
|
log.Info("applying migration", "version", version)
|
|
}
|
|
|
|
_, execErr := db.ExecContext(ctx, string(content))
|
|
if execErr != nil {
|
|
return fmt.Errorf("failed to apply migration %s: %w", migration, execErr)
|
|
}
|
|
|
|
// Record migration as applied.
|
|
_, recErr := db.ExecContext(ctx,
|
|
"INSERT INTO schema_migrations (version) VALUES (?)",
|
|
version,
|
|
)
|
|
if recErr != nil {
|
|
return fmt.Errorf("failed to record migration %s: %w", migration, recErr)
|
|
}
|
|
|
|
if log != nil {
|
|
log.Info("migration applied successfully", "version", version)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DB returns the underlying sql.DB.
|
|
func (s *Database) DB() *sql.DB {
|
|
return s.db
|
|
}
|