// Package database provides SQLite database access. package database import ( "context" "database/sql" "embed" "fmt" "log/slog" "path/filepath" "sort" "strings" "go.uber.org/fx" "sneak.berlin/go/pixa/internal/config" "sneak.berlin/go/pixa/internal/logger" _ "modernc.org/sqlite" // SQLite driver registration ) //go:embed schema/*.sql var schemaFS embed.FS // Params defines dependencies for Database. type Params struct { fx.In Logger *logger.Logger Config *config.Config } // Database wraps the SQL database connection. type Database struct { db *sql.DB log *slog.Logger config *config.Config } // New creates a new Database instance. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := &Database{ log: params.Logger.Get(), config: params.Config, } s.log.Info("Database instantiated") lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { s.log.Info("Database OnStart Hook") return s.connect(ctx) }, OnStop: func(_ context.Context) error { s.log.Info("Database OnStop Hook") if s.db != nil { return s.db.Close() } return nil }, }) return s, nil } func (s *Database) connect(ctx context.Context) error { dbURL := s.config.DBURL s.log.Info("connecting to database", "url", dbURL) db, err := sql.Open("sqlite", dbURL) if err != nil { s.log.Error("failed to open database", "error", err) return err } if err := db.PingContext(ctx); err != nil { s.log.Error("failed to ping database", "error", err) return err } s.db = db s.log.Info("database connected") return s.runMigrations(ctx) } func (s *Database) runMigrations(ctx context.Context) error { // Create migrations tracking table _, err := s.db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP ) `) if err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Get list of migration files entries, err := schemaFS.ReadDir("schema") if err != nil { return fmt.Errorf("failed to read schema directory: %w", err) } // Sort migration files by name (001.sql, 002.sql, etc.) var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) // Apply each migration that hasn't been applied yet for _, migration := range migrations { version := strings.TrimSuffix(migration, filepath.Ext(migration)) // Check if already applied var count int err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return fmt.Errorf("failed to check migration status: %w", err) } if count > 0 { s.log.Debug("migration already applied", "version", version) continue } // Read and apply migration content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) if err != nil { return fmt.Errorf("failed to read migration %s: %w", migration, err) } s.log.Info("applying migration", "version", version) _, err = s.db.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to apply migration %s: %w", migration, err) } // Record migration as applied _, err = s.db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) if err != nil { return fmt.Errorf("failed to record migration %s: %w", migration, err) } s.log.Info("migration applied successfully", "version", version) } return nil } // DB returns the underlying sql.DB. func (s *Database) DB() *sql.DB { return s.db } // ApplyMigrations applies all migrations to the given database. // This is useful for testing where you want to use the real schema // without the full fx lifecycle. func ApplyMigrations(db *sql.DB) error { ctx := context.Background() // Create migrations tracking table _, err := db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP ) `) if err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Get list of migration files entries, err := schemaFS.ReadDir("schema") if err != nil { return fmt.Errorf("failed to read schema directory: %w", err) } // Sort migration files by name (001.sql, 002.sql, etc.) var migrations []string for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { migrations = append(migrations, entry.Name()) } } sort.Strings(migrations) // Apply each migration that hasn't been applied yet for _, migration := range migrations { version := strings.TrimSuffix(migration, filepath.Ext(migration)) // Check if already applied var count int err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version, ).Scan(&count) if err != nil { return fmt.Errorf("failed to check migration status: %w", err) } if count > 0 { continue } // Read and apply migration content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) if err != nil { return fmt.Errorf("failed to read migration %s: %w", migration, err) } _, err = db.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("failed to apply migration %s: %w", migration, err) } // Record migration as applied _, err = db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", version, ) if err != nil { return fmt.Errorf("failed to record migration %s: %w", migration, err) } } return nil }