diff --git a/internal/db/db.go b/internal/db/db.go index 5539b0d..a6663ff 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -46,26 +46,6 @@ type Database struct { params *Params } -// GetDB returns the underlying sql.DB connection. -func (s *Database) GetDB() *sql.DB { - return s.db -} - -// NewChannel creates a Channel model instance with the db reference injected. -func (s *Database) NewChannel(id int64, name, topic, modes string, createdAt, updatedAt time.Time) *models.Channel { - c := &models.Channel{ - ID: id, - Name: name, - Topic: topic, - Modes: modes, - CreatedAt: createdAt, - UpdatedAt: updatedAt, - } - c.SetDB(s) - - return c -} - // New creates a new Database instance and registers lifecycle hooks. func New(lc fx.Lifecycle, params Params) (*Database, error) { s := new(Database) @@ -94,6 +74,455 @@ func New(lc fx.Lifecycle, params Params) (*Database, error) { return s, nil } +// NewTest creates a Database for testing, bypassing fx lifecycle. +// It connects to the given DSN and runs all migrations. +func NewTest(dsn string) (*Database, error) { + d, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, err + } + + s := &Database{ + db: d, + log: slog.Default(), + } + + // Item 9: Enable foreign keys + if _, err := d.Exec("PRAGMA foreign_keys = ON"); err != nil { + _ = d.Close() + return nil, fmt.Errorf("enable foreign keys: %w", err) + } + + ctx := context.Background() + + err = s.runMigrations(ctx) + if err != nil { + _ = d.Close() + + return nil, err + } + + return s, nil +} + +// GetDB returns the underlying sql.DB connection. +func (s *Database) GetDB() *sql.DB { + return s.db +} + +// Hydrate injects the database reference into any model that +// embeds Base. +func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) { + m.SetDB(s) +} + +// GetUserByID looks up a user by their ID. +func (s *Database) GetUserByID( + ctx context.Context, + id string, +) (*models.User, error) { + u := &models.User{} + s.Hydrate(u) + + err := s.db.QueryRowContext(ctx, ` + SELECT id, nick, password_hash, created_at, updated_at, last_seen_at + FROM users WHERE id = ?`, + id, + ).Scan( + &u.ID, &u.Nick, &u.PasswordHash, + &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, + ) + if err != nil { + return nil, err + } + + return u, nil +} + +// GetChannelByID looks up a channel by its ID. +func (s *Database) GetChannelByID( + ctx context.Context, + id string, +) (*models.Channel, error) { + c := &models.Channel{} + s.Hydrate(c) + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, topic, modes, created_at, updated_at + FROM channels WHERE id = ?`, + id, + ).Scan( + &c.ID, &c.Name, &c.Topic, &c.Modes, + &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, err + } + + return c, nil +} + +// GetUserByNick looks up a user by their nick. +func (s *Database) GetUserByNick( + ctx context.Context, + nick string, +) (*models.User, error) { + u := &models.User{} + s.Hydrate(u) + + err := s.db.QueryRowContext(ctx, ` + SELECT id, nick, password_hash, created_at, updated_at, last_seen_at + FROM users WHERE nick = ?`, + nick, + ).Scan( + &u.ID, &u.Nick, &u.PasswordHash, + &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, + ) + if err != nil { + return nil, err + } + + return u, nil +} + +// GetUserByToken looks up a user by their auth token. +func (s *Database) GetUserByToken( + ctx context.Context, + token string, +) (*models.User, error) { + u := &models.User{} + s.Hydrate(u) + + err := s.db.QueryRowContext(ctx, ` + SELECT u.id, u.nick, u.password_hash, + u.created_at, u.updated_at, u.last_seen_at + FROM users u + JOIN auth_tokens t ON t.user_id = u.id + WHERE t.token = ?`, + token, + ).Scan( + &u.ID, &u.Nick, &u.PasswordHash, + &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, + ) + if err != nil { + return nil, err + } + + return u, nil +} + +// DeleteAuthToken removes an auth token from the database. +func (s *Database) DeleteAuthToken( + ctx context.Context, + token string, +) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM auth_tokens WHERE token = ?`, token, + ) + return err +} + +// UpdateUserLastSeen updates the last_seen_at timestamp for a user. +func (s *Database) UpdateUserLastSeen( + ctx context.Context, + userID string, +) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`, + userID, + ) + return err +} + +// CreateUser inserts a new user into the database. +func (s *Database) CreateUser( + ctx context.Context, + id, nick, passwordHash string, +) (*models.User, error) { + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO users (id, nick, password_hash) + VALUES (?, ?, ?)`, + id, nick, passwordHash, + ) + if err != nil { + return nil, err + } + + u := &models.User{ + ID: id, Nick: nick, PasswordHash: passwordHash, + CreatedAt: now, UpdatedAt: now, + } + s.Hydrate(u) + + return u, nil +} + +// CreateChannel inserts a new channel into the database. +func (s *Database) CreateChannel( + ctx context.Context, + id, name, topic, modes string, +) (*models.Channel, error) { + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO channels (id, name, topic, modes) + VALUES (?, ?, ?, ?)`, + id, name, topic, modes, + ) + if err != nil { + return nil, err + } + + c := &models.Channel{ + ID: id, Name: name, Topic: topic, Modes: modes, + CreatedAt: now, UpdatedAt: now, + } + s.Hydrate(c) + + return c, nil +} + +// AddChannelMember adds a user to a channel with the given modes. +func (s *Database) AddChannelMember( + ctx context.Context, + channelID, userID, modes string, +) (*models.ChannelMember, error) { + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO channel_members + (channel_id, user_id, modes) + VALUES (?, ?, ?)`, + channelID, userID, modes, + ) + if err != nil { + return nil, err + } + + cm := &models.ChannelMember{ + ChannelID: channelID, + UserID: userID, + Modes: modes, + JoinedAt: now, + } + s.Hydrate(cm) + + return cm, nil +} + +// CreateMessage inserts a new message into the database. +func (s *Database) CreateMessage( + ctx context.Context, + id, fromUserID, fromNick, target, msgType, body string, +) (*models.Message, error) { + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO messages + (id, from_user_id, from_nick, target, type, body) + VALUES (?, ?, ?, ?, ?, ?)`, + id, fromUserID, fromNick, target, msgType, body, + ) + if err != nil { + return nil, err + } + + m := &models.Message{ + ID: id, + FromUserID: fromUserID, + FromNick: fromNick, + Target: target, + Type: msgType, + Body: body, + Timestamp: now, + CreatedAt: now, + } + s.Hydrate(m) + + return m, nil +} + +// QueueMessage adds a message to a user's delivery queue. +func (s *Database) QueueMessage( + ctx context.Context, + userID, messageID string, +) (*models.MessageQueueEntry, error) { + now := time.Now() + + res, err := s.db.ExecContext(ctx, + `INSERT INTO message_queue (user_id, message_id) + VALUES (?, ?)`, + userID, messageID, + ) + if err != nil { + return nil, err + } + + entryID, err := res.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get last insert id: %w", err) + } + + mq := &models.MessageQueueEntry{ + ID: entryID, + UserID: userID, + MessageID: messageID, + QueuedAt: now, + } + s.Hydrate(mq) + + return mq, nil +} + +// DequeueMessages returns up to limit pending messages for a user, +// ordered by queue time (oldest first). +func (s *Database) DequeueMessages( + ctx context.Context, + userID string, + limit int, +) ([]*models.MessageQueueEntry, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, user_id, message_id, queued_at + FROM message_queue + WHERE user_id = ? + ORDER BY queued_at ASC + LIMIT ?`, + userID, limit, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + entries := []*models.MessageQueueEntry{} + + for rows.Next() { + e := &models.MessageQueueEntry{} + s.Hydrate(e) + + err = rows.Scan(&e.ID, &e.UserID, &e.MessageID, &e.QueuedAt) + if err != nil { + return nil, err + } + + entries = append(entries, e) + } + + return entries, rows.Err() +} + +// AckMessages removes the given queue entry IDs, marking them as delivered. +func (s *Database) AckMessages( + ctx context.Context, + entryIDs []int64, +) error { + if len(entryIDs) == 0 { + return nil + } + + placeholders := make([]string, len(entryIDs)) + args := make([]interface{}, len(entryIDs)) + + for i, id := range entryIDs { + placeholders[i] = "?" + args[i] = id + } + + query := fmt.Sprintf( + "DELETE FROM message_queue WHERE id IN (%s)", + strings.Join(placeholders, ","), + ) + + _, err := s.db.ExecContext(ctx, query, args...) + + return err +} + +// CreateAuthToken inserts a new auth token for a user. +func (s *Database) CreateAuthToken( + ctx context.Context, + token, userID string, +) (*models.AuthToken, error) { + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO auth_tokens (token, user_id) + VALUES (?, ?)`, + token, userID, + ) + if err != nil { + return nil, err + } + + at := &models.AuthToken{Token: token, UserID: userID, CreatedAt: now} + s.Hydrate(at) + + return at, nil +} + +// CreateSession inserts a new session for a user. +func (s *Database) CreateSession( + ctx context.Context, + id, userID string, +) (*models.Session, error) { + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO sessions (id, user_id) + VALUES (?, ?)`, + id, userID, + ) + if err != nil { + return nil, err + } + + sess := &models.Session{ + ID: id, UserID: userID, + CreatedAt: now, LastActiveAt: now, + } + s.Hydrate(sess) + + return sess, nil +} + +// CreateServerLink inserts a new server link. +func (s *Database) CreateServerLink( + ctx context.Context, + id, name, url, sharedKeyHash string, + isActive bool, +) (*models.ServerLink, error) { + now := time.Now() + active := 0 + + if isActive { + active = 1 + } + + _, err := s.db.ExecContext(ctx, + `INSERT INTO server_links + (id, name, url, shared_key_hash, is_active) + VALUES (?, ?, ?, ?, ?)`, + id, name, url, sharedKeyHash, active, + ) + if err != nil { + return nil, err + } + + sl := &models.ServerLink{ + ID: id, + Name: name, + URL: url, + SharedKeyHash: sharedKeyHash, + IsActive: isActive, + CreatedAt: now, + } + s.Hydrate(sl) + + return sl, nil +} + func (s *Database) connect(ctx context.Context) error { dbURL := s.params.Config.DBURL if dbURL == "" { @@ -119,6 +548,11 @@ func (s *Database) connect(ctx context.Context) error { s.db = d s.log.Info("database connected") + // Item 9: Enable foreign keys on every connection + if _, err := s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + return fmt.Errorf("enable foreign keys: %w", err) + } + return s.runMigrations(ctx) } @@ -149,13 +583,18 @@ func (s *Database) runMigrations(ctx context.Context) error { return nil } -func (s *Database) bootstrapMigrationsTable(ctx context.Context) error { - _, err := s.db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( +func (s *Database) bootstrapMigrationsTable( + ctx context.Context, +) error { + _, err := s.db.ExecContext(ctx, + `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP )`) if err != nil { - return fmt.Errorf("failed to create schema_migrations table: %w", err) + return fmt.Errorf( + "create schema_migrations table: %w", err, + ) } return nil @@ -164,17 +603,20 @@ func (s *Database) bootstrapMigrationsTable(ctx context.Context) error { func (s *Database) loadMigrations() ([]migration, error) { entries, err := fs.ReadDir(SchemaFiles, "schema") if err != nil { - return nil, fmt.Errorf("failed to read schema dir: %w", err) + return nil, fmt.Errorf("read schema dir: %w", err) } var migrations []migration for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { + if entry.IsDir() || + !strings.HasSuffix(entry.Name(), ".sql") { continue } - parts := strings.SplitN(entry.Name(), "_", minMigrationParts) + parts := strings.SplitN( + entry.Name(), "_", minMigrationParts, + ) if len(parts) < minMigrationParts { continue } @@ -184,9 +626,13 @@ func (s *Database) loadMigrations() ([]migration, error) { continue } - content, err := SchemaFiles.ReadFile("schema/" + entry.Name()) + content, err := SchemaFiles.ReadFile( + "schema/" + entry.Name(), + ) if err != nil { - return nil, fmt.Errorf("failed to read migration %s: %w", entry.Name(), err) + return nil, fmt.Errorf( + "read migration %s: %w", entry.Name(), err, + ) } migrations = append(migrations, migration{ @@ -203,29 +649,66 @@ func (s *Database) loadMigrations() ([]migration, error) { return migrations, nil } -func (s *Database) applyMigrations(ctx context.Context, migrations []migration) error { +// Item 4: Wrap each migration in a transaction +func (s *Database) applyMigrations( + ctx context.Context, + migrations []migration, +) error { for _, m := range migrations { var exists int - err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&exists) + err := s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + m.version, + ).Scan(&exists) if err != nil { - return fmt.Errorf("failed to check migration %d: %w", m.version, err) + return fmt.Errorf( + "check migration %d: %w", m.version, err, + ) } if exists > 0 { continue } - s.log.Info("applying migration", "version", m.version, "name", m.name) + s.log.Info( + "applying migration", + "version", m.version, "name", m.name, + ) - _, err = s.db.ExecContext(ctx, m.sql) + tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("failed to apply migration %d (%s): %w", m.version, m.name, err) + return fmt.Errorf( + "begin tx for migration %d: %w", m.version, err, + ) } - _, err = s.db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version) + _, err = tx.ExecContext(ctx, m.sql) if err != nil { - return fmt.Errorf("failed to record migration %d: %w", m.version, err) + _ = tx.Rollback() + + return fmt.Errorf( + "apply migration %d (%s): %w", + m.version, m.name, err, + ) + } + + _, err = tx.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + m.version, + ) + if err != nil { + _ = tx.Rollback() + + return fmt.Errorf( + "record migration %d: %w", m.version, err, + ) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf( + "commit migration %d: %w", m.version, err, + ) } } diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 0000000..c6de510 --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,494 @@ +package db_test + +import ( + "context" + "fmt" + "path/filepath" + "testing" + "time" + + "git.eeqj.de/sneak/chat/internal/db" +) + +const ( + nickAlice = "alice" + nickBob = "bob" + nickCharlie = "charlie" +) + +// setupTestDB creates a fresh database in a temp directory with +// all migrations applied. +func setupTestDB(t *testing.T) *db.Database { + t.Helper() + + dir := t.TempDir() + dsn := fmt.Sprintf( + "file:%s?_journal_mode=WAL", + filepath.Join(dir, "test.db"), + ) + + d, err := db.NewTest(dsn) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + + t.Cleanup(func() { _ = d.GetDB().Close() }) + + return d +} + +func TestCreateUser(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, err := d.CreateUser(ctx, "u1", nickAlice, "hash1") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + + if u.ID != "u1" || u.Nick != nickAlice { + t.Errorf("got user %+v", u) + } +} + +func TestCreateAuthToken(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, err := d.CreateUser(ctx, "u1", nickAlice, "h") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + + tok, err := d.CreateAuthToken(ctx, "tok1", "u1") + if err != nil { + t.Fatalf("CreateAuthToken: %v", err) + } + + if tok.Token != "tok1" || tok.UserID != "u1" { + t.Errorf("unexpected token: %+v", tok) + } + + u, err := tok.User(ctx) + if err != nil { + t.Fatalf("AuthToken.User: %v", err) + } + + if u.ID != "u1" || u.Nick != nickAlice { + t.Errorf("AuthToken.User got %+v", u) + } +} + +func TestCreateChannel(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, err := d.CreateChannel( + ctx, "c1", "#general", "welcome", "+n", + ) + if err != nil { + t.Fatalf("CreateChannel: %v", err) + } + + if ch.ID != "c1" || ch.Name != "#general" { + t.Errorf("unexpected channel: %+v", ch) + } +} + +func TestAddChannelMember(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#general", "", "") + + cm, err := d.AddChannelMember(ctx, "c1", "u1", "+o") + if err != nil { + t.Fatalf("AddChannelMember: %v", err) + } + + if cm.ChannelID != "c1" || cm.Modes != "+o" { + t.Errorf("unexpected member: %+v", cm) + } +} + +func TestCreateMessage(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + + msg, err := d.CreateMessage( + ctx, "m1", "u1", nickAlice, + "#general", "message", "hello", + ) + if err != nil { + t.Fatalf("CreateMessage: %v", err) + } + + if msg.ID != "m1" || msg.Body != "hello" { + t.Errorf("unexpected message: %+v", msg) + } +} + +func TestQueueMessage(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + _, _ = d.CreateMessage( + ctx, "m1", "u1", nickAlice, "u2", "message", "hi", + ) + + mq, err := d.QueueMessage(ctx, "u2", "m1") + if err != nil { + t.Fatalf("QueueMessage: %v", err) + } + + if mq.UserID != "u2" || mq.MessageID != "m1" { + t.Errorf("unexpected queue entry: %+v", mq) + } +} + +func TestCreateSession(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + + sess, err := d.CreateSession(ctx, "s1", "u1") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + if sess.ID != "s1" || sess.UserID != "u1" { + t.Errorf("unexpected session: %+v", sess) + } + + u, err := sess.User(ctx) + if err != nil { + t.Fatalf("Session.User: %v", err) + } + + if u.ID != "u1" { + t.Errorf("Session.User got %v, want u1", u.ID) + } +} + +func TestCreateServerLink(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + sl, err := d.CreateServerLink( + ctx, "sl1", "peer1", + "https://peer.example.com", "keyhash", true, + ) + if err != nil { + t.Fatalf("CreateServerLink: %v", err) + } + + if sl.ID != "sl1" || !sl.IsActive { + t.Errorf("unexpected server link: %+v", sl) + } +} + +func TestUserChannels(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#alpha", "", "") + _, _ = d.CreateChannel(ctx, "c2", "#beta", "", "") + _, _ = d.AddChannelMember(ctx, "c1", "u1", "") + _, _ = d.AddChannelMember(ctx, "c2", "u1", "") + + channels, err := u.Channels(ctx) + if err != nil { + t.Fatalf("User.Channels: %v", err) + } + + if len(channels) != 2 { + t.Fatalf("expected 2 channels, got %d", len(channels)) + } + + if channels[0].Name != "#alpha" { + t.Errorf("first channel: got %s", channels[0].Name) + } + + if channels[1].Name != "#beta" { + t.Errorf("second channel: got %s", channels[1].Name) + } +} + +func TestUserChannelsEmpty(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + + channels, err := u.Channels(ctx) + if err != nil { + t.Fatalf("User.Channels: %v", err) + } + + if len(channels) != 0 { + t.Errorf("expected 0 channels, got %d", len(channels)) + } +} + +func TestUserQueuedMessages(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + + for i := range 3 { + id := fmt.Sprintf("m%d", i) + + _, _ = d.CreateMessage( + ctx, id, "u2", nickBob, "u1", + "message", fmt.Sprintf("msg%d", i), + ) + + time.Sleep(10 * time.Millisecond) + + _, _ = d.QueueMessage(ctx, "u1", id) + } + + msgs, err := u.QueuedMessages(ctx) + if err != nil { + t.Fatalf("User.QueuedMessages: %v", err) + } + + if len(msgs) != 3 { + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } + + for i, msg := range msgs { + want := fmt.Sprintf("msg%d", i) + if msg.Body != want { + t.Errorf("msg %d: got %q, want %q", i, msg.Body, want) + } + } +} + +func TestUserQueuedMessagesEmpty(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + u, _ := d.CreateUser(ctx, "u1", nickAlice, "h") + + msgs, err := u.QueuedMessages(ctx) + if err != nil { + t.Fatalf("User.QueuedMessages: %v", err) + } + + if len(msgs) != 0 { + t.Errorf("expected 0 messages, got %d", len(msgs)) + } +} + +func TestChannelMembers(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + _, _ = d.CreateUser(ctx, "u3", nickCharlie, "h") + _, _ = d.AddChannelMember(ctx, "c1", "u1", "+o") + _, _ = d.AddChannelMember(ctx, "c1", "u2", "+v") + _, _ = d.AddChannelMember(ctx, "c1", "u3", "") + + members, err := ch.Members(ctx) + if err != nil { + t.Fatalf("Channel.Members: %v", err) + } + + if len(members) != 3 { + t.Fatalf("expected 3 members, got %d", len(members)) + } + + nicks := map[string]bool{} + for _, m := range members { + nicks[m.Nick] = true + } + + for _, want := range []string{ + nickAlice, nickBob, nickCharlie, + } { + if !nicks[want] { + t.Errorf("missing nick %q", want) + } + } +} + +func TestChannelMembersEmpty(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#empty", "", "") + + members, err := ch.Members(ctx) + if err != nil { + t.Fatalf("Channel.Members: %v", err) + } + + if len(members) != 0 { + t.Errorf("expected 0, got %d", len(members)) + } +} + +func TestChannelRecentMessages(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + + for i := range 5 { + id := fmt.Sprintf("m%d", i) + + _, _ = d.CreateMessage( + ctx, id, "u1", nickAlice, "#general", + "message", fmt.Sprintf("msg%d", i), + ) + + time.Sleep(10 * time.Millisecond) + } + + msgs, err := ch.RecentMessages(ctx, 3) + if err != nil { + t.Fatalf("RecentMessages: %v", err) + } + + if len(msgs) != 3 { + t.Fatalf("expected 3, got %d", len(msgs)) + } + + if msgs[0].Body != "msg4" { + t.Errorf("first: got %q, want msg4", msgs[0].Body) + } + + if msgs[2].Body != "msg2" { + t.Errorf("last: got %q, want msg2", msgs[2].Body) + } +} + +func TestChannelRecentMessagesLargeLimit(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateMessage( + ctx, "m1", "u1", nickAlice, + "#general", "message", "only", + ) + + msgs, err := ch.RecentMessages(ctx, 100) + if err != nil { + t.Fatalf("RecentMessages: %v", err) + } + + if len(msgs) != 1 { + t.Errorf("expected 1, got %d", len(msgs)) + } +} + +func TestChannelMemberUser(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#general", "", "") + + cm, _ := d.AddChannelMember(ctx, "c1", "u1", "+o") + + u, err := cm.User(ctx) + if err != nil { + t.Fatalf("ChannelMember.User: %v", err) + } + + if u.ID != "u1" || u.Nick != nickAlice { + t.Errorf("got %+v", u) + } +} + +func TestChannelMemberChannel(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateChannel(ctx, "c1", "#general", "topic", "+n") + + cm, _ := d.AddChannelMember(ctx, "c1", "u1", "") + + ch, err := cm.Channel(ctx) + if err != nil { + t.Fatalf("ChannelMember.Channel: %v", err) + } + + if ch.ID != "c1" || ch.Topic != "topic" { + t.Errorf("got %+v", ch) + } +} + +func TestDMMessage(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, _ = d.CreateUser(ctx, "u1", nickAlice, "h") + _, _ = d.CreateUser(ctx, "u2", nickBob, "h") + + msg, err := d.CreateMessage( + ctx, "m1", "u1", nickAlice, "u2", "message", "hey", + ) + if err != nil { + t.Fatalf("CreateMessage DM: %v", err) + } + + if msg.Target != "u2" { + t.Errorf("target: got %q, want u2", msg.Target) + } +} diff --git a/internal/db/schema/002_schema.sql b/internal/db/schema/002_schema.sql new file mode 100644 index 0000000..58dcb70 --- /dev/null +++ b/internal/db/schema/002_schema.sql @@ -0,0 +1,89 @@ +-- All schema changes go into this file until 1.0.0 is tagged. +-- There will not be migrations during the early development phase. +-- After 1.0.0, new changes get their own numbered migration files. + +-- Users: accounts and authentication +CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, -- UUID + nick TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen_at DATETIME +); + +-- Auth tokens: one user can have multiple active tokens (multiple devices) +CREATE TABLE IF NOT EXISTS auth_tokens ( + token TEXT PRIMARY KEY, -- random token string + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME, -- NULL = no expiry + last_used_at DATETIME +); +CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens(user_id); + +-- Channels: chat rooms +CREATE TABLE IF NOT EXISTS channels ( + id TEXT PRIMARY KEY, -- UUID + name TEXT NOT NULL UNIQUE, -- #general, etc. + topic TEXT NOT NULL DEFAULT '', + modes TEXT NOT NULL DEFAULT '', -- +i, +m, +s, +t, +n + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Channel members: who is in which channel, with per-user modes +CREATE TABLE IF NOT EXISTS channel_members ( + channel_id TEXT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + modes TEXT NOT NULL DEFAULT '', -- +o (operator), +v (voice) + joined_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (channel_id, user_id) +); +CREATE INDEX IF NOT EXISTS idx_channel_members_user_id ON channel_members(user_id); + +-- Messages: channel and DM history (rotated per MAX_HISTORY) +CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, -- UUID + ts DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + from_user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + from_nick TEXT NOT NULL, -- denormalized for history + target TEXT NOT NULL, -- #channel name or user UUID for DMs + type TEXT NOT NULL DEFAULT 'message', -- message, action, notice, join, part, quit, topic, mode, nick, system + body TEXT NOT NULL DEFAULT '', + meta TEXT NOT NULL DEFAULT '{}', -- JSON extensible metadata + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_messages_target_ts ON messages(target, ts); +CREATE INDEX IF NOT EXISTS idx_messages_from_user ON messages(from_user_id); + +-- Message queue: per-user pending delivery (unread messages) +CREATE TABLE IF NOT EXISTS message_queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + message_id TEXT NOT NULL REFERENCES messages(id) ON DELETE CASCADE, + queued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, message_id) +); +CREATE INDEX IF NOT EXISTS idx_message_queue_user_id ON message_queue(user_id, queued_at); + +-- Sessions: server-held session state +CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, -- UUID + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_active_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME -- idle timeout +); +CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id); + +-- Server links: federation peer configuration +CREATE TABLE IF NOT EXISTS server_links ( + id TEXT PRIMARY KEY, -- UUID + name TEXT NOT NULL UNIQUE, -- human-readable peer name + url TEXT NOT NULL, -- base URL of peer server + shared_key_hash TEXT NOT NULL, -- hashed shared secret + is_active INTEGER NOT NULL DEFAULT 1, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen_at DATETIME +); diff --git a/internal/db/schema/002_tables.sql b/internal/db/schema/002_tables.sql deleted file mode 100644 index 91f91c8..0000000 --- a/internal/db/schema/002_tables.sql +++ /dev/null @@ -1,8 +0,0 @@ -CREATE TABLE IF NOT EXISTS channels ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - topic TEXT NOT NULL DEFAULT '', - modes TEXT NOT NULL DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP -); diff --git a/internal/models/auth_token.go b/internal/models/auth_token.go new file mode 100644 index 0000000..f1646e2 --- /dev/null +++ b/internal/models/auth_token.go @@ -0,0 +1,27 @@ +package models + +import ( + "context" + "fmt" + "time" +) + +// AuthToken represents an authentication token for a user session. +type AuthToken struct { + Base + + Token string `json:"-"` + UserID string `json:"userId"` + CreatedAt time.Time `json:"createdAt"` + ExpiresAt *time.Time `json:"expiresAt,omitempty"` + LastUsedAt *time.Time `json:"lastUsedAt,omitempty"` +} + +// User returns the user who owns this token. +func (t *AuthToken) User(ctx context.Context) (*User, error) { + if ul := t.GetUserLookup(); ul != nil { + return ul.GetUserByID(ctx, t.UserID) + } + + return nil, fmt.Errorf("user lookup not available") +} diff --git a/internal/models/channel.go b/internal/models/channel.go index 2ad401e..addafc9 100644 --- a/internal/models/channel.go +++ b/internal/models/channel.go @@ -1,6 +1,7 @@ package models import ( + "context" "time" ) @@ -8,10 +9,88 @@ import ( type Channel struct { Base - ID int64 `json:"id"` + ID string `json:"id"` Name string `json:"name"` Topic string `json:"topic"` Modes string `json:"modes"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } + +// Members returns all users who are members of this channel. +func (c *Channel) Members(ctx context.Context) ([]*ChannelMember, error) { + rows, err := c.GetDB().QueryContext(ctx, ` + SELECT cm.channel_id, cm.user_id, cm.modes, cm.joined_at, + u.nick + FROM channel_members cm + JOIN users u ON u.id = cm.user_id + WHERE cm.channel_id = ? + ORDER BY cm.joined_at`, + c.ID, + ) + if err != nil { + return nil, err + } + + defer func() { _ = rows.Close() }() + + members := []*ChannelMember{} + + for rows.Next() { + m := &ChannelMember{} + m.SetDB(c.db) + + err = rows.Scan( + &m.ChannelID, &m.UserID, &m.Modes, + &m.JoinedAt, &m.Nick, + ) + if err != nil { + return nil, err + } + + members = append(members, m) + } + + return members, rows.Err() +} + +// RecentMessages returns the most recent messages in this channel. +func (c *Channel) RecentMessages( + ctx context.Context, + limit int, +) ([]*Message, error) { + rows, err := c.GetDB().QueryContext(ctx, ` + SELECT id, ts, from_user_id, from_nick, + target, type, body, meta, created_at + FROM messages + WHERE target = ? + ORDER BY ts DESC + LIMIT ?`, + c.Name, limit, + ) + if err != nil { + return nil, err + } + + defer func() { _ = rows.Close() }() + + messages := []*Message{} + + for rows.Next() { + msg := &Message{} + msg.SetDB(c.db) + + err = rows.Scan( + &msg.ID, &msg.Timestamp, &msg.FromUserID, + &msg.FromNick, &msg.Target, &msg.Type, + &msg.Body, &msg.Meta, &msg.CreatedAt, + ) + if err != nil { + return nil, err + } + + messages = append(messages, msg) + } + + return messages, rows.Err() +} diff --git a/internal/models/channel_member.go b/internal/models/channel_member.go new file mode 100644 index 0000000..f93ed9d --- /dev/null +++ b/internal/models/channel_member.go @@ -0,0 +1,36 @@ +package models + +import ( + "context" + "fmt" + "time" +) + +// ChannelMember represents a user's membership in a channel. +type ChannelMember struct { + Base + + ChannelID string `json:"channelId"` + UserID string `json:"userId"` + Modes string `json:"modes"` + JoinedAt time.Time `json:"joinedAt"` + Nick string `json:"nick"` // denormalized from users table +} + +// User returns the full User for this membership. +func (cm *ChannelMember) User(ctx context.Context) (*User, error) { + if ul := cm.GetUserLookup(); ul != nil { + return ul.GetUserByID(ctx, cm.UserID) + } + + return nil, fmt.Errorf("user lookup not available") +} + +// Channel returns the full Channel for this membership. +func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) { + if cl := cm.GetChannelLookup(); cl != nil { + return cl.GetChannelByID(ctx, cm.ChannelID) + } + + return nil, fmt.Errorf("channel lookup not available") +} diff --git a/internal/models/message.go b/internal/models/message.go new file mode 100644 index 0000000..652ae0d --- /dev/null +++ b/internal/models/message.go @@ -0,0 +1,20 @@ +package models + +import ( + "time" +) + +// Message represents a chat message (channel or DM). +type Message struct { + Base + + ID string `json:"id"` + Timestamp time.Time `json:"ts"` + FromUserID string `json:"fromUserId"` + FromNick string `json:"from"` + Target string `json:"to"` + Type string `json:"type"` + Body string `json:"body"` + Meta string `json:"meta"` + CreatedAt time.Time `json:"createdAt"` +} diff --git a/internal/models/message_queue.go b/internal/models/message_queue.go new file mode 100644 index 0000000..616cbc3 --- /dev/null +++ b/internal/models/message_queue.go @@ -0,0 +1,15 @@ +package models + +import ( + "time" +) + +// MessageQueueEntry represents a pending message delivery for a user. +type MessageQueueEntry struct { + Base + + ID int64 `json:"id"` + UserID string `json:"userId"` + MessageID string `json:"messageId"` + QueuedAt time.Time `json:"queuedAt"` +} diff --git a/internal/models/model.go b/internal/models/model.go index 89be8d4..b65c6e8 100644 --- a/internal/models/model.go +++ b/internal/models/model.go @@ -1,14 +1,29 @@ // Package models defines the data models used by the chat application. +// All model structs embed Base, which provides database access for +// relation-fetching methods directly on model instances. package models -import "database/sql" +import ( + "context" + "database/sql" +) -// DB is the interface that models use to query relations. +// DB is the interface that models use to query the database. // This avoids a circular import with the db package. type DB interface { GetDB() *sql.DB } +// UserLookup provides user lookup by ID without circular imports. +type UserLookup interface { + GetUserByID(ctx context.Context, id string) (*User, error) +} + +// ChannelLookup provides channel lookup by ID without circular imports. +type ChannelLookup interface { + GetChannelByID(ctx context.Context, id string) (*Channel, error) +} + // Base is embedded in all model structs to provide database access. type Base struct { db DB @@ -18,3 +33,26 @@ type Base struct { func (b *Base) SetDB(d DB) { b.db = d } + +// GetDB returns the database interface for use in model methods. +func (b *Base) GetDB() *sql.DB { + return b.db.GetDB() +} + +// GetUserLookup returns the DB as a UserLookup if it implements the interface. +func (b *Base) GetUserLookup() UserLookup { + if ul, ok := b.db.(UserLookup); ok { + return ul + } + + return nil +} + +// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface. +func (b *Base) GetChannelLookup() ChannelLookup { + if cl, ok := b.db.(ChannelLookup); ok { + return cl + } + + return nil +} diff --git a/internal/models/server_link.go b/internal/models/server_link.go new file mode 100644 index 0000000..004ef67 --- /dev/null +++ b/internal/models/server_link.go @@ -0,0 +1,18 @@ +package models + +import ( + "time" +) + +// ServerLink represents a federation peer server configuration. +type ServerLink struct { + Base + + ID string `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + SharedKeyHash string `json:"-"` + IsActive bool `json:"isActive"` + CreatedAt time.Time `json:"createdAt"` + LastSeenAt *time.Time `json:"lastSeenAt,omitempty"` +} diff --git a/internal/models/session.go b/internal/models/session.go new file mode 100644 index 0000000..42231d9 --- /dev/null +++ b/internal/models/session.go @@ -0,0 +1,27 @@ +package models + +import ( + "context" + "fmt" + "time" +) + +// Session represents a server-held user session. +type Session struct { + Base + + ID string `json:"id"` + UserID string `json:"userId"` + CreatedAt time.Time `json:"createdAt"` + LastActiveAt time.Time `json:"lastActiveAt"` + ExpiresAt *time.Time `json:"expiresAt,omitempty"` +} + +// User returns the user who owns this session. +func (s *Session) User(ctx context.Context) (*User, error) { + if ul := s.GetUserLookup(); ul != nil { + return ul.GetUserByID(ctx, s.UserID) + } + + return nil, fmt.Errorf("user lookup not available") +} diff --git a/internal/models/user.go b/internal/models/user.go new file mode 100644 index 0000000..f3d778f --- /dev/null +++ b/internal/models/user.go @@ -0,0 +1,92 @@ +package models + +import ( + "context" + "time" +) + +// User represents a registered user account. +type User struct { + Base + + ID string `json:"id"` + Nick string `json:"nick"` + PasswordHash string `json:"-"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + LastSeenAt *time.Time `json:"lastSeenAt,omitempty"` +} + +// Channels returns all channels the user is a member of. +func (u *User) Channels(ctx context.Context) ([]*Channel, error) { + rows, err := u.GetDB().QueryContext(ctx, ` + SELECT c.id, c.name, c.topic, c.modes, c.created_at, c.updated_at + FROM channels c + JOIN channel_members cm ON cm.channel_id = c.id + WHERE cm.user_id = ? + ORDER BY c.name`, + u.ID, + ) + if err != nil { + return nil, err + } + + defer func() { _ = rows.Close() }() + + channels := []*Channel{} + + for rows.Next() { + c := &Channel{} + c.SetDB(u.db) + + err = rows.Scan( + &c.ID, &c.Name, &c.Topic, &c.Modes, + &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, err + } + + channels = append(channels, c) + } + + return channels, rows.Err() +} + +// QueuedMessages returns undelivered messages for this user. +func (u *User) QueuedMessages(ctx context.Context) ([]*Message, error) { + rows, err := u.GetDB().QueryContext(ctx, ` + SELECT m.id, m.ts, m.from_user_id, m.from_nick, + m.target, m.type, m.body, m.meta, m.created_at + FROM messages m + JOIN message_queue mq ON mq.message_id = m.id + WHERE mq.user_id = ? + ORDER BY mq.queued_at ASC`, + u.ID, + ) + if err != nil { + return nil, err + } + + defer func() { _ = rows.Close() }() + + messages := []*Message{} + + for rows.Next() { + msg := &Message{} + msg.SetDB(u.db) + + err = rows.Scan( + &msg.ID, &msg.Timestamp, &msg.FromUserID, + &msg.FromNick, &msg.Target, &msg.Type, + &msg.Body, &msg.Meta, &msg.CreatedAt, + ) + if err != nil { + return nil, err + } + + messages = append(messages, msg) + } + + return messages, rows.Err() +}