Add models package with embedded DB interface pattern
- internal/models/model.go: DB interface + Base struct for all models - internal/models/channel.go: Channel model with DB access for relation queries - Database.NewChannel() factory injects db reference into model instances - Uses interface to avoid circular imports (models -> db)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"git.eeqj.de/sneak/chat/internal/config"
|
"git.eeqj.de/sneak/chat/internal/config"
|
||||||
"git.eeqj.de/sneak/chat/internal/logger"
|
"git.eeqj.de/sneak/chat/internal/logger"
|
||||||
|
"git.eeqj.de/sneak/chat/internal/models"
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
|
|
||||||
_ "github.com/joho/godotenv/autoload"
|
_ "github.com/joho/godotenv/autoload"
|
||||||
@@ -29,11 +31,30 @@ type DatabaseParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
db *sql.DB
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
params *DatabaseParams
|
params *DatabaseParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDB implements models.db so Database can be embedded in model structs.
|
||||||
|
func (s *Database) GetDB() *sql.DB {
|
||||||
|
return s.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChannel creates a Channel model instance with the db reference injected.
|
||||||
|
func (s *Database) NewChannel(id int64, name, topic, modes string, createdAt, updatedAt time.Time) *models.Channel {
|
||||||
|
c := &models.Channel{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
Topic: topic,
|
||||||
|
Modes: modes,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
UpdatedAt: updatedAt,
|
||||||
|
}
|
||||||
|
c.SetDB(s)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
||||||
s := new(Database)
|
s := new(Database)
|
||||||
s.params = ¶ms
|
s.params = ¶ms
|
||||||
@@ -48,8 +69,8 @@ func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
|||||||
},
|
},
|
||||||
OnStop: func(ctx context.Context) error {
|
OnStop: func(ctx context.Context) error {
|
||||||
s.log.Info("Database OnStop Hook")
|
s.log.Info("Database OnStop Hook")
|
||||||
if s.DB != nil {
|
if s.db != nil {
|
||||||
return s.DB.Close()
|
return s.db.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -76,7 +97,7 @@ func (s *Database) connect(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.DB = d
|
s.db = d
|
||||||
s.log.Info("database connected")
|
s.log.Info("database connected")
|
||||||
|
|
||||||
return s.runMigrations(ctx)
|
return s.runMigrations(ctx)
|
||||||
@@ -91,7 +112,7 @@ type migration struct {
|
|||||||
func (s *Database) runMigrations(ctx context.Context) error {
|
func (s *Database) runMigrations(ctx context.Context) error {
|
||||||
// Bootstrap: create schema_migrations table directly (migration 001 also does this,
|
// Bootstrap: create schema_migrations table directly (migration 001 also does this,
|
||||||
// but we need it to exist before we can check which migrations have run)
|
// but we need it to exist before we can check which migrations have run)
|
||||||
_, err := s.DB.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (
|
_, err := s.db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||||
version INTEGER PRIMARY KEY,
|
version INTEGER PRIMARY KEY,
|
||||||
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
)`)
|
)`)
|
||||||
@@ -139,7 +160,7 @@ func (s *Database) runMigrations(ctx context.Context) error {
|
|||||||
|
|
||||||
for _, m := range migrations {
|
for _, m := range migrations {
|
||||||
var exists int
|
var exists int
|
||||||
err := s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&exists)
|
err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&exists)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check migration %d: %w", m.version, err)
|
return fmt.Errorf("failed to check migration %d: %w", m.version, err)
|
||||||
}
|
}
|
||||||
@@ -148,10 +169,10 @@ func (s *Database) runMigrations(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.log.Info("applying migration", "version", m.version, "name", m.name)
|
s.log.Info("applying migration", "version", m.version, "name", m.name)
|
||||||
if _, err := s.DB.ExecContext(ctx, m.sql); err != nil {
|
if _, err := s.db.ExecContext(ctx, m.sql); err != nil {
|
||||||
return fmt.Errorf("failed to apply migration %d (%s): %w", m.version, m.name, err)
|
return fmt.Errorf("failed to apply migration %d (%s): %w", m.version, m.name, err)
|
||||||
}
|
}
|
||||||
if _, err := s.DB.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version); err != nil {
|
if _, err := s.db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version); err != nil {
|
||||||
return fmt.Errorf("failed to record migration %d: %w", m.version, err)
|
return fmt.Errorf("failed to record migration %d: %w", m.version, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -159,3 +180,4 @@ func (s *Database) runMigrations(ctx context.Context) error {
|
|||||||
s.log.Info("database migrations complete")
|
s.log.Info("database migrations complete")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
21
internal/models/channel.go
Normal file
21
internal/models/channel.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Channel represents a chat channel.
|
||||||
|
type Channel struct {
|
||||||
|
Base
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
Modes string `json:"modes"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example relation method — will be fleshed out when we add channel_members:
|
||||||
|
// func (c *Channel) Members(ctx context.Context) ([]*User, error) {
|
||||||
|
// return c.DB().GetChannelMembers(ctx, c.ID)
|
||||||
|
// }
|
||||||
18
internal/models/model.go
Normal file
18
internal/models/model.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
// DB is the interface that models use to query relations.
|
||||||
|
// This avoids a circular import with the db package.
|
||||||
|
type DB interface {
|
||||||
|
GetDB() *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base is embedded in all model structs to provide database access.
|
||||||
|
type Base struct {
|
||||||
|
db DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Base) SetDB(d DB) {
|
||||||
|
b.db = d
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user