From 28f3b5aef86afe650ce92e2f46b5d0f1803b74dd Mon Sep 17 00:00:00 2001 From: clawbot Date: Mon, 9 Feb 2026 17:45:01 -0800 Subject: [PATCH] Add comprehensive model and relation test suite --- internal/db/db.go | 303 ++++++++++++++++++++++--- internal/db/db_test.go | 494 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 772 insertions(+), 25 deletions(-) create mode 100644 internal/db/db_test.go diff --git a/internal/db/db.go b/internal/db/db.go index 11eaa8d..eba382b 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -45,16 +45,6 @@ type Database struct { params *Params } -// GetDB returns the underlying sql.DB connection. -func (s *Database) GetDB() *sql.DB { - return s.db -} - -// Hydrate injects the database reference into any model that embeds Base. -func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) { - m.SetDB(s) -} - // New creates a new Database instance and registers lifecycle hooks. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := new(Database) @@ -83,6 +73,238 @@ func New(lc fx.Lifecycle, params Params) (*Database, error) { return s, nil } +// NewTest creates a Database for testing, bypassing fx lifecycle. +// It connects to the given DSN and runs all migrations. +func NewTest(dsn string) (*Database, error) { + d, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, err + } + + s := &Database{ + db: d, + log: slog.Default(), + } + + ctx := context.Background() + + err = s.runMigrations(ctx) + if err != nil { + _ = d.Close() + + return nil, err + } + + return s, nil +} + +// GetDB returns the underlying sql.DB connection. +func (s *Database) GetDB() *sql.DB { + return s.db +} + +// Hydrate injects the database reference into any model that +// embeds Base. +func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) { + m.SetDB(s) +} + +// CreateUser inserts a new user into the database. +func (s *Database) CreateUser( + ctx context.Context, + id, nick, passwordHash string, +) (*models.User, error) { + _, err := s.db.ExecContext(ctx, + `INSERT INTO users (id, nick, password_hash) + VALUES (?, ?, ?)`, + id, nick, passwordHash, + ) + if err != nil { + return nil, err + } + + u := &models.User{ + ID: id, Nick: nick, PasswordHash: passwordHash, + } + s.Hydrate(u) + + return u, nil +} + +// CreateChannel inserts a new channel into the database. +func (s *Database) CreateChannel( + ctx context.Context, + id, name, topic, modes string, +) (*models.Channel, error) { + _, err := s.db.ExecContext(ctx, + `INSERT INTO channels (id, name, topic, modes) + VALUES (?, ?, ?, ?)`, + id, name, topic, modes, + ) + if err != nil { + return nil, err + } + + c := &models.Channel{ + ID: id, Name: name, Topic: topic, Modes: modes, + } + s.Hydrate(c) + + return c, nil +} + +// AddChannelMember adds a user to a channel with the given modes. +func (s *Database) AddChannelMember( + ctx context.Context, + channelID, userID, modes string, +) (*models.ChannelMember, error) { + _, err := s.db.ExecContext(ctx, + `INSERT INTO channel_members + (channel_id, user_id, modes) + VALUES (?, ?, ?)`, + channelID, userID, modes, + ) + if err != nil { + return nil, err + } + + cm := &models.ChannelMember{ + ChannelID: channelID, + UserID: userID, + Modes: modes, + } + s.Hydrate(cm) + + return cm, nil +} + +// CreateMessage inserts a new message into the database. +func (s *Database) CreateMessage( + ctx context.Context, + id, fromUserID, fromNick, target, msgType, body string, +) (*models.Message, error) { + _, err := s.db.ExecContext(ctx, + `INSERT INTO messages + (id, from_user_id, from_nick, target, type, body) + VALUES (?, ?, ?, ?, ?, ?)`, + id, fromUserID, fromNick, target, msgType, body, + ) + if err != nil { + return nil, err + } + + m := &models.Message{ + ID: id, + FromUserID: fromUserID, + FromNick: fromNick, + Target: target, + Type: msgType, + Body: body, + } + s.Hydrate(m) + + return m, nil +} + +// QueueMessage adds a message to a user's delivery queue. +func (s *Database) QueueMessage( + ctx context.Context, + userID, messageID string, +) (*models.MessageQueueEntry, error) { + res, err := s.db.ExecContext(ctx, + `INSERT INTO message_queue (user_id, message_id) + VALUES (?, ?)`, + userID, messageID, + ) + if err != nil { + return nil, err + } + + entryID, _ := res.LastInsertId() + + mq := &models.MessageQueueEntry{ + ID: entryID, + UserID: userID, + MessageID: messageID, + } + s.Hydrate(mq) + + return mq, nil +} + +// CreateAuthToken inserts a new auth token for a user. +func (s *Database) CreateAuthToken( + ctx context.Context, + token, userID string, +) (*models.AuthToken, error) { + _, err := s.db.ExecContext(ctx, + `INSERT INTO auth_tokens (token, user_id) + VALUES (?, ?)`, + token, userID, + ) + if err != nil { + return nil, err + } + + at := &models.AuthToken{Token: token, UserID: userID} + s.Hydrate(at) + + return at, nil +} + +// CreateSession inserts a new session for a user. +func (s *Database) CreateSession( + ctx context.Context, + id, userID string, +) (*models.Session, error) { + _, err := s.db.ExecContext(ctx, + `INSERT INTO sessions (id, user_id) + VALUES (?, ?)`, + id, userID, + ) + if err != nil { + return nil, err + } + + sess := &models.Session{ID: id, UserID: userID} + s.Hydrate(sess) + + return sess, nil +} + +// CreateServerLink inserts a new server link. +func (s *Database) CreateServerLink( + ctx context.Context, + id, name, url, sharedKeyHash string, + isActive bool, +) (*models.ServerLink, error) { + active := 0 + if isActive { + active = 1 + } + + _, err := s.db.ExecContext(ctx, + `INSERT INTO server_links + (id, name, url, shared_key_hash, is_active) + VALUES (?, ?, ?, ?, ?)`, + id, name, url, sharedKeyHash, active, + ) + if err != nil { + return nil, err + } + + sl := &models.ServerLink{ + ID: id, + Name: name, + URL: url, + SharedKeyHash: sharedKeyHash, + IsActive: isActive, + } + s.Hydrate(sl) + + return sl, nil +} + func (s *Database) connect(ctx context.Context) error { dbURL := s.params.Config.DBURL if dbURL == "" { @@ -138,13 +360,18 @@ func (s *Database) runMigrations(ctx context.Context) error { return nil } -func (s *Database) bootstrapMigrationsTable(ctx context.Context) error { - _, err := s.db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( +func (s *Database) bootstrapMigrationsTable( + ctx context.Context, +) error { + _, err := s.db.ExecContext(ctx, + `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP )`) if err != nil { - return fmt.Errorf("failed to create schema_migrations table: %w", err) + return fmt.Errorf( + "create schema_migrations table: %w", err, + ) } return nil @@ -153,17 +380,20 @@ func (s *Database) bootstrapMigrationsTable(ctx context.Context) error { func (s *Database) loadMigrations() ([]migration, error) { entries, err := fs.ReadDir(SchemaFiles, "schema") if err != nil { - return nil, fmt.Errorf("failed to read schema dir: %w", err) + return nil, fmt.Errorf("read schema dir: %w", err) } var migrations []migration for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { + if entry.IsDir() || + !strings.HasSuffix(entry.Name(), ".sql") { continue } - parts := strings.SplitN(entry.Name(), "_", minMigrationParts) + parts := strings.SplitN( + entry.Name(), "_", minMigrationParts, + ) if len(parts) < minMigrationParts { continue } @@ -173,9 +403,13 @@ func (s *Database) loadMigrations() ([]migration, error) { continue } - content, err := SchemaFiles.ReadFile("schema/" + entry.Name()) + content, err := SchemaFiles.ReadFile( + "schema/" + entry.Name(), + ) if err != nil { - return nil, fmt.Errorf("failed to read migration %s: %w", entry.Name(), err) + return nil, fmt.Errorf( + "read migration %s: %w", entry.Name(), err, + ) } migrations = append(migrations, migration{ @@ -192,29 +426,48 @@ func (s *Database) loadMigrations() ([]migration, error) { return migrations, nil } -func (s *Database) applyMigrations(ctx context.Context, migrations []migration) error { +func (s *Database) applyMigrations( + ctx context.Context, + migrations []migration, +) error { for _, m := range migrations { var exists int - err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&exists) + err := s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + m.version, + ).Scan(&exists) if err != nil { - return fmt.Errorf("failed to check migration %d: %w", m.version, err) + return fmt.Errorf( + "check migration %d: %w", m.version, err, + ) } if exists > 0 { continue } - s.log.Info("applying migration", "version", m.version, "name", m.name) + s.log.Info( + "applying migration", + "version", m.version, "name", m.name, + ) _, err = s.db.ExecContext(ctx, m.sql) if err != nil { - return fmt.Errorf("failed to apply migration %d (%s): %w", m.version, m.name, err) + return fmt.Errorf( + "apply migration %d (%s): %w", + m.version, m.name, err, + ) } - _, err = s.db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version) + _, err = s.db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + m.version, + ) if err != nil { - return fmt.Errorf("failed to record migration %d: %w", m.version, err) + return fmt.Errorf( + "record migration %d: %w", m.version, err, + ) } } diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 0000000..c6de510 --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,494 @@ +package db_test + +import ( + "context" + "fmt" + "path/filepath" + "testing" + "time" + + "git.eeqj.de/sneak/chat/internal/db" +) + +const ( + nickAlice = "alice" + nickBob = "bob" + nickCharlie = "charlie" +) + +// setupTestDB creates a fresh database in a temp directory with +// all migrations applied. +func setupTestDB(t *testing.T) *db.Database { + t.Helper() + + dir := t.TempDir() + dsn := fmt.Sprintf( + "file:%s?_journal_mode=WAL", + filepath.Join(dir, "test.db"), + ) + + d, err := db.NewTest(dsn) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + + t.Cleanup(func() { _ = d.GetDB().Close() }) + + return d +} + +func TestCreateUser(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, err := d.CreateUser(ctx, "u1", nickAlice, "hash1") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + + if u.ID != "u1" || u.Nick != nickAlice { + t.Errorf("got user %+v", u) + } +} + +func TestCreateAuthToken(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, err := d.CreateUser(ctx, "u1", nickAlice, "h") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + + tok, err := d.CreateAuthToken(ctx, "tok1", "u1") + if err != nil { + t.Fatalf("CreateAuthToken: %v", err) + } + + if tok.Token != "tok1" || tok.UserID != "u1" { + t.Errorf("unexpected token: %+v", tok) + } + + u, err := tok.User(ctx) + if err != nil { + t.Fatalf("AuthToken.User: %v", err) + } + + if u.ID != "u1" || u.Nick != nickAlice { + t.Errorf("AuthToken.User got %+v", u) + } +} + +func TestCreateChannel(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, err := d.CreateChannel( + ctx, "c1", "#general", "welcome", "+n", + ) + if err != nil { + t.Fatalf("CreateChannel: %v", err) + } + + if ch.ID != "c1" || ch.Name != "#general" { + t.Errorf("unexpected channel: %+v", ch) + } +} + +func TestAddChannelMember(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#general", "", "") + + cm, err := d.AddChannelMember(ctx, "c1", "u1", "+o") + if err != nil { + t.Fatalf("AddChannelMember: %v", err) + } + + if cm.ChannelID != "c1" || cm.Modes != "+o" { + t.Errorf("unexpected member: %+v", cm) + } +} + +func TestCreateMessage(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + + msg, err := d.CreateMessage( + ctx, "m1", "u1", nickAlice, + "#general", "message", "hello", + ) + if err != nil { + t.Fatalf("CreateMessage: %v", err) + } + + if msg.ID != "m1" || msg.Body != "hello" { + t.Errorf("unexpected message: %+v", msg) + } +} + +func TestQueueMessage(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + _, _ = d.CreateMessage( + ctx, "m1", "u1", nickAlice, "u2", "message", "hi", + ) + + mq, err := d.QueueMessage(ctx, "u2", "m1") + if err != nil { + t.Fatalf("QueueMessage: %v", err) + } + + if mq.UserID != "u2" || mq.MessageID != "m1" { + t.Errorf("unexpected queue entry: %+v", mq) + } +} + +func TestCreateSession(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + + sess, err := d.CreateSession(ctx, "s1", "u1") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + if sess.ID != "s1" || sess.UserID != "u1" { + t.Errorf("unexpected session: %+v", sess) + } + + u, err := sess.User(ctx) + if err != nil { + t.Fatalf("Session.User: %v", err) + } + + if u.ID != "u1" { + t.Errorf("Session.User got %v, want u1", u.ID) + } +} + +func TestCreateServerLink(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + sl, err := d.CreateServerLink( + ctx, "sl1", "peer1", + "https://peer.example.com", "keyhash", true, + ) + if err != nil { + t.Fatalf("CreateServerLink: %v", err) + } + + if sl.ID != "sl1" || !sl.IsActive { + t.Errorf("unexpected server link: %+v", sl) + } +} + +func TestUserChannels(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#alpha", "", "") + _, _ = d.CreateChannel(ctx, "c2", "#beta", "", "") + _, _ = d.AddChannelMember(ctx, "c1", "u1", "") + _, _ = d.AddChannelMember(ctx, "c2", "u1", "") + + channels, err := u.Channels(ctx) + if err != nil { + t.Fatalf("User.Channels: %v", err) + } + + if len(channels) != 2 { + t.Fatalf("expected 2 channels, got %d", len(channels)) + } + + if channels[0].Name != "#alpha" { + t.Errorf("first channel: got %s", channels[0].Name) + } + + if channels[1].Name != "#beta" { + t.Errorf("second channel: got %s", channels[1].Name) + } +} + +func TestUserChannelsEmpty(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + + channels, err := u.Channels(ctx) + if err != nil { + t.Fatalf("User.Channels: %v", err) + } + + if len(channels) != 0 { + t.Errorf("expected 0 channels, got %d", len(channels)) + } +} + +func TestUserQueuedMessages(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + + for i := range 3 { + id := fmt.Sprintf("m%d", i) + + _, _ = d.CreateMessage( + ctx, id, "u2", nickBob, "u1", + "message", fmt.Sprintf("msg%d", i), + ) + + time.Sleep(10 * time.Millisecond) + + _, _ = d.QueueMessage(ctx, "u1", id) + } + + msgs, err := u.QueuedMessages(ctx) + if err != nil { + t.Fatalf("User.QueuedMessages: %v", err) + } + + if len(msgs) != 3 { + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } + + for i, msg := range msgs { + want := fmt.Sprintf("msg%d", i) + if msg.Body != want { + t.Errorf("msg %d: got %q, want %q", i, msg.Body, want) + } + } +} + +func TestUserQueuedMessagesEmpty(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + + msgs, err := u.QueuedMessages(ctx) + if err != nil { + t.Fatalf("User.QueuedMessages: %v", err) + } + + if len(msgs) != 0 { + t.Errorf("expected 0 messages, got %d", len(msgs)) + } +} + +func TestChannelMembers(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + _, _ = d.CreateUser(ctx, "u3", nickCharlie, "h") + _, _ = d.AddChannelMember(ctx, "c1", "u1", "+o") + _, _ = d.AddChannelMember(ctx, "c1", "u2", "+v") + _, _ = d.AddChannelMember(ctx, "c1", "u3", "") + + members, err := ch.Members(ctx) + if err != nil { + t.Fatalf("Channel.Members: %v", err) + } + + if len(members) != 3 { + t.Fatalf("expected 3 members, got %d", len(members)) + } + + nicks := map[string]bool{} + for _, m := range members { + nicks[m.Nick] = true + } + + for _, want := range []string{ + nickAlice, nickBob, nickCharlie, + } { + if !nicks[want] { + t.Errorf("missing nick %q", want) + } + } +} + +func TestChannelMembersEmpty(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#empty", "", "") + + members, err := ch.Members(ctx) + if err != nil { + t.Fatalf("Channel.Members: %v", err) + } + + if len(members) != 0 { + t.Errorf("expected 0, got %d", len(members)) + } +} + +func TestChannelRecentMessages(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + + for i := range 5 { + id := fmt.Sprintf("m%d", i) + + _, _ = d.CreateMessage( + ctx, id, "u1", nickAlice, "#general", + "message", fmt.Sprintf("msg%d", i), + ) + + time.Sleep(10 * time.Millisecond) + } + + msgs, err := ch.RecentMessages(ctx, 3) + if err != nil { + t.Fatalf("RecentMessages: %v", err) + } + + if len(msgs) != 3 { + t.Fatalf("expected 3, got %d", len(msgs)) + } + + if msgs[0].Body != "msg4" { + t.Errorf("first: got %q, want msg4", msgs[0].Body) + } + + if msgs[2].Body != "msg2" { + t.Errorf("last: got %q, want msg2", msgs[2].Body) + } +} + +func TestChannelRecentMessagesLargeLimit(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateMessage( + ctx, "m1", "u1", nickAlice, + "#general", "message", "only", + ) + + msgs, err := ch.RecentMessages(ctx, 100) + if err != nil { + t.Fatalf("RecentMessages: %v", err) + } + + if len(msgs) != 1 { + t.Errorf("expected 1, got %d", len(msgs)) + } +} + +func TestChannelMemberUser(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#general", "", "") + + cm, _ := d.AddChannelMember(ctx, "c1", "u1", "+o") + + u, err := cm.User(ctx) + if err != nil { + t.Fatalf("ChannelMember.User: %v", err) + } + + if u.ID != "u1" || u.Nick != nickAlice { + t.Errorf("got %+v", u) + } +} + +func TestChannelMemberChannel(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#general", "topic", "+n") + + cm, _ := d.AddChannelMember(ctx, "c1", "u1", "") + + ch, err := cm.Channel(ctx) + if err != nil { + t.Fatalf("ChannelMember.Channel: %v", err) + } + + if ch.ID != "c1" || ch.Topic != "topic" { + t.Errorf("got %+v", ch) + } +} + +func TestDMMessage(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + + msg, err := d.CreateMessage( + ctx, "m1", "u1", nickAlice, "u2", "message", "hey", + ) + if err != nil { + t.Fatalf("CreateMessage DM: %v", err) + } + + if msg.Target != "u2" { + t.Errorf("target: got %q, want u2", msg.Target) + } +}