package db import ( "context" "database/sql" "embed" "fmt" "io/fs" "log/slog" "sort" "strconv" "strings" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/logger" "go.uber.org/fx" _ "github.com/joho/godotenv/autoload" _ "modernc.org/sqlite" ) //go:embed schema/*.sql var SchemaFiles embed.FS type DatabaseParams struct { fx.In Logger *logger.Logger Config *config.Config } type Database struct { DB *sql.DB log *slog.Logger params *DatabaseParams } func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) { s := new(Database) s.params = ¶ms s.log = params.Logger.Get() s.log.Info("Database instantiated") lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { s.log.Info("Database OnStart Hook") return s.connect(ctx) }, OnStop: func(ctx context.Context) error { s.log.Info("Database OnStop Hook") if s.DB != nil { return s.DB.Close() } return nil }, }) return s, nil } func (s *Database) connect(ctx context.Context) error { dbURL := s.params.Config.DBURL if dbURL == "" { dbURL = "file:./data.db?_journal_mode=WAL" } s.log.Info("connecting to database", "url", dbURL) d, err := sql.Open("sqlite", dbURL) if err != nil { s.log.Error("failed to open database", "error", err) return err } if err := d.PingContext(ctx); err != nil { s.log.Error("failed to ping database", "error", err) return err } s.DB = d s.log.Info("database connected") return s.runMigrations(ctx) } type migration struct { version int name string sql string } func (s *Database) runMigrations(ctx context.Context) error { // Bootstrap: create schema_migrations table directly (migration 001 also does this, // but we need it to exist before we can check which migrations have run) _, err := s.DB.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP )`) if err != nil { return fmt.Errorf("failed to create schema_migrations table: %w", err) } // Read all migration files entries, err := fs.ReadDir(SchemaFiles, "schema") if err != nil { return fmt.Errorf("failed to read schema dir: %w", err) } var migrations []migration for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } // Parse version from filename (e.g. "001_initial.sql" -> 1) parts := strings.SplitN(entry.Name(), "_", 2) if len(parts) < 2 { continue } version, err := strconv.Atoi(parts[0]) if err != nil { continue } content, err := SchemaFiles.ReadFile("schema/" + entry.Name()) if err != nil { return fmt.Errorf("failed to read migration %s: %w", entry.Name(), err) } migrations = append(migrations, migration{ version: version, name: entry.Name(), sql: string(content), }) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version }) for _, m := range migrations { var exists int err := s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&exists) if err != nil { return fmt.Errorf("failed to check migration %d: %w", m.version, err) } if exists > 0 { continue } s.log.Info("applying migration", "version", m.version, "name", m.name) if _, err := s.DB.ExecContext(ctx, m.sql); err != nil { return fmt.Errorf("failed to apply migration %d (%s): %w", m.version, m.name, err) } if _, err := s.DB.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version); err != nil { return fmt.Errorf("failed to record migration %d: %w", m.version, err) } } s.log.Info("database migrations complete") return nil }