From fbe53179b83e16756b62d17417ce1337a3e19af1 Mon Sep 17 00:00:00 2001 From: clawbot Date: Mon, 9 Feb 2026 21:15:41 -0800 Subject: [PATCH] Fix code review feedback items 1-6, 8-10 - Item 1: Extract GetUserByID/GetChannelByID lookup methods, use from relation methods - Item 2: Initialize slices with literals so JSON gets [] not null - Item 3: Populate CreatedAt/UpdatedAt with time.Now() on all Create methods - Item 4: Wrap each migration's SQL + recording in a transaction - Item 5: Check error from res.LastInsertId() in QueueMessage - Item 6: Add DequeueMessages and AckMessages methods - Item 8: Add GetUserByNick, GetUserByToken, DeleteAuthToken, UpdateUserLastSeen - Item 9: Run PRAGMA foreign_keys = ON on every new connection - Item 10: Builds clean, all tests pass --- internal/db/db.go | 251 +++++++++++++++++++++++++++++- internal/models/auth_token.go | 18 +-- internal/models/channel.go | 4 +- internal/models/channel_member.go | 35 +---- internal/models/model.go | 33 +++- internal/models/session.go | 18 +-- internal/models/user.go | 4 +- 7 files changed, 297 insertions(+), 66 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index eba382b..a6663ff 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -11,6 +11,7 @@ import ( "sort" "strconv" "strings" + "time" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/logger" @@ -86,6 +87,12 @@ func NewTest(dsn string) (*Database, error) { log: slog.Default(), } + // Item 9: Enable foreign keys + if _, err := d.Exec("PRAGMA foreign_keys = ON"); err != nil { + _ = d.Close() + return nil, fmt.Errorf("enable foreign keys: %w", err) + } + ctx := context.Background() err = s.runMigrations(ctx) @@ -109,11 +116,131 @@ func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) { m.SetDB(s) } +// GetUserByID looks up a user by their ID. +func (s *Database) GetUserByID( + ctx context.Context, + id string, +) (*models.User, error) { + u := &models.User{} + s.Hydrate(u) + + err := s.db.QueryRowContext(ctx, ` + SELECT id, nick, password_hash, created_at, updated_at, last_seen_at + FROM users WHERE id = ?`, + id, + ).Scan( + &u.ID, &u.Nick, &u.PasswordHash, + &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, + ) + if err != nil { + return nil, err + } + + return u, nil +} + +// GetChannelByID looks up a channel by its ID. +func (s *Database) GetChannelByID( + ctx context.Context, + id string, +) (*models.Channel, error) { + c := &models.Channel{} + s.Hydrate(c) + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, topic, modes, created_at, updated_at + FROM channels WHERE id = ?`, + id, + ).Scan( + &c.ID, &c.Name, &c.Topic, &c.Modes, + &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, err + } + + return c, nil +} + +// GetUserByNick looks up a user by their nick. +func (s *Database) GetUserByNick( + ctx context.Context, + nick string, +) (*models.User, error) { + u := &models.User{} + s.Hydrate(u) + + err := s.db.QueryRowContext(ctx, ` + SELECT id, nick, password_hash, created_at, updated_at, last_seen_at + FROM users WHERE nick = ?`, + nick, + ).Scan( + &u.ID, &u.Nick, &u.PasswordHash, + &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, + ) + if err != nil { + return nil, err + } + + return u, nil +} + +// GetUserByToken looks up a user by their auth token. +func (s *Database) GetUserByToken( + ctx context.Context, + token string, +) (*models.User, error) { + u := &models.User{} + s.Hydrate(u) + + err := s.db.QueryRowContext(ctx, ` + SELECT u.id, u.nick, u.password_hash, + u.created_at, u.updated_at, u.last_seen_at + FROM users u + JOIN auth_tokens t ON t.user_id = u.id + WHERE t.token = ?`, + token, + ).Scan( + &u.ID, &u.Nick, &u.PasswordHash, + &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, + ) + if err != nil { + return nil, err + } + + return u, nil +} + +// DeleteAuthToken removes an auth token from the database. +func (s *Database) DeleteAuthToken( + ctx context.Context, + token string, +) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM auth_tokens WHERE token = ?`, token, + ) + return err +} + +// UpdateUserLastSeen updates the last_seen_at timestamp for a user. +func (s *Database) UpdateUserLastSeen( + ctx context.Context, + userID string, +) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`, + userID, + ) + return err +} + // CreateUser inserts a new user into the database. func (s *Database) CreateUser( ctx context.Context, id, nick, passwordHash string, ) (*models.User, error) { + now := time.Now() + _, err := s.db.ExecContext(ctx, `INSERT INTO users (id, nick, password_hash) VALUES (?, ?, ?)`, @@ -125,6 +252,7 @@ func (s *Database) CreateUser( u := &models.User{ ID: id, Nick: nick, PasswordHash: passwordHash, + CreatedAt: now, UpdatedAt: now, } s.Hydrate(u) @@ -136,6 +264,8 @@ func (s *Database) CreateChannel( ctx context.Context, id, name, topic, modes string, ) (*models.Channel, error) { + now := time.Now() + _, err := s.db.ExecContext(ctx, `INSERT INTO channels (id, name, topic, modes) VALUES (?, ?, ?, ?)`, @@ -147,6 +277,7 @@ func (s *Database) CreateChannel( c := &models.Channel{ ID: id, Name: name, Topic: topic, Modes: modes, + CreatedAt: now, UpdatedAt: now, } s.Hydrate(c) @@ -158,6 +289,8 @@ func (s *Database) AddChannelMember( ctx context.Context, channelID, userID, modes string, ) (*models.ChannelMember, error) { + now := time.Now() + _, err := s.db.ExecContext(ctx, `INSERT INTO channel_members (channel_id, user_id, modes) @@ -172,6 +305,7 @@ func (s *Database) AddChannelMember( ChannelID: channelID, UserID: userID, Modes: modes, + JoinedAt: now, } s.Hydrate(cm) @@ -183,6 +317,8 @@ func (s *Database) CreateMessage( ctx context.Context, id, fromUserID, fromNick, target, msgType, body string, ) (*models.Message, error) { + now := time.Now() + _, err := s.db.ExecContext(ctx, `INSERT INTO messages (id, from_user_id, from_nick, target, type, body) @@ -200,6 +336,8 @@ func (s *Database) CreateMessage( Target: target, Type: msgType, Body: body, + Timestamp: now, + CreatedAt: now, } s.Hydrate(m) @@ -211,6 +349,8 @@ func (s *Database) QueueMessage( ctx context.Context, userID, messageID string, ) (*models.MessageQueueEntry, error) { + now := time.Now() + res, err := s.db.ExecContext(ctx, `INSERT INTO message_queue (user_id, message_id) VALUES (?, ?)`, @@ -220,23 +360,93 @@ func (s *Database) QueueMessage( return nil, err } - entryID, _ := res.LastInsertId() + entryID, err := res.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get last insert id: %w", err) + } mq := &models.MessageQueueEntry{ ID: entryID, UserID: userID, MessageID: messageID, + QueuedAt: now, } s.Hydrate(mq) return mq, nil } +// DequeueMessages returns up to limit pending messages for a user, +// ordered by queue time (oldest first). +func (s *Database) DequeueMessages( + ctx context.Context, + userID string, + limit int, +) ([]*models.MessageQueueEntry, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, user_id, message_id, queued_at + FROM message_queue + WHERE user_id = ? + ORDER BY queued_at ASC + LIMIT ?`, + userID, limit, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + entries := []*models.MessageQueueEntry{} + + for rows.Next() { + e := &models.MessageQueueEntry{} + s.Hydrate(e) + + err = rows.Scan(&e.ID, &e.UserID, &e.MessageID, &e.QueuedAt) + if err != nil { + return nil, err + } + + entries = append(entries, e) + } + + return entries, rows.Err() +} + +// AckMessages removes the given queue entry IDs, marking them as delivered. +func (s *Database) AckMessages( + ctx context.Context, + entryIDs []int64, +) error { + if len(entryIDs) == 0 { + return nil + } + + placeholders := make([]string, len(entryIDs)) + args := make([]interface{}, len(entryIDs)) + + for i, id := range entryIDs { + placeholders[i] = "?" + args[i] = id + } + + query := fmt.Sprintf( + "DELETE FROM message_queue WHERE id IN (%s)", + strings.Join(placeholders, ","), + ) + + _, err := s.db.ExecContext(ctx, query, args...) + + return err +} + // CreateAuthToken inserts a new auth token for a user. func (s *Database) CreateAuthToken( ctx context.Context, token, userID string, ) (*models.AuthToken, error) { + now := time.Now() + _, err := s.db.ExecContext(ctx, `INSERT INTO auth_tokens (token, user_id) VALUES (?, ?)`, @@ -246,7 +456,7 @@ func (s *Database) CreateAuthToken( return nil, err } - at := &models.AuthToken{Token: token, UserID: userID} + at := &models.AuthToken{Token: token, UserID: userID, CreatedAt: now} s.Hydrate(at) return at, nil @@ -257,6 +467,8 @@ func (s *Database) CreateSession( ctx context.Context, id, userID string, ) (*models.Session, error) { + now := time.Now() + _, err := s.db.ExecContext(ctx, `INSERT INTO sessions (id, user_id) VALUES (?, ?)`, @@ -266,7 +478,10 @@ func (s *Database) CreateSession( return nil, err } - sess := &models.Session{ID: id, UserID: userID} + sess := &models.Session{ + ID: id, UserID: userID, + CreatedAt: now, LastActiveAt: now, + } s.Hydrate(sess) return sess, nil @@ -278,7 +493,9 @@ func (s *Database) CreateServerLink( id, name, url, sharedKeyHash string, isActive bool, ) (*models.ServerLink, error) { + now := time.Now() active := 0 + if isActive { active = 1 } @@ -299,6 +516,7 @@ func (s *Database) CreateServerLink( URL: url, SharedKeyHash: sharedKeyHash, IsActive: isActive, + CreatedAt: now, } s.Hydrate(sl) @@ -330,6 +548,11 @@ func (s *Database) connect(ctx context.Context) error { s.db = d s.log.Info("database connected") + // Item 9: Enable foreign keys on every connection + if _, err := s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + return fmt.Errorf("enable foreign keys: %w", err) + } + return s.runMigrations(ctx) } @@ -426,6 +649,7 @@ func (s *Database) loadMigrations() ([]migration, error) { return migrations, nil } +// Item 4: Wrap each migration in a transaction func (s *Database) applyMigrations( ctx context.Context, migrations []migration, @@ -452,23 +676,40 @@ func (s *Database) applyMigrations( "version", m.version, "name", m.name, ) - _, err = s.db.ExecContext(ctx, m.sql) + tx, err := s.db.BeginTx(ctx, nil) if err != nil { + return fmt.Errorf( + "begin tx for migration %d: %w", m.version, err, + ) + } + + _, err = tx.ExecContext(ctx, m.sql) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf( "apply migration %d (%s): %w", m.version, m.name, err, ) } - _, err = s.db.ExecContext(ctx, + _, err = tx.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version, ) if err != nil { + _ = tx.Rollback() + return fmt.Errorf( "record migration %d: %w", m.version, err, ) } + + if err := tx.Commit(); err != nil { + return fmt.Errorf( + "commit migration %d: %w", m.version, err, + ) + } } return nil diff --git a/internal/models/auth_token.go b/internal/models/auth_token.go index 5f6bc4e..f1646e2 100644 --- a/internal/models/auth_token.go +++ b/internal/models/auth_token.go @@ -2,6 +2,7 @@ package models import ( "context" + "fmt" "time" ) @@ -18,20 +19,9 @@ type AuthToken struct { // User returns the user who owns this token. func (t *AuthToken) User(ctx context.Context) (*User, error) { - u := &User{} - u.SetDB(t.db) - - err := t.GetDB().QueryRowContext(ctx, ` - SELECT id, nick, password_hash, created_at, updated_at, last_seen_at - FROM users WHERE id = ?`, - t.UserID, - ).Scan( - &u.ID, &u.Nick, &u.PasswordHash, - &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, - ) - if err != nil { - return nil, err + if ul := t.GetUserLookup(); ul != nil { + return ul.GetUserByID(ctx, t.UserID) } - return u, nil + return nil, fmt.Errorf("user lookup not available") } diff --git a/internal/models/channel.go b/internal/models/channel.go index a6ecea2..addafc9 100644 --- a/internal/models/channel.go +++ b/internal/models/channel.go @@ -34,7 +34,7 @@ func (c *Channel) Members(ctx context.Context) ([]*ChannelMember, error) { defer func() { _ = rows.Close() }() - var members []*ChannelMember + members := []*ChannelMember{} for rows.Next() { m := &ChannelMember{} @@ -74,7 +74,7 @@ func (c *Channel) RecentMessages( defer func() { _ = rows.Close() }() - var messages []*Message + messages := []*Message{} for rows.Next() { msg := &Message{} diff --git a/internal/models/channel_member.go b/internal/models/channel_member.go index c2d5e4d..f93ed9d 100644 --- a/internal/models/channel_member.go +++ b/internal/models/channel_member.go @@ -2,6 +2,7 @@ package models import ( "context" + "fmt" "time" ) @@ -18,40 +19,18 @@ type ChannelMember struct { // User returns the full User for this membership. func (cm *ChannelMember) User(ctx context.Context) (*User, error) { - u := &User{} - u.SetDB(cm.db) - - err := cm.GetDB().QueryRowContext(ctx, ` - SELECT id, nick, password_hash, created_at, updated_at, last_seen_at - FROM users WHERE id = ?`, - cm.UserID, - ).Scan( - &u.ID, &u.Nick, &u.PasswordHash, - &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, - ) - if err != nil { - return nil, err + if ul := cm.GetUserLookup(); ul != nil { + return ul.GetUserByID(ctx, cm.UserID) } - return u, nil + return nil, fmt.Errorf("user lookup not available") } // Channel returns the full Channel for this membership. func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) { - c := &Channel{} - c.SetDB(cm.db) - - err := cm.GetDB().QueryRowContext(ctx, ` - SELECT id, name, topic, modes, created_at, updated_at - FROM channels WHERE id = ?`, - cm.ChannelID, - ).Scan( - &c.ID, &c.Name, &c.Topic, &c.Modes, - &c.CreatedAt, &c.UpdatedAt, - ) - if err != nil { - return nil, err + if cl := cm.GetChannelLookup(); cl != nil { + return cl.GetChannelByID(ctx, cm.ChannelID) } - return c, nil + return nil, fmt.Errorf("channel lookup not available") } diff --git a/internal/models/model.go b/internal/models/model.go index 40ee96d..b65c6e8 100644 --- a/internal/models/model.go +++ b/internal/models/model.go @@ -3,7 +3,10 @@ // relation-fetching methods directly on model instances. package models -import "database/sql" +import ( + "context" + "database/sql" +) // DB is the interface that models use to query the database. // This avoids a circular import with the db package. @@ -11,6 +14,16 @@ type DB interface { GetDB() *sql.DB } +// UserLookup provides user lookup by ID without circular imports. +type UserLookup interface { + GetUserByID(ctx context.Context, id string) (*User, error) +} + +// ChannelLookup provides channel lookup by ID without circular imports. +type ChannelLookup interface { + GetChannelByID(ctx context.Context, id string) (*Channel, error) +} + // Base is embedded in all model structs to provide database access. type Base struct { db DB @@ -25,3 +38,21 @@ func (b *Base) SetDB(d DB) { func (b *Base) GetDB() *sql.DB { return b.db.GetDB() } + +// GetUserLookup returns the DB as a UserLookup if it implements the interface. +func (b *Base) GetUserLookup() UserLookup { + if ul, ok := b.db.(UserLookup); ok { + return ul + } + + return nil +} + +// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface. +func (b *Base) GetChannelLookup() ChannelLookup { + if cl, ok := b.db.(ChannelLookup); ok { + return cl + } + + return nil +} diff --git a/internal/models/session.go b/internal/models/session.go index 4b495f2..42231d9 100644 --- a/internal/models/session.go +++ b/internal/models/session.go @@ -2,6 +2,7 @@ package models import ( "context" + "fmt" "time" ) @@ -18,20 +19,9 @@ type Session struct { // User returns the user who owns this session. func (s *Session) User(ctx context.Context) (*User, error) { - u := &User{} - u.SetDB(s.db) - - err := s.GetDB().QueryRowContext(ctx, ` - SELECT id, nick, password_hash, created_at, updated_at, last_seen_at - FROM users WHERE id = ?`, - s.UserID, - ).Scan( - &u.ID, &u.Nick, &u.PasswordHash, - &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, - ) - if err != nil { - return nil, err + if ul := s.GetUserLookup(); ul != nil { + return ul.GetUserByID(ctx, s.UserID) } - return u, nil + return nil, fmt.Errorf("user lookup not available") } diff --git a/internal/models/user.go b/internal/models/user.go index 214ea9b..f3d778f 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -33,7 +33,7 @@ func (u *User) Channels(ctx context.Context) ([]*Channel, error) { defer func() { _ = rows.Close() }() - var channels []*Channel + channels := []*Channel{} for rows.Next() { c := &Channel{} @@ -70,7 +70,7 @@ func (u *User) QueuedMessages(ctx context.Context) ([]*Message, error) { defer func() { _ = rows.Close() }() - var messages []*Message + messages := []*Message{} for rows.Next() { msg := &Message{}