diff --git a/internal/broker/broker.go b/internal/broker/broker.go new file mode 100644 index 0000000..7d82b0c --- /dev/null +++ b/internal/broker/broker.go @@ -0,0 +1,60 @@ +// Package broker provides an in-memory pub/sub for long-poll notifications. +package broker + +import ( + "sync" +) + +// Broker notifies waiting clients when new messages are available. +type Broker struct { + mu sync.Mutex + listeners map[int64][]chan struct{} // userID -> list of waiting channels +} + +// New creates a new Broker. +func New() *Broker { + return &Broker{ + listeners: make(map[int64][]chan struct{}), + } +} + +// Wait returns a channel that will be closed when a message is available for the user. +func (b *Broker) Wait(userID int64) chan struct{} { + ch := make(chan struct{}, 1) + b.mu.Lock() + b.listeners[userID] = append(b.listeners[userID], ch) + b.mu.Unlock() + return ch +} + +// Notify wakes up all waiting clients for a user. +func (b *Broker) Notify(userID int64) { + b.mu.Lock() + waiters := b.listeners[userID] + delete(b.listeners, userID) + b.mu.Unlock() + + for _, ch := range waiters { + select { + case ch <- struct{}{}: + default: + } + } +} + +// Remove removes a specific wait channel (for cleanup on timeout). +func (b *Broker) Remove(userID int64, ch chan struct{}) { + b.mu.Lock() + defer b.mu.Unlock() + + waiters := b.listeners[userID] + for i, w := range waiters { + if w == ch { + b.listeners[userID] = append(waiters[:i], waiters[i+1:]...) + break + } + } + if len(b.listeners[userID]) == 0 { + delete(b.listeners, userID) + } +} diff --git a/internal/db/db.go b/internal/db/db.go index 5062827..e23d4ec 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -11,11 +11,9 @@ import ( "sort" "strconv" "strings" - "time" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/logger" - "git.eeqj.de/sneak/chat/internal/models" "go.uber.org/fx" _ "github.com/joho/godotenv/autoload" // loads .env file @@ -57,16 +55,13 @@ func New(lc fx.Lifecycle, params Params) (*Database, error) { lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { s.log.Info("Database OnStart Hook") - return s.connect(ctx) }, OnStop: func(_ context.Context) error { s.log.Info("Database OnStop Hook") - if s.db != nil { return s.db.Close() } - return nil }, }) @@ -74,460 +69,11 @@ func New(lc fx.Lifecycle, params Params) (*Database, error) { return s, nil } -// NewTest creates a Database for testing, bypassing fx lifecycle. -// It connects to the given DSN and runs all migrations. -func NewTest(dsn string) (*Database, error) { - d, err := sql.Open("sqlite", dsn) - if err != nil { - return nil, err - } - - s := &Database{ - db: d, - log: slog.Default(), - } - - // Item 9: Enable foreign keys - _, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx // no context in sql.Open path - if err != nil { - _ = d.Close() - - return nil, fmt.Errorf("enable foreign keys: %w", err) - } - - ctx := context.Background() - - err = s.runMigrations(ctx) - if err != nil { - _ = d.Close() - - return nil, err - } - - return s, nil -} - // GetDB returns the underlying sql.DB connection. func (s *Database) GetDB() *sql.DB { return s.db } -// Hydrate injects the database reference into any model that -// embeds Base. -func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) { - m.SetDB(s) -} - -// GetUserByID looks up a user by their ID. -func (s *Database) GetUserByID( - ctx context.Context, - id string, -) (*models.User, error) { - u := &models.User{} - s.Hydrate(u) - - err := s.db.QueryRowContext(ctx, ` - SELECT id, nick, password_hash, created_at, updated_at, last_seen_at - FROM users WHERE id = ?`, - id, - ).Scan( - &u.ID, &u.Nick, &u.PasswordHash, - &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, - ) - if err != nil { - return nil, err - } - - return u, nil -} - -// GetChannelByID looks up a channel by its ID. -func (s *Database) GetChannelByID( - ctx context.Context, - id string, -) (*models.Channel, error) { - c := &models.Channel{} - s.Hydrate(c) - - err := s.db.QueryRowContext(ctx, ` - SELECT id, name, topic, modes, created_at, updated_at - FROM channels WHERE id = ?`, - id, - ).Scan( - &c.ID, &c.Name, &c.Topic, &c.Modes, - &c.CreatedAt, &c.UpdatedAt, - ) - if err != nil { - return nil, err - } - - return c, nil -} - -// GetUserByNickModel looks up a user by their nick. -func (s *Database) GetUserByNickModel( - ctx context.Context, - nick string, -) (*models.User, error) { - u := &models.User{} - s.Hydrate(u) - - err := s.db.QueryRowContext(ctx, ` - SELECT id, nick, password_hash, created_at, updated_at, last_seen_at - FROM users WHERE nick = ?`, - nick, - ).Scan( - &u.ID, &u.Nick, &u.PasswordHash, - &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, - ) - if err != nil { - return nil, err - } - - return u, nil -} - -// GetUserByTokenModel looks up a user by their auth token. -func (s *Database) GetUserByTokenModel( - ctx context.Context, - token string, -) (*models.User, error) { - u := &models.User{} - s.Hydrate(u) - - err := s.db.QueryRowContext(ctx, ` - SELECT u.id, u.nick, u.password_hash, - u.created_at, u.updated_at, u.last_seen_at - FROM users u - JOIN auth_tokens t ON t.user_id = u.id - WHERE t.token = ?`, - token, - ).Scan( - &u.ID, &u.Nick, &u.PasswordHash, - &u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt, - ) - if err != nil { - return nil, err - } - - return u, nil -} - -// DeleteAuthToken removes an auth token from the database. -func (s *Database) DeleteAuthToken( - ctx context.Context, - token string, -) error { - _, err := s.db.ExecContext(ctx, - `DELETE FROM auth_tokens WHERE token = ?`, token, - ) - - return err -} - -// UpdateUserLastSeen updates the last_seen_at timestamp for a user. -func (s *Database) UpdateUserLastSeen( - ctx context.Context, - userID string, -) error { - _, err := s.db.ExecContext(ctx, - `UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`, - userID, - ) - - return err -} - -// CreateUserModel inserts a new user into the database. -func (s *Database) CreateUserModel( - ctx context.Context, - id, nick, passwordHash string, -) (*models.User, error) { - now := time.Now() - - _, err := s.db.ExecContext(ctx, - `INSERT INTO users (id, nick, password_hash) - VALUES (?, ?, ?)`, - id, nick, passwordHash, - ) - if err != nil { - return nil, err - } - - u := &models.User{ - ID: id, Nick: nick, PasswordHash: passwordHash, - CreatedAt: now, UpdatedAt: now, - } - s.Hydrate(u) - - return u, nil -} - -// CreateChannel inserts a new channel into the database. -func (s *Database) CreateChannel( - ctx context.Context, - id, name, topic, modes string, -) (*models.Channel, error) { - now := time.Now() - - _, err := s.db.ExecContext(ctx, - `INSERT INTO channels (id, name, topic, modes) - VALUES (?, ?, ?, ?)`, - id, name, topic, modes, - ) - if err != nil { - return nil, err - } - - c := &models.Channel{ - ID: id, Name: name, Topic: topic, Modes: modes, - CreatedAt: now, UpdatedAt: now, - } - s.Hydrate(c) - - return c, nil -} - -// AddChannelMember adds a user to a channel with the given modes. -func (s *Database) AddChannelMember( - ctx context.Context, - channelID, userID, modes string, -) (*models.ChannelMember, error) { - now := time.Now() - - _, err := s.db.ExecContext(ctx, - `INSERT INTO channel_members - (channel_id, user_id, modes) - VALUES (?, ?, ?)`, - channelID, userID, modes, - ) - if err != nil { - return nil, err - } - - cm := &models.ChannelMember{ - ChannelID: channelID, - UserID: userID, - Modes: modes, - JoinedAt: now, - } - s.Hydrate(cm) - - return cm, nil -} - -// CreateMessage inserts a new message into the database. -func (s *Database) CreateMessage( - ctx context.Context, - id, fromUserID, fromNick, target, msgType, body string, -) (*models.Message, error) { - now := time.Now() - - _, err := s.db.ExecContext(ctx, - `INSERT INTO messages - (id, from_user_id, from_nick, target, type, body) - VALUES (?, ?, ?, ?, ?, ?)`, - id, fromUserID, fromNick, target, msgType, body, - ) - if err != nil { - return nil, err - } - - m := &models.Message{ - ID: id, - FromUserID: fromUserID, - FromNick: fromNick, - Target: target, - Type: msgType, - Body: body, - Timestamp: now, - CreatedAt: now, - } - s.Hydrate(m) - - return m, nil -} - -// QueueMessage adds a message to a user's delivery queue. -func (s *Database) QueueMessage( - ctx context.Context, - userID, messageID string, -) (*models.MessageQueueEntry, error) { - now := time.Now() - - res, err := s.db.ExecContext(ctx, - `INSERT INTO message_queue (user_id, message_id) - VALUES (?, ?)`, - userID, messageID, - ) - if err != nil { - return nil, err - } - - entryID, err := res.LastInsertId() - if err != nil { - return nil, fmt.Errorf("get last insert id: %w", err) - } - - mq := &models.MessageQueueEntry{ - ID: entryID, - UserID: userID, - MessageID: messageID, - QueuedAt: now, - } - s.Hydrate(mq) - - return mq, nil -} - -// DequeueMessages returns up to limit pending messages for a user, -// ordered by queue time (oldest first). -func (s *Database) DequeueMessages( - ctx context.Context, - userID string, - limit int, -) ([]*models.MessageQueueEntry, error) { - rows, err := s.db.QueryContext(ctx, ` - SELECT id, user_id, message_id, queued_at - FROM message_queue - WHERE user_id = ? - ORDER BY queued_at ASC - LIMIT ?`, - userID, limit, - ) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - entries := []*models.MessageQueueEntry{} - - for rows.Next() { - e := &models.MessageQueueEntry{} - s.Hydrate(e) - - err = rows.Scan(&e.ID, &e.UserID, &e.MessageID, &e.QueuedAt) - if err != nil { - return nil, err - } - - entries = append(entries, e) - } - - return entries, rows.Err() -} - -// AckMessages removes the given queue entry IDs, marking them as delivered. -func (s *Database) AckMessages( - ctx context.Context, - entryIDs []int64, -) error { - if len(entryIDs) == 0 { - return nil - } - - placeholders := make([]string, len(entryIDs)) - args := make([]any, len(entryIDs)) - - for i, id := range entryIDs { - placeholders[i] = "?" - args[i] = id - } - - query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input - "DELETE FROM message_queue WHERE id IN (%s)", - strings.Join(placeholders, ","), - ) - - _, err := s.db.ExecContext(ctx, query, args...) - - return err -} - -// CreateAuthToken inserts a new auth token for a user. -func (s *Database) CreateAuthToken( - ctx context.Context, - token, userID string, -) (*models.AuthToken, error) { - now := time.Now() - - _, err := s.db.ExecContext(ctx, - `INSERT INTO auth_tokens (token, user_id) - VALUES (?, ?)`, - token, userID, - ) - if err != nil { - return nil, err - } - - at := &models.AuthToken{Token: token, UserID: userID, CreatedAt: now} - s.Hydrate(at) - - return at, nil -} - -// CreateSession inserts a new session for a user. -func (s *Database) CreateSession( - ctx context.Context, - id, userID string, -) (*models.Session, error) { - now := time.Now() - - _, err := s.db.ExecContext(ctx, - `INSERT INTO sessions (id, user_id) - VALUES (?, ?)`, - id, userID, - ) - if err != nil { - return nil, err - } - - sess := &models.Session{ - ID: id, UserID: userID, - CreatedAt: now, LastActiveAt: now, - } - s.Hydrate(sess) - - return sess, nil -} - -// CreateServerLink inserts a new server link. -func (s *Database) CreateServerLink( - ctx context.Context, - id, name, url, sharedKeyHash string, - isActive bool, -) (*models.ServerLink, error) { - now := time.Now() - active := 0 - - if isActive { - active = 1 - } - - _, err := s.db.ExecContext(ctx, - `INSERT INTO server_links - (id, name, url, shared_key_hash, is_active) - VALUES (?, ?, ?, ?, ?)`, - id, name, url, sharedKeyHash, active, - ) - if err != nil { - return nil, err - } - - sl := &models.ServerLink{ - ID: id, - Name: name, - URL: url, - SharedKeyHash: sharedKeyHash, - IsActive: isActive, - CreatedAt: now, - } - s.Hydrate(sl) - - return sl, nil -} - func (s *Database) connect(ctx context.Context) error { dbURL := s.params.Config.DBURL if dbURL == "" { @@ -539,23 +85,19 @@ func (s *Database) connect(ctx context.Context) error { d, err := sql.Open("sqlite", dbURL) if err != nil { s.log.Error("failed to open database", "error", err) - return err } err = d.PingContext(ctx) if err != nil { s.log.Error("failed to ping database", "error", err) - return err } s.db = d s.log.Info("database connected") - // Item 9: Enable foreign keys on every connection - _, err = s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON") - if err != nil { + if _, err := s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { return fmt.Errorf("enable foreign keys: %w", err) } @@ -569,9 +111,13 @@ type migration struct { } func (s *Database) runMigrations(ctx context.Context) error { - err := s.bootstrapMigrationsTable(ctx) + _, err := s.db.ExecContext(ctx, + `CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`) if err != nil { - return err + return fmt.Errorf("create schema_migrations table: %w", err) } migrations, err := s.loadMigrations() @@ -579,30 +125,47 @@ func (s *Database) runMigrations(ctx context.Context) error { return err } - err = s.applyMigrations(ctx, migrations) - if err != nil { - return err + for _, m := range migrations { + var exists int + err := s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + m.version, + ).Scan(&exists) + if err != nil { + return fmt.Errorf("check migration %d: %w", m.version, err) + } + if exists > 0 { + continue + } + + s.log.Info("applying migration", "version", m.version, "name", m.name) + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin tx for migration %d: %w", m.version, err) + } + + _, err = tx.ExecContext(ctx, m.sql) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("apply migration %d (%s): %w", m.version, m.name, err) + } + + _, err = tx.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + m.version, + ) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("record migration %d: %w", m.version, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit migration %d: %w", m.version, err) + } } s.log.Info("database migrations complete") - - return nil -} - -func (s *Database) bootstrapMigrationsTable( - ctx context.Context, -) error { - _, err := s.db.ExecContext(ctx, - `CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`) - if err != nil { - return fmt.Errorf( - "create schema_migrations table: %w", err, - ) - } - return nil } @@ -613,16 +176,12 @@ func (s *Database) loadMigrations() ([]migration, error) { } var migrations []migration - for _, entry := range entries { - if entry.IsDir() || - !strings.HasSuffix(entry.Name(), ".sql") { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } - parts := strings.SplitN( - entry.Name(), "_", minMigrationParts, - ) + parts := strings.SplitN(entry.Name(), "_", minMigrationParts) if len(parts) < minMigrationParts { continue } @@ -632,13 +191,9 @@ func (s *Database) loadMigrations() ([]migration, error) { continue } - content, err := SchemaFiles.ReadFile( - "schema/" + entry.Name(), - ) + content, err := SchemaFiles.ReadFile("schema/" + entry.Name()) if err != nil { - return nil, fmt.Errorf( - "read migration %s: %w", entry.Name(), err, - ) + return nil, fmt.Errorf("read migration %s: %w", entry.Name(), err) } migrations = append(migrations, migration{ @@ -654,82 +209,3 @@ func (s *Database) loadMigrations() ([]migration, error) { return migrations, nil } - -// Item 4: Wrap each migration in a transaction -func (s *Database) applyMigrations( - ctx context.Context, - migrations []migration, -) error { - for _, m := range migrations { - var exists int - - err := s.db.QueryRowContext(ctx, - "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", - m.version, - ).Scan(&exists) - if err != nil { - return fmt.Errorf( - "check migration %d: %w", m.version, err, - ) - } - - if exists > 0 { - continue - } - - s.log.Info( - "applying migration", - "version", m.version, "name", m.name, - ) - - err = s.executeMigration(ctx, m) - if err != nil { - return err - } - } - - return nil -} - -func (s *Database) executeMigration( - ctx context.Context, - m migration, -) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf( - "begin tx for migration %d: %w", m.version, err, - ) - } - - _, err = tx.ExecContext(ctx, m.sql) - if err != nil { - _ = tx.Rollback() - - return fmt.Errorf( - "apply migration %d (%s): %w", - m.version, m.name, err, - ) - } - - _, err = tx.ExecContext(ctx, - "INSERT INTO schema_migrations (version) VALUES (?)", - m.version, - ) - if err != nil { - _ = tx.Rollback() - - return fmt.Errorf( - "record migration %d: %w", m.version, err, - ) - } - - err = tx.Commit() - if err != nil { - return fmt.Errorf( - "commit migration %d: %w", m.version, err, - ) - } - - return nil -} diff --git a/internal/db/db_test.go b/internal/db/db_test.go deleted file mode 100644 index b3cf841..0000000 --- a/internal/db/db_test.go +++ /dev/null @@ -1,425 +0,0 @@ -package db_test - -import ( - "context" - "fmt" - "path/filepath" - "testing" - "time" - - "git.eeqj.de/sneak/chat/internal/db" -) - -const ( - nickAlice = "alice" - nickBob = "bob" - nickCharlie = "charlie" -) - -// setupTestDB creates a fresh database in a temp directory with -// all migrations applied. -func setupTestDB(t *testing.T) *db.Database { - t.Helper() - - dir := t.TempDir() - dsn := fmt.Sprintf( - "file:%s?_journal_mode=WAL", - filepath.Join(dir, "test.db"), - ) - - d, err := db.NewTest(dsn) - if err != nil { - t.Fatalf("failed to create test database: %v", err) - } - - t.Cleanup(func() { _ = d.GetDB().Close() }) - - return d -} - -func TestCreateUser(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - id, token, err := d.CreateUser(ctx, nickAlice) - if err != nil { - t.Fatalf("CreateUser: %v", err) - } - - if id <= 0 { - t.Errorf("expected positive id, got %d", id) - } - - if token == "" { - t.Error("expected non-empty token") - } -} - -func TestGetUserByToken(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, token, _ := d.CreateUser(ctx, nickAlice) - - id, nick, err := d.GetUserByToken(ctx, token) - if err != nil { - t.Fatalf("GetUserByToken: %v", err) - } - - if id <= 0 || nick != nickAlice { - t.Errorf( - "got id=%d nick=%s, want nick=%s", - id, nick, nickAlice, - ) - } -} - -func TestGetUserByNick(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - origID, _, _ := d.CreateUser(ctx, nickAlice) - - id, err := d.GetUserByNick(ctx, nickAlice) - if err != nil { - t.Fatalf("GetUserByNick: %v", err) - } - - if id != origID { - t.Errorf("got id %d, want %d", id, origID) - } -} - -func TestGetOrCreateChannel(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - id1, err := d.GetOrCreateChannel(ctx, "#general") - if err != nil { - t.Fatalf("GetOrCreateChannel: %v", err) - } - - if id1 <= 0 { - t.Errorf("expected positive id, got %d", id1) - } - - // Same channel returns same ID. - id2, err := d.GetOrCreateChannel(ctx, "#general") - if err != nil { - t.Fatalf("GetOrCreateChannel(2): %v", err) - } - - if id1 != id2 { - t.Errorf("got different ids: %d vs %d", id1, id2) - } -} - -func TestJoinAndListChannels(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid, _, _ := d.CreateUser(ctx, nickAlice) - ch1, _ := d.GetOrCreateChannel(ctx, "#alpha") - ch2, _ := d.GetOrCreateChannel(ctx, "#beta") - - _ = d.JoinChannel(ctx, ch1, uid) - _ = d.JoinChannel(ctx, ch2, uid) - - channels, err := d.ListChannels(ctx, uid) - if err != nil { - t.Fatalf("ListChannels: %v", err) - } - - if len(channels) != 2 { - t.Fatalf("expected 2 channels, got %d", len(channels)) - } -} - -func TestListChannelsEmpty(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid, _, _ := d.CreateUser(ctx, nickAlice) - - channels, err := d.ListChannels(ctx, uid) - if err != nil { - t.Fatalf("ListChannels: %v", err) - } - - if len(channels) != 0 { - t.Errorf("expected 0 channels, got %d", len(channels)) - } -} - -func TestPartChannel(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid, _, _ := d.CreateUser(ctx, nickAlice) - chID, _ := d.GetOrCreateChannel(ctx, "#general") - - _ = d.JoinChannel(ctx, chID, uid) - _ = d.PartChannel(ctx, chID, uid) - - channels, err := d.ListChannels(ctx, uid) - if err != nil { - t.Fatalf("ListChannels: %v", err) - } - - if len(channels) != 0 { - t.Errorf("expected 0 after part, got %d", len(channels)) - } -} - -func TestSendAndGetMessages(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid, _, _ := d.CreateUser(ctx, nickAlice) - chID, _ := d.GetOrCreateChannel(ctx, "#general") - _ = d.JoinChannel(ctx, chID, uid) - - _, err := d.SendMessage(ctx, chID, uid, "hello world") - if err != nil { - t.Fatalf("SendMessage: %v", err) - } - - msgs, err := d.GetMessages(ctx, chID, 0, 0) - if err != nil { - t.Fatalf("GetMessages: %v", err) - } - - if len(msgs) != 1 { - t.Fatalf("expected 1 message, got %d", len(msgs)) - } - - if msgs[0].Content != "hello world" { - t.Errorf( - "got content %q, want %q", - msgs[0].Content, "hello world", - ) - } -} - -func TestChannelMembers(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid1, _, _ := d.CreateUser(ctx, nickAlice) - uid2, _, _ := d.CreateUser(ctx, nickBob) - uid3, _, _ := d.CreateUser(ctx, nickCharlie) - chID, _ := d.GetOrCreateChannel(ctx, "#general") - - _ = d.JoinChannel(ctx, chID, uid1) - _ = d.JoinChannel(ctx, chID, uid2) - _ = d.JoinChannel(ctx, chID, uid3) - - members, err := d.ChannelMembers(ctx, chID) - if err != nil { - t.Fatalf("ChannelMembers: %v", err) - } - - if len(members) != 3 { - t.Fatalf("expected 3 members, got %d", len(members)) - } -} - -func TestChannelMembersEmpty(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - chID, _ := d.GetOrCreateChannel(ctx, "#empty") - - members, err := d.ChannelMembers(ctx, chID) - if err != nil { - t.Fatalf("ChannelMembers: %v", err) - } - - if len(members) != 0 { - t.Errorf("expected 0, got %d", len(members)) - } -} - -func TestSendDM(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid1, _, _ := d.CreateUser(ctx, nickAlice) - uid2, _, _ := d.CreateUser(ctx, nickBob) - - msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob") - if err != nil { - t.Fatalf("SendDM: %v", err) - } - - if msgID <= 0 { - t.Errorf("expected positive msgID, got %d", msgID) - } - - msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0) - if err != nil { - t.Fatalf("GetDMs: %v", err) - } - - if len(msgs) != 1 { - t.Fatalf("expected 1 DM, got %d", len(msgs)) - } - - if msgs[0].Content != "hey bob" { - t.Errorf("got %q, want %q", msgs[0].Content, "hey bob") - } -} - -func TestPollMessages(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid1, _, _ := d.CreateUser(ctx, nickAlice) - uid2, _, _ := d.CreateUser(ctx, nickBob) - chID, _ := d.GetOrCreateChannel(ctx, "#general") - - _ = d.JoinChannel(ctx, chID, uid1) - _ = d.JoinChannel(ctx, chID, uid2) - - _, _ = d.SendMessage(ctx, chID, uid2, "hello") - _, _ = d.SendDM(ctx, uid2, uid1, "private") - - time.Sleep(10 * time.Millisecond) - - msgs, err := d.PollMessages(ctx, uid1, 0, 0) - if err != nil { - t.Fatalf("PollMessages: %v", err) - } - - if len(msgs) < 2 { - t.Fatalf("expected >=2 messages, got %d", len(msgs)) - } -} - -func TestChangeNick(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, token, _ := d.CreateUser(ctx, nickAlice) - - err := d.ChangeNick(ctx, 1, "alice2") - if err != nil { - t.Fatalf("ChangeNick: %v", err) - } - - _, nick, err := d.GetUserByToken(ctx, token) - if err != nil { - t.Fatalf("GetUserByToken: %v", err) - } - - if nick != "alice2" { - t.Errorf("got nick %q, want alice2", nick) - } -} - -func TestSetTopic(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid, _, _ := d.CreateUser(ctx, nickAlice) - _, _ = d.GetOrCreateChannel(ctx, "#general") - - err := d.SetTopic(ctx, "#general", uid, "new topic") - if err != nil { - t.Fatalf("SetTopic: %v", err) - } - - channels, err := d.ListAllChannels(ctx) - if err != nil { - t.Fatalf("ListAllChannels: %v", err) - } - - found := false - - for _, ch := range channels { - if ch.Name == "#general" && ch.Topic == "new topic" { - found = true - } - } - - if !found { - t.Error("topic was not updated") - } -} - -func TestGetMessagesBefore(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - uid, _, _ := d.CreateUser(ctx, nickAlice) - chID, _ := d.GetOrCreateChannel(ctx, "#general") - - _ = d.JoinChannel(ctx, chID, uid) - - for i := range 5 { - _, _ = d.SendMessage( - ctx, chID, uid, - fmt.Sprintf("msg%d", i), - ) - - time.Sleep(10 * time.Millisecond) - } - - msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3) - if err != nil { - t.Fatalf("GetMessagesBefore: %v", err) - } - - if len(msgs) != 3 { - t.Fatalf("expected 3, got %d", len(msgs)) - } -} - -func TestListAllChannels(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, _ = d.GetOrCreateChannel(ctx, "#alpha") - _, _ = d.GetOrCreateChannel(ctx, "#beta") - - channels, err := d.ListAllChannels(ctx) - if err != nil { - t.Fatalf("ListAllChannels: %v", err) - } - - if len(channels) != 2 { - t.Errorf("expected 2, got %d", len(channels)) - } -} diff --git a/internal/db/queries.go b/internal/db/queries.go index f3567b4..cbe9c16 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -3,144 +3,31 @@ package db import ( "context" "crypto/rand" - "database/sql" "encoding/hex" + "encoding/json" "fmt" "time" -) -const ( - defaultMessageLimit = 50 - defaultPollLimit = 100 - tokenBytes = 32 + "github.com/google/uuid" ) func generateToken() string { - b := make([]byte, tokenBytes) + b := make([]byte, 32) _, _ = rand.Read(b) - return hex.EncodeToString(b) } -// CreateUser registers a new user with the given nick and -// returns the user with token. -func (s *Database) CreateUser( - ctx context.Context, - nick string, -) (int64, string, error) { - token := generateToken() - now := time.Now() - - res, err := s.db.ExecContext(ctx, - "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", - nick, token, now, now) - if err != nil { - return 0, "", fmt.Errorf("create user: %w", err) - } - - id, _ := res.LastInsertId() - - return id, token, nil -} - -// GetUserByToken returns user id and nick for a given auth -// token. -func (s *Database) GetUserByToken( - ctx context.Context, - token string, -) (int64, string, error) { - var id int64 - - var nick string - - err := s.db.QueryRowContext( - ctx, - "SELECT id, nick FROM users WHERE token = ?", - token, - ).Scan(&id, &nick) - if err != nil { - return 0, "", err - } - - // Update last_seen - _, _ = s.db.ExecContext( - ctx, - "UPDATE users SET last_seen = ? WHERE id = ?", - time.Now(), id, - ) - - return id, nick, nil -} - -// GetUserByNick returns user id for a given nick. -func (s *Database) GetUserByNick( - ctx context.Context, - nick string, -) (int64, error) { - var id int64 - - err := s.db.QueryRowContext( - ctx, - "SELECT id FROM users WHERE nick = ?", - nick, - ).Scan(&id) - - return id, err -} - -// GetOrCreateChannel returns the channel id, creating it if -// needed. -func (s *Database) GetOrCreateChannel( - ctx context.Context, - name string, -) (int64, error) { - var id int64 - - err := s.db.QueryRowContext( - ctx, - "SELECT id FROM channels WHERE name = ?", - name, - ).Scan(&id) - if err == nil { - return id, nil - } - - now := time.Now() - - res, err := s.db.ExecContext(ctx, - "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", - name, now, now) - if err != nil { - return 0, fmt.Errorf("create channel: %w", err) - } - - id, _ = res.LastInsertId() - - return id, nil -} - -// JoinChannel adds a user to a channel. -func (s *Database) JoinChannel( - ctx context.Context, - channelID, userID int64, -) error { - _, err := s.db.ExecContext(ctx, - "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", - channelID, userID, time.Now()) - - return err -} - -// PartChannel removes a user from a channel. -func (s *Database) PartChannel( - ctx context.Context, - channelID, userID int64, -) error { - _, err := s.db.ExecContext(ctx, - "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", - channelID, userID) - - return err +// IRCMessage is the IRC envelope format for all messages. +type IRCMessage struct { + ID string `json:"id"` + Command string `json:"command"` + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + Body json.RawMessage `json:"body,omitempty"` + TS string `json:"ts"` + Meta json.RawMessage `json:"meta,omitempty"` + // Internal DB fields (not in JSON) + DBID int64 `json:"-"` } // ChannelInfo is a lightweight channel representation. @@ -150,46 +37,6 @@ type ChannelInfo struct { Topic string `json:"topic"` } -// ListChannels returns all channels the user has joined. -func (s *Database) ListChannels( - ctx context.Context, - userID int64, -) ([]ChannelInfo, error) { - rows, err := s.db.QueryContext(ctx, - `SELECT c.id, c.name, c.topic FROM channels c - INNER JOIN channel_members cm ON cm.channel_id = c.id - WHERE cm.user_id = ? ORDER BY c.name`, userID) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - var channels []ChannelInfo - - for rows.Next() { - var ch ChannelInfo - - err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic) - if err != nil { - return nil, err - } - - channels = append(channels, ch) - } - - err = rows.Err() - if err != nil { - return nil, err - } - - if channels == nil { - channels = []ChannelInfo{} - } - - return channels, nil -} - // MemberInfo represents a channel member. type MemberInfo struct { ID int64 `json:"id"` @@ -197,11 +44,130 @@ type MemberInfo struct { LastSeen time.Time `json:"lastSeen"` } +// CreateUser registers a new user with the given nick. +func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) { + token := generateToken() + now := time.Now() + res, err := s.db.ExecContext(ctx, + "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", + nick, token, now, now) + if err != nil { + return 0, "", fmt.Errorf("create user: %w", err) + } + id, _ := res.LastInsertId() + return id, token, nil +} + +// GetUserByToken returns user id and nick for a given auth token. +func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) { + var id int64 + var nick string + err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick) + if err != nil { + return 0, "", err + } + _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id) + return id, nick, nil +} + +// GetUserByNick returns user id for a given nick. +func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) { + var id int64 + err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id) + return id, err +} + +// GetOrCreateChannel returns the channel id, creating it if needed. +func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) { + var id int64 + err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) + if err == nil { + return id, nil + } + now := time.Now() + res, err := s.db.ExecContext(ctx, + "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", + name, now, now) + if err != nil { + return 0, fmt.Errorf("create channel: %w", err) + } + id, _ = res.LastInsertId() + return id, nil +} + +// JoinChannel adds a user to a channel. +func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error { + _, err := s.db.ExecContext(ctx, + "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", + channelID, userID, time.Now()) + return err +} + +// PartChannel removes a user from a channel. +func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error { + _, err := s.db.ExecContext(ctx, + "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", + channelID, userID) + return err +} + +// DeleteChannelIfEmpty deletes a channel if it has no members. +func (s *Database) DeleteChannelIfEmpty(ctx context.Context, channelID int64) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM channels WHERE id = ? AND NOT EXISTS + (SELECT 1 FROM channel_members WHERE channel_id = ?)`, + channelID, channelID) + return err +} + +// ListChannels returns all channels the user has joined. +func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT c.id, c.name, c.topic FROM channels c + INNER JOIN channel_members cm ON cm.channel_id = c.id + WHERE cm.user_id = ? ORDER BY c.name`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var channels []ChannelInfo + for rows.Next() { + var ch ChannelInfo + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { + return nil, err + } + channels = append(channels, ch) + } + if channels == nil { + channels = []ChannelInfo{} + } + return channels, nil +} + +// ListAllChannels returns all channels. +func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) { + rows, err := s.db.QueryContext(ctx, + "SELECT id, name, topic FROM channels ORDER BY name") + if err != nil { + return nil, err + } + defer rows.Close() + var channels []ChannelInfo + for rows.Next() { + var ch ChannelInfo + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { + return nil, err + } + channels = append(channels, ch) + } + if channels == nil { + channels = []ChannelInfo{} + } + return channels, nil +} + // ChannelMembers returns all members of a channel. -func (s *Database) ChannelMembers( - ctx context.Context, - channelID int64, -) ([]MemberInfo, error) { +func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT u.id, u.nick, u.last_seen FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id @@ -209,503 +175,215 @@ func (s *Database) ChannelMembers( if err != nil { return nil, err } - - defer func() { _ = rows.Close() }() - + defer rows.Close() var members []MemberInfo - for rows.Next() { var m MemberInfo - - err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen) - if err != nil { + if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil { return nil, err } - members = append(members, m) } - - err = rows.Err() - if err != nil { - return nil, err - } - if members == nil { members = []MemberInfo{} } - return members, nil } -// MessageInfo represents a chat message. -type MessageInfo struct { - ID int64 `json:"id"` - Channel string `json:"channel,omitempty"` - Nick string `json:"nick"` - Content string `json:"content"` - IsDM bool `json:"isDm,omitempty"` - DMTarget string `json:"dmTarget,omitempty"` - CreatedAt time.Time `json:"createdAt"` -} - -// GetMessages returns messages for a channel, optionally -// after a given ID. -func (s *Database) GetMessages( - ctx context.Context, - channelID int64, - afterID int64, - limit int, -) ([]MessageInfo, error) { - if limit <= 0 { - limit = defaultMessageLimit - } - +// GetChannelMemberIDs returns user IDs of all members in a channel. +func (s *Database) GetChannelMemberIDs(ctx context.Context, channelID int64) ([]int64, error) { rows, err := s.db.QueryContext(ctx, - `SELECT m.id, c.name, u.nick, m.content, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id > ? - ORDER BY m.id ASC LIMIT ?`, channelID, afterID, limit) + "SELECT user_id FROM channel_members WHERE channel_id = ?", channelID) if err != nil { return nil, err } - - defer func() { _ = rows.Close() }() - - var msgs []MessageInfo - + defer rows.Close() + var ids []int64 for rows.Next() { - var m MessageInfo - - err := rows.Scan( - &m.ID, &m.Channel, &m.Nick, - &m.Content, &m.CreatedAt, - ) - if err != nil { + var id int64 + if err := rows.Scan(&id); err != nil { return nil, err } - - msgs = append(msgs, m) + ids = append(ids, id) } + return ids, nil +} - err = rows.Err() +// GetUserChannelIDs returns channel IDs the user is a member of. +func (s *Database) GetUserChannelIDs(ctx context.Context, userID int64) ([]int64, error) { + rows, err := s.db.QueryContext(ctx, + "SELECT channel_id FROM channel_members WHERE user_id = ?", userID) if err != nil { return nil, err } - - if msgs == nil { - msgs = []MessageInfo{} + defer rows.Close() + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) } - - return msgs, nil + return ids, nil } -// SendMessage inserts a channel message. -func (s *Database) SendMessage( - ctx context.Context, - channelID, userID int64, - content string, -) (int64, error) { +// InsertMessage stores a message and returns its DB ID. +func (s *Database) InsertMessage(ctx context.Context, command, from, to string, body json.RawMessage, meta json.RawMessage) (int64, string, error) { + msgUUID := uuid.New().String() + now := time.Now().UTC() + if body == nil { + body = json.RawMessage("[]") + } + if meta == nil { + meta = json.RawMessage("{}") + } res, err := s.db.ExecContext(ctx, - "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)", - channelID, userID, content, time.Now()) + `INSERT INTO messages (uuid, command, msg_from, msg_to, body, meta, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + msgUUID, command, from, to, string(body), string(meta), now) if err != nil { - return 0, err + return 0, "", err } - - return res.LastInsertId() + id, _ := res.LastInsertId() + return id, msgUUID, nil } -// SendDM inserts a direct message. -func (s *Database) SendDM( - ctx context.Context, - fromID, toID int64, - content string, -) (int64, error) { - res, err := s.db.ExecContext(ctx, - "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)", - fromID, content, toID, time.Now()) - if err != nil { - return 0, err - } - - return res.LastInsertId() +// EnqueueMessage adds a message to a user's delivery queue. +func (s *Database) EnqueueMessage(ctx context.Context, userID, messageID int64) error { + _, err := s.db.ExecContext(ctx, + "INSERT OR IGNORE INTO client_queues (user_id, message_id, created_at) VALUES (?, ?, ?)", + userID, messageID, time.Now()) + return err } -// GetDMs returns direct messages between two users after a -// given ID. -func (s *Database) GetDMs( - ctx context.Context, - userA, userB int64, - afterID int64, - limit int, -) ([]MessageInfo, error) { +// PollMessages returns queued messages for a user after a given queue ID. +func (s *Database) PollMessages(ctx context.Context, userID int64, afterQueueID int64, limit int) ([]IRCMessage, int64, error) { if limit <= 0 { - limit = defaultMessageLimit + limit = 100 } - rows, err := s.db.QueryContext(ctx, - `SELECT m.id, u.nick, m.content, t.nick, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN users t ON t.id = m.dm_target_id - WHERE m.is_dm = 1 AND m.id > ? - AND ((m.user_id = ? AND m.dm_target_id = ?) - OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id ASC LIMIT ?`, - afterID, userA, userB, userB, userA, limit) + `SELECT cq.id, m.uuid, m.command, m.msg_from, m.msg_to, m.body, m.meta, m.created_at + FROM client_queues cq + INNER JOIN messages m ON m.id = cq.message_id + WHERE cq.user_id = ? AND cq.id > ? + ORDER BY cq.id ASC LIMIT ?`, userID, afterQueueID, limit) if err != nil { - return nil, err + return nil, afterQueueID, err } + defer rows.Close() - defer func() { _ = rows.Close() }() - - var msgs []MessageInfo - + var msgs []IRCMessage + var lastQID int64 for rows.Next() { - var m MessageInfo - - err := rows.Scan( - &m.ID, &m.Nick, &m.Content, - &m.DMTarget, &m.CreatedAt, - ) - if err != nil { - return nil, err + var m IRCMessage + var qID int64 + var body, meta string + var ts time.Time + if err := rows.Scan(&qID, &m.ID, &m.Command, &m.From, &m.To, &body, &meta, &ts); err != nil { + return nil, afterQueueID, err } - - m.IsDM = true - + m.Body = json.RawMessage(body) + m.Meta = json.RawMessage(meta) + m.TS = ts.Format(time.RFC3339Nano) + m.DBID = qID + lastQID = qID msgs = append(msgs, m) } - - err = rows.Err() - if err != nil { - return nil, err - } - if msgs == nil { - msgs = []MessageInfo{} + msgs = []IRCMessage{} } - - return msgs, nil + if lastQID == 0 { + lastQID = afterQueueID + } + return msgs, lastQID, nil } -// PollMessages returns all new messages (channel + DM) for -// a user after a given ID. -func (s *Database) PollMessages( - ctx context.Context, - userID int64, - afterID int64, - limit int, -) ([]MessageInfo, error) { +// GetHistory returns message history for a target (channel or DM nick pair). +func (s *Database) GetHistory(ctx context.Context, target string, beforeID int64, limit int) ([]IRCMessage, error) { if limit <= 0 { - limit = defaultPollLimit + limit = 50 } - - rows, err := s.db.QueryContext(ctx, - `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, - m.is_dm, COALESCE(t.nick, ''), m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - LEFT JOIN channels c ON c.id = m.channel_id - LEFT JOIN users t ON t.id = m.dm_target_id - WHERE m.id > ? AND ( - (m.is_dm = 0 AND m.channel_id IN - (SELECT channel_id FROM channel_members - WHERE user_id = ?)) - OR (m.is_dm = 1 - AND (m.user_id = ? OR m.dm_target_id = ?)) - ) - ORDER BY m.id ASC LIMIT ?`, - afterID, userID, userID, userID, limit) + var query string + var args []any + if beforeID > 0 { + query = `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at + FROM messages WHERE msg_to = ? AND id < ? AND command = 'PRIVMSG' + ORDER BY id DESC LIMIT ?` + args = []any{target, beforeID, limit} + } else { + query = `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at + FROM messages WHERE msg_to = ? AND command = 'PRIVMSG' + ORDER BY id DESC LIMIT ?` + args = []any{target, limit} + } + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - - defer func() { _ = rows.Close() }() - - msgs := make([]MessageInfo, 0) - + defer rows.Close() + var msgs []IRCMessage for rows.Next() { - var ( - m MessageInfo - isDM int - ) - - err := rows.Scan( - &m.ID, &m.Channel, &m.Nick, &m.Content, - &isDM, &m.DMTarget, &m.CreatedAt, - ) - if err != nil { + var m IRCMessage + var dbID int64 + var body, meta string + var ts time.Time + if err := rows.Scan(&dbID, &m.ID, &m.Command, &m.From, &m.To, &body, &meta, &ts); err != nil { return nil, err } - - m.IsDM = isDM == 1 + m.Body = json.RawMessage(body) + m.Meta = json.RawMessage(meta) + m.TS = ts.Format(time.RFC3339Nano) + m.DBID = dbID msgs = append(msgs, m) } - - err = rows.Err() - if err != nil { - return nil, err - } - - return msgs, nil -} - -func scanChannelMessages( - rows *sql.Rows, -) ([]MessageInfo, error) { - var msgs []MessageInfo - - for rows.Next() { - var m MessageInfo - - err := rows.Scan( - &m.ID, &m.Channel, &m.Nick, - &m.Content, &m.CreatedAt, - ) - if err != nil { - return nil, err - } - - msgs = append(msgs, m) - } - - err := rows.Err() - if err != nil { - return nil, err - } - if msgs == nil { - msgs = []MessageInfo{} + msgs = []IRCMessage{} } - - return msgs, nil -} - -func scanDMMessages( - rows *sql.Rows, -) ([]MessageInfo, error) { - var msgs []MessageInfo - - for rows.Next() { - var m MessageInfo - - err := rows.Scan( - &m.ID, &m.Nick, &m.Content, - &m.DMTarget, &m.CreatedAt, - ) - if err != nil { - return nil, err - } - - m.IsDM = true - - msgs = append(msgs, m) - } - - err := rows.Err() - if err != nil { - return nil, err - } - - if msgs == nil { - msgs = []MessageInfo{} - } - - return msgs, nil -} - -func reverseMessages(msgs []MessageInfo) { + // Reverse to ascending order for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } -} - -// GetMessagesBefore returns channel messages before a given -// ID (for history scrollback). -func (s *Database) GetMessagesBefore( - ctx context.Context, - channelID int64, - beforeID int64, - limit int, -) ([]MessageInfo, error) { - if limit <= 0 { - limit = defaultMessageLimit - } - - var query string - - var args []any - - if beforeID > 0 { - query = `SELECT m.id, c.name, u.nick, m.content, - m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 - AND m.id < ? - ORDER BY m.id DESC LIMIT ?` - args = []any{channelID, beforeID, limit} - } else { - query = `SELECT m.id, c.name, u.nick, m.content, - m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 - ORDER BY m.id DESC LIMIT ?` - args = []any{channelID, limit} - } - - rows, err := s.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - msgs, scanErr := scanChannelMessages(rows) - if scanErr != nil { - return nil, scanErr - } - - // Reverse to ascending order. - reverseMessages(msgs) - - return msgs, nil -} - -// GetDMsBefore returns DMs between two users before a given -// ID (for history scrollback). -func (s *Database) GetDMsBefore( - ctx context.Context, - userA, userB int64, - beforeID int64, - limit int, -) ([]MessageInfo, error) { - if limit <= 0 { - limit = defaultMessageLimit - } - - var query string - - var args []any - - if beforeID > 0 { - query = `SELECT m.id, u.nick, m.content, t.nick, - m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN users t ON t.id = m.dm_target_id - WHERE m.is_dm = 1 AND m.id < ? - AND ((m.user_id = ? AND m.dm_target_id = ?) - OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id DESC LIMIT ?` - args = []any{ - beforeID, userA, userB, userB, userA, limit, - } - } else { - query = `SELECT m.id, u.nick, m.content, t.nick, - m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN users t ON t.id = m.dm_target_id - WHERE m.is_dm = 1 - AND ((m.user_id = ? AND m.dm_target_id = ?) - OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id DESC LIMIT ?` - args = []any{userA, userB, userB, userA, limit} - } - - rows, err := s.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - msgs, scanErr := scanDMMessages(rows) - if scanErr != nil { - return nil, scanErr - } - - // Reverse to ascending order. - reverseMessages(msgs) - return msgs, nil } // ChangeNick updates a user's nickname. -func (s *Database) ChangeNick( - ctx context.Context, - userID int64, - newNick string, -) error { +func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error { _, err := s.db.ExecContext(ctx, - "UPDATE users SET nick = ? WHERE id = ?", - newNick, userID) - + "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) return err } // SetTopic sets the topic for a channel. -func (s *Database) SetTopic( - ctx context.Context, - channelName string, - _ int64, - topic string, -) error { +func (s *Database) SetTopic(ctx context.Context, channelName string, topic string) error { _, err := s.db.ExecContext(ctx, - "UPDATE channels SET topic = ? WHERE name = ?", - topic, channelName) - + "UPDATE channels SET topic = ?, updated_at = ? WHERE name = ?", topic, time.Now(), channelName) return err } -// GetServerName returns the server name (unused, config -// provides this). -func (s *Database) GetServerName() string { - return "" +// DeleteUser removes a user and all their data. +func (s *Database) DeleteUser(ctx context.Context, userID int64) error { + _, err := s.db.ExecContext(ctx, "DELETE FROM users WHERE id = ?", userID) + return err } -// ListAllChannels returns all channels. -func (s *Database) ListAllChannels( - ctx context.Context, -) ([]ChannelInfo, error) { +// GetAllChannelMembershipsForUser returns (channelID, channelName) for all channels a user is in. +func (s *Database) GetAllChannelMembershipsForUser(ctx context.Context, userID int64) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, - "SELECT id, name, topic FROM channels ORDER BY name") + `SELECT c.id, c.name, c.topic FROM channels c + INNER JOIN channel_members cm ON cm.channel_id = c.id + WHERE cm.user_id = ?`, userID) if err != nil { return nil, err } - - defer func() { _ = rows.Close() }() - + defer rows.Close() var channels []ChannelInfo - for rows.Next() { var ch ChannelInfo - - err := rows.Scan( - &ch.ID, &ch.Name, &ch.Topic, - ) - if err != nil { + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } - channels = append(channels, ch) } - - err = rows.Err() - if err != nil { - return nil, err - } - - if channels == nil { - channels = []ChannelInfo{} - } - return channels, nil } diff --git a/internal/db/schema/001_initial.sql b/internal/db/schema/001_initial.sql index 3741469..8434f78 100644 --- a/internal/db/schema/001_initial.sql +++ b/internal/db/schema/001_initial.sql @@ -1,4 +1,54 @@ -CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at DATETIME DEFAULT CURRENT_TIMESTAMP +-- Chat server schema (pre-1.0 consolidated) +PRAGMA foreign_keys = ON; + +-- Users: IRC-style sessions (no passwords, just nick + token) +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + nick TEXT NOT NULL UNIQUE, + token TEXT NOT NULL UNIQUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_seen DATETIME DEFAULT CURRENT_TIMESTAMP ); +CREATE INDEX IF NOT EXISTS idx_users_token ON users(token); + +-- Channels +CREATE TABLE IF NOT EXISTS channels ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + topic TEXT NOT NULL DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +-- Channel members +CREATE TABLE IF NOT EXISTS channel_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + joined_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(channel_id, user_id) +); + +-- Messages: IRC envelope format +CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + uuid TEXT NOT NULL UNIQUE, + command TEXT NOT NULL DEFAULT 'PRIVMSG', + msg_from TEXT NOT NULL DEFAULT '', + msg_to TEXT NOT NULL DEFAULT '', + body TEXT NOT NULL DEFAULT '[]', + meta TEXT NOT NULL DEFAULT '{}', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_messages_to_id ON messages(msg_to, id); +CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at); + +-- Per-client message queues for fan-out delivery +CREATE TABLE IF NOT EXISTS client_queues ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, message_id) +); +CREATE INDEX IF NOT EXISTS idx_client_queues_user ON client_queues(user_id, id); diff --git a/internal/db/schema/002_schema.sql b/internal/db/schema/002_schema.sql deleted file mode 100644 index 58dcb70..0000000 --- a/internal/db/schema/002_schema.sql +++ /dev/null @@ -1,89 +0,0 @@ --- All schema changes go into this file until 1.0.0 is tagged. --- There will not be migrations during the early development phase. --- After 1.0.0, new changes get their own numbered migration files. - --- Users: accounts and authentication -CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, -- UUID - nick TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - last_seen_at DATETIME -); - --- Auth tokens: one user can have multiple active tokens (multiple devices) -CREATE TABLE IF NOT EXISTS auth_tokens ( - token TEXT PRIMARY KEY, -- random token string - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at DATETIME, -- NULL = no expiry - last_used_at DATETIME -); -CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens(user_id); - --- Channels: chat rooms -CREATE TABLE IF NOT EXISTS channels ( - id TEXT PRIMARY KEY, -- UUID - name TEXT NOT NULL UNIQUE, -- #general, etc. - topic TEXT NOT NULL DEFAULT '', - modes TEXT NOT NULL DEFAULT '', -- +i, +m, +s, +t, +n - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP -); - --- Channel members: who is in which channel, with per-user modes -CREATE TABLE IF NOT EXISTS channel_members ( - channel_id TEXT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - modes TEXT NOT NULL DEFAULT '', -- +o (operator), +v (voice) - joined_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (channel_id, user_id) -); -CREATE INDEX IF NOT EXISTS idx_channel_members_user_id ON channel_members(user_id); - --- Messages: channel and DM history (rotated per MAX_HISTORY) -CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, -- UUID - ts DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - from_user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - from_nick TEXT NOT NULL, -- denormalized for history - target TEXT NOT NULL, -- #channel name or user UUID for DMs - type TEXT NOT NULL DEFAULT 'message', -- message, action, notice, join, part, quit, topic, mode, nick, system - body TEXT NOT NULL DEFAULT '', - meta TEXT NOT NULL DEFAULT '{}', -- JSON extensible metadata - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP -); -CREATE INDEX IF NOT EXISTS idx_messages_target_ts ON messages(target, ts); -CREATE INDEX IF NOT EXISTS idx_messages_from_user ON messages(from_user_id); - --- Message queue: per-user pending delivery (unread messages) -CREATE TABLE IF NOT EXISTS message_queue ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - message_id TEXT NOT NULL REFERENCES messages(id) ON DELETE CASCADE, - queued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - UNIQUE(user_id, message_id) -); -CREATE INDEX IF NOT EXISTS idx_message_queue_user_id ON message_queue(user_id, queued_at); - --- Sessions: server-held session state -CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, -- UUID - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - last_active_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at DATETIME -- idle timeout -); -CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id); - --- Server links: federation peer configuration -CREATE TABLE IF NOT EXISTS server_links ( - id TEXT PRIMARY KEY, -- UUID - name TEXT NOT NULL UNIQUE, -- human-readable peer name - url TEXT NOT NULL, -- base URL of peer server - shared_key_hash TEXT NOT NULL, -- hashed shared secret - is_active INTEGER NOT NULL DEFAULT 1, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - last_seen_at DATETIME -); diff --git a/internal/db/schema/003_users.sql b/internal/db/schema/003_users.sql deleted file mode 100644 index a89aad8..0000000 --- a/internal/db/schema/003_users.sql +++ /dev/null @@ -1,53 +0,0 @@ --- Migration 003: Replace UUID-based tables with simple integer-keyed --- tables for the HTTP API. Drops the 002 tables and recreates them. - -PRAGMA foreign_keys = OFF; - -DROP TABLE IF EXISTS message_queue; -DROP TABLE IF EXISTS sessions; -DROP TABLE IF EXISTS server_links; -DROP TABLE IF EXISTS messages; -DROP TABLE IF EXISTS channel_members; -DROP TABLE IF EXISTS auth_tokens; -DROP TABLE IF EXISTS channels; -DROP TABLE IF EXISTS users; - -CREATE TABLE users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - nick TEXT NOT NULL UNIQUE, - token TEXT NOT NULL UNIQUE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - last_seen DATETIME DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE channels ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - topic TEXT NOT NULL DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE channel_members ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - joined_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(channel_id, user_id) -); - -CREATE TABLE messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - content TEXT NOT NULL, - is_dm INTEGER NOT NULL DEFAULT 0, - dm_target_id INTEGER REFERENCES users(id) ON DELETE CASCADE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP -); - -CREATE INDEX idx_messages_channel ON messages(channel_id, created_at); -CREATE INDEX idx_messages_dm ON messages(user_id, dm_target_id, created_at); -CREATE INDEX idx_users_token ON users(token); - -PRAGMA foreign_keys = ON; diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 764a0fe..d02f734 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -6,186 +6,125 @@ import ( "net/http" "strconv" "strings" + "time" - "git.eeqj.de/sneak/chat/internal/db" "github.com/go-chi/chi" ) -const ( - maxNickLen = 32 - defaultHistory = 50 -) - -// authUser extracts the user from the Authorization header -// (Bearer token). -func (s *Handlers) authUser( - r *http.Request, -) (int64, string, error) { +// authUser extracts the user from the Authorization header (Bearer token). +func (s *Handlers) authUser(r *http.Request) (int64, string, error) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { return 0, "", sql.ErrNoRows } - token := strings.TrimPrefix(auth, "Bearer ") - return s.params.Database.GetUserByToken(r.Context(), token) } -func (s *Handlers) requireAuth( - w http.ResponseWriter, - r *http.Request, -) (int64, string, bool) { +func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, string, bool) { uid, nick, err := s.authUser(r) if err != nil { - s.respondJSON( - w, r, - map[string]string{"error": "unauthorized"}, - http.StatusUnauthorized, - ) - + s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized) return 0, "", false } - return uid, nick, true } -func (s *Handlers) respondError( - w http.ResponseWriter, - r *http.Request, - msg string, - code int, -) { - s.respondJSON(w, r, map[string]string{"error": msg}, code) -} - -func (s *Handlers) internalError( - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - s.log.Error(msg, "error", err) - s.respondError(w, r, "internal error", http.StatusInternalServerError) -} - -// bodyLines extracts body as string lines from a request body -// field. -func bodyLines(body any) []string { - switch v := body.(type) { - case []any: - lines := make([]string, 0, len(v)) - - for _, item := range v { - if s, ok := item.(string); ok { - lines = append(lines, s) - } - } - - return lines - case []string: - return v - default: - return nil +// fanOut stores a message and enqueues it to all specified user IDs, then notifies them. +func (s *Handlers) fanOut(ctx *http.Request, command, from, to string, body json.RawMessage, userIDs []int64) error { + dbID, _, err := s.params.Database.InsertMessage(ctx.Context(), command, from, to, body, nil) + if err != nil { + return err } + for _, uid := range userIDs { + if err := s.params.Database.EnqueueMessage(ctx.Context(), uid, dbID); err != nil { + s.log.Error("enqueue failed", "error", err, "user_id", uid) + } + s.broker.Notify(uid) + } + return nil } -// HandleCreateSession creates a new user session and returns -// the auth token. +// fanOutRaw stores and fans out, returning the message DB ID. +func (s *Handlers) fanOutDirect(ctx *http.Request, command, from, to string, body json.RawMessage, userIDs []int64) (int64, string, error) { + dbID, msgUUID, err := s.params.Database.InsertMessage(ctx.Context(), command, from, to, body, nil) + if err != nil { + return 0, "", err + } + for _, uid := range userIDs { + if err := s.params.Database.EnqueueMessage(ctx.Context(), uid, dbID); err != nil { + s.log.Error("enqueue failed", "error", err, "user_id", uid) + } + s.broker.Notify(uid) + } + return dbID, msgUUID, nil +} + +// getChannelMembers gets all member IDs for a channel by name. +func (s *Handlers) getChannelMemberIDs(r *http.Request, channelName string) (int64, []int64, error) { + var chID int64 + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", channelName).Scan(&chID) + if err != nil { + return 0, nil, err + } + ids, err := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + return chID, ids, err +} + +// HandleCreateSession creates a new user session and returns the auth token. func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { Nick string `json:"nick"` } - type response struct { ID int64 `json:"id"` Nick string `json:"nick"` Token string `json:"token"` } - return func(w http.ResponseWriter, r *http.Request) { var req request - - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - s.respondError( - w, r, "invalid request", - http.StatusBadRequest, - ) - + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) return } - req.Nick = strings.TrimSpace(req.Nick) - - if req.Nick == "" || len(req.Nick) > maxNickLen { - s.respondError( - w, r, "nick must be 1-32 characters", - http.StatusBadRequest, - ) - + if req.Nick == "" || len(req.Nick) > 32 { + s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) return } - - id, token, err := s.params.Database.CreateUser( - r.Context(), req.Nick, - ) + id, token, err := s.params.Database.CreateUser(r.Context(), req.Nick) if err != nil { if strings.Contains(err.Error(), "UNIQUE") { - s.respondError( - w, r, "nick already taken", - http.StatusConflict, - ) - + s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict) return } - - s.internalError(w, r, "create user failed", err) - + s.log.Error("create user failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - - s.respondJSON( - w, r, - &response{ID: id, Nick: req.Nick, Token: token}, - http.StatusCreated, - ) + s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated) } } -// HandleState returns the current user's info and joined -// channels. +// HandleState returns the current user's info and joined channels. func (s *Handlers) HandleState() http.HandlerFunc { - type response struct { - ID int64 `json:"id"` - Nick string `json:"nick"` - Channels []db.ChannelInfo `json:"channels"` - } - return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } - - channels, err := s.params.Database.ListChannels( - r.Context(), uid, - ) + channels, err := s.params.Database.ListChannels(r.Context(), uid) if err != nil { - s.internalError( - w, r, "list channels failed", err, - ) - + s.log.Error("list channels failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - - s.respondJSON( - w, r, - &response{ - ID: uid, Nick: nick, - Channels: channels, - }, - http.StatusOK, - ) + s.respondJSON(w, r, map[string]any{ + "id": uid, + "nick": nick, + "channels": channels, + }, http.StatusOK) } } @@ -196,18 +135,12 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc { if !ok { return } - - channels, err := s.params.Database.ListAllChannels( - r.Context(), - ) + channels, err := s.params.Database.ListAllChannels(r.Context()) if err != nil { - s.internalError( - w, r, "list all channels failed", err, - ) - + s.log.Error("list all channels failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - s.respondJSON(w, r, channels, http.StatusOK) } } @@ -219,570 +152,347 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { if !ok { return } - name := "#" + chi.URLParam(r, "channel") - var chID int64 - - err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query - r.Context(), - "SELECT id FROM channels WHERE name = ?", - name, - ).Scan(&chID) + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) if err != nil { - s.respondError( - w, r, "channel not found", - http.StatusNotFound, - ) - + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) return } - - members, err := s.params.Database.ChannelMembers( - r.Context(), chID, - ) + members, err := s.params.Database.ChannelMembers(r.Context(), chID) if err != nil { - s.internalError( - w, r, "channel members failed", err, - ) - + s.log.Error("channel members failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - s.respondJSON(w, r, members, http.StatusOK) } } -// HandleGetMessages returns all new messages (channel + DM) -// for the user via long-polling. +// HandleGetMessages returns messages via long-polling from the client's queue. func (s *Handlers) HandleGetMessages() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { return } + afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64) + timeout, _ := strconv.Atoi(r.URL.Query().Get("timeout")) + if timeout <= 0 { + timeout = 0 + } + if timeout > 30 { + timeout = 30 + } - afterID, _ := strconv.ParseInt( - r.URL.Query().Get("after"), 10, 64, - ) - - limit, _ := strconv.Atoi( - r.URL.Query().Get("limit"), - ) - - msgs, err := s.params.Database.PollMessages( - r.Context(), uid, afterID, limit, - ) + // First check for existing messages. + msgs, lastQID, err := s.params.Database.PollMessages(r.Context(), uid, afterID, 100) if err != nil { - s.internalError( - w, r, "get messages failed", err, - ) - + s.log.Error("poll messages failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - s.respondJSON(w, r, msgs, http.StatusOK) + if len(msgs) > 0 || timeout == 0 { + s.respondJSON(w, r, map[string]any{ + "messages": msgs, + "last_id": lastQID, + }, http.StatusOK) + return + } + + // Long-poll: wait for notification or timeout. + waitCh := s.broker.Wait(uid) + timer := time.NewTimer(time.Duration(timeout) * time.Second) + defer timer.Stop() + + select { + case <-waitCh: + case <-timer.C: + case <-r.Context().Done(): + s.broker.Remove(uid, waitCh) + return + } + s.broker.Remove(uid, waitCh) + + // Check again after notification. + msgs, lastQID, err = s.params.Database.PollMessages(r.Context(), uid, afterID, 100) + if err != nil { + s.log.Error("poll messages failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]any{ + "messages": msgs, + "last_id": lastQID, + }, http.StatusOK) } } -type sendRequest struct { - Command string `json:"command"` - To string `json:"to"` - Params []string `json:"params,omitempty"` - Body any `json:"body,omitempty"` -} - -// HandleSendCommand handles all C2S commands via POST -// /messages. +// HandleSendCommand handles all C2S commands via POST /messages. func (s *Handlers) HandleSendCommand() http.HandlerFunc { + type request struct { + Command string `json:"command"` + To string `json:"to"` + Body json.RawMessage `json:"body,omitempty"` + Meta json.RawMessage `json:"meta,omitempty"` + } return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } - - var req sendRequest - - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - s.respondError( - w, r, "invalid request", - http.StatusBadRequest, - ) - + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) return } - - req.Command = strings.ToUpper( - strings.TrimSpace(req.Command), - ) + req.Command = strings.ToUpper(strings.TrimSpace(req.Command)) req.To = strings.TrimSpace(req.To) - s.dispatchCommand(w, r, uid, nick, &req) - } -} - -func (s *Handlers) dispatchCommand( - w http.ResponseWriter, - r *http.Request, - uid int64, - nick string, - req *sendRequest, -) { - switch req.Command { - case "PRIVMSG", "NOTICE": - s.handlePrivmsg(w, r, uid, req) - case "JOIN": - s.handleJoin(w, r, uid, req) - case "PART": - s.handlePart(w, r, uid, req) - case "NICK": - s.handleNick(w, r, uid, req) - case "TOPIC": - s.handleTopic(w, r, uid, req) - case "PING": - s.respondJSON( - w, r, - map[string]string{ - "command": "PONG", - "from": s.params.Config.ServerName, - }, - http.StatusOK, - ) - default: - _ = nick - - s.respondError( - w, r, - "unknown command: "+req.Command, - http.StatusBadRequest, - ) - } -} - -func (s *Handlers) handlePrivmsg( - w http.ResponseWriter, - r *http.Request, - uid int64, - req *sendRequest, -) { - if req.To == "" { - s.respondError( - w, r, "to field required", - http.StatusBadRequest, - ) - - return - } - - lines := bodyLines(req.Body) - if len(lines) == 0 { - s.respondError( - w, r, "body required", http.StatusBadRequest, - ) - - return - } - - content := strings.Join(lines, "\n") - - if strings.HasPrefix(req.To, "#") { - s.sendChannelMsg(w, r, uid, req.To, content) - } else { - s.sendDM(w, r, uid, req.To, content) - } -} - -func (s *Handlers) sendChannelMsg( - w http.ResponseWriter, - r *http.Request, - uid int64, - channel, content string, -) { - var chID int64 - - err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query - r.Context(), - "SELECT id FROM channels WHERE name = ?", - channel, - ).Scan(&chID) - if err != nil { - s.respondError( - w, r, "channel not found", - http.StatusNotFound, - ) - - return - } - - msgID, err := s.params.Database.SendMessage( - r.Context(), chID, uid, content, - ) - if err != nil { - s.internalError(w, r, "send message failed", err) - - return - } - - s.respondJSON( - w, r, - map[string]any{"id": msgID, "status": "sent"}, - http.StatusCreated, - ) -} - -func (s *Handlers) sendDM( - w http.ResponseWriter, - r *http.Request, - uid int64, - toNick, content string, -) { - targetID, err := s.params.Database.GetUserByNick( - r.Context(), toNick, - ) - if err != nil { - s.respondError( - w, r, "user not found", http.StatusNotFound, - ) - - return - } - - msgID, err := s.params.Database.SendDM( - r.Context(), uid, targetID, content, - ) - if err != nil { - s.internalError(w, r, "send dm failed", err) - - return - } - - s.respondJSON( - w, r, - map[string]any{"id": msgID, "status": "sent"}, - http.StatusCreated, - ) -} - -func (s *Handlers) handleJoin( - w http.ResponseWriter, - r *http.Request, - uid int64, - req *sendRequest, -) { - if req.To == "" { - s.respondError( - w, r, "to field required", - http.StatusBadRequest, - ) - - return - } - - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - - chID, err := s.params.Database.GetOrCreateChannel( - r.Context(), channel, - ) - if err != nil { - s.internalError( - w, r, "get/create channel failed", err, - ) - - return - } - - err = s.params.Database.JoinChannel( - r.Context(), chID, uid, - ) - if err != nil { - s.internalError(w, r, "join channel failed", err) - - return - } - - s.respondJSON( - w, r, - map[string]string{ - "status": "joined", "channel": channel, - }, - http.StatusOK, - ) -} - -func (s *Handlers) handlePart( - w http.ResponseWriter, - r *http.Request, - uid int64, - req *sendRequest, -) { - if req.To == "" { - s.respondError( - w, r, "to field required", - http.StatusBadRequest, - ) - - return - } - - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - - var chID int64 - - err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query - r.Context(), - "SELECT id FROM channels WHERE name = ?", - channel, - ).Scan(&chID) - if err != nil { - s.respondError( - w, r, "channel not found", - http.StatusNotFound, - ) - - return - } - - err = s.params.Database.PartChannel( - r.Context(), chID, uid, - ) - if err != nil { - s.internalError(w, r, "part channel failed", err) - - return - } - - s.respondJSON( - w, r, - map[string]string{ - "status": "parted", "channel": channel, - }, - http.StatusOK, - ) -} - -func (s *Handlers) handleNick( - w http.ResponseWriter, - r *http.Request, - uid int64, - req *sendRequest, -) { - lines := bodyLines(req.Body) - if len(lines) == 0 { - s.respondError( - w, r, "body required (new nick)", - http.StatusBadRequest, - ) - - return - } - - newNick := strings.TrimSpace(lines[0]) - if newNick == "" || len(newNick) > maxNickLen { - s.respondError( - w, r, "nick must be 1-32 characters", - http.StatusBadRequest, - ) - - return - } - - err := s.params.Database.ChangeNick( - r.Context(), uid, newNick, - ) - if err != nil { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondError( - w, r, "nick already in use", - http.StatusConflict, - ) - - return + bodyLines := func() []string { + if req.Body == nil { + return nil + } + var lines []string + if err := json.Unmarshal(req.Body, &lines); err != nil { + return nil + } + return lines } - s.internalError(w, r, "change nick failed", err) + switch req.Command { + case "PRIVMSG", "NOTICE": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) + return + } - return + if strings.HasPrefix(req.To, "#") { + // Channel message — fan out to all channel members. + _, memberIDs, err := s.getChannelMemberIDs(r, req.To) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + _, msgUUID, err := s.fanOutDirect(r, req.Command, nick, req.To, req.Body, memberIDs) + if err != nil { + s.log.Error("send message failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) + } else { + // DM — fan out to recipient + sender. + targetUID, err := s.params.Database.GetUserByNick(r.Context(), req.To) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + return + } + recipients := []int64{targetUID} + if targetUID != uid { + recipients = append(recipients, uid) // echo to sender + } + _, msgUUID, err := s.fanOutDirect(r, req.Command, nick, req.To, req.Body, recipients) + if err != nil { + s.log.Error("send dm failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) + } + + case "JOIN": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) + if err != nil { + s.log.Error("get/create channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { + s.log.Error("join channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Broadcast JOIN to all channel members (including the joiner). + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + _ = s.fanOut(r, "JOIN", nick, channel, nil, memberIDs) + s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) + + case "PART": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + var chID int64 + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + // Broadcast PART before removing the member. + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + _ = s.fanOut(r, "PART", nick, channel, req.Body, memberIDs) + + if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { + s.log.Error("part channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Delete channel if empty (ephemeral). + _ = s.params.Database.DeleteChannelIfEmpty(r.Context(), chID) + s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) + + case "NICK": + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) + return + } + newNick := strings.TrimSpace(lines[0]) + if newNick == "" || len(newNick) > 32 { + s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) + return + } + if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) + return + } + s.log.Error("change nick failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Broadcast NICK to all channels the user is in. + channels, _ := s.params.Database.GetAllChannelMembershipsForUser(r.Context(), uid) + notified := map[int64]bool{uid: true} + body, _ := json.Marshal([]string{newNick}) + // Notify self. + dbID, _, _ := s.params.Database.InsertMessage(r.Context(), "NICK", nick, "", json.RawMessage(body), nil) + _ = s.params.Database.EnqueueMessage(r.Context(), uid, dbID) + s.broker.Notify(uid) + + for _, ch := range channels { + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), ch.ID) + for _, mid := range memberIDs { + if !notified[mid] { + notified[mid] = true + _ = s.params.Database.EnqueueMessage(r.Context(), mid, dbID) + s.broker.Notify(mid) + } + } + } + s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) + + case "TOPIC": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) + return + } + topic := strings.Join(lines, " ") + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + if err := s.params.Database.SetTopic(r.Context(), channel, topic); err != nil { + s.log.Error("set topic failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Broadcast TOPIC to channel members. + _, memberIDs, _ := s.getChannelMemberIDs(r, channel) + _ = s.fanOut(r, "TOPIC", nick, channel, req.Body, memberIDs) + s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) + + case "QUIT": + // Broadcast QUIT to all channels, then remove user. + channels, _ := s.params.Database.GetAllChannelMembershipsForUser(r.Context(), uid) + notified := map[int64]bool{} + var dbID int64 + if len(channels) > 0 { + dbID, _, _ = s.params.Database.InsertMessage(r.Context(), "QUIT", nick, "", req.Body, nil) + } + for _, ch := range channels { + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), ch.ID) + for _, mid := range memberIDs { + if mid != uid && !notified[mid] { + notified[mid] = true + _ = s.params.Database.EnqueueMessage(r.Context(), mid, dbID) + s.broker.Notify(mid) + } + } + _ = s.params.Database.PartChannel(r.Context(), ch.ID, uid) + _ = s.params.Database.DeleteChannelIfEmpty(r.Context(), ch.ID) + } + _ = s.params.Database.DeleteUser(r.Context(), uid) + s.respondJSON(w, r, map[string]string{"status": "quit"}, http.StatusOK) + + case "PING": + s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK) + + default: + s.respondJSON(w, r, map[string]string{"error": "unknown command: " + req.Command}, http.StatusBadRequest) + } } - - s.respondJSON( - w, r, - map[string]string{"status": "ok", "nick": newNick}, - http.StatusOK, - ) } -func (s *Handlers) handleTopic( - w http.ResponseWriter, - r *http.Request, - uid int64, - req *sendRequest, -) { - if req.To == "" { - s.respondError( - w, r, "to field required", - http.StatusBadRequest, - ) - - return - } - - lines := bodyLines(req.Body) - if len(lines) == 0 { - s.respondError( - w, r, "body required (topic text)", - http.StatusBadRequest, - ) - - return - } - - topic := strings.Join(lines, " ") - - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - - err := s.params.Database.SetTopic( - r.Context(), channel, uid, topic, - ) - if err != nil { - s.internalError(w, r, "set topic failed", err) - - return - } - - s.respondJSON( - w, r, - map[string]string{"status": "ok", "topic": topic}, - http.StatusOK, - ) -} - -// HandleGetHistory returns message history for a specific -// target (channel or DM). +// HandleGetHistory returns message history for a specific target. func (s *Handlers) HandleGetHistory() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) + _, _, ok := s.requireAuth(w, r) if !ok { return } - target := r.URL.Query().Get("target") if target == "" { - s.respondError( - w, r, "target required", - http.StatusBadRequest, - ) - + s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest) return } - - beforeID, _ := strconv.ParseInt( - r.URL.Query().Get("before"), 10, 64, - ) - - limit, _ := strconv.Atoi( - r.URL.Query().Get("limit"), - ) + beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64) + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) if limit <= 0 { - limit = defaultHistory + limit = 50 } - - if strings.HasPrefix(target, "#") { - s.getChannelHistory( - w, r, target, beforeID, limit, - ) - } else { - s.getDMHistory( - w, r, uid, target, beforeID, limit, - ) + msgs, err := s.params.Database.GetHistory(r.Context(), target, beforeID, limit) + if err != nil { + s.log.Error("get history failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, msgs, http.StatusOK) } } -func (s *Handlers) getChannelHistory( - w http.ResponseWriter, - r *http.Request, - target string, - beforeID int64, - limit int, -) { - var chID int64 - - err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query - r.Context(), - "SELECT id FROM channels WHERE name = ?", - target, - ).Scan(&chID) - if err != nil { - s.respondError( - w, r, "channel not found", - http.StatusNotFound, - ) - - return - } - - msgs, err := s.params.Database.GetMessagesBefore( - r.Context(), chID, beforeID, limit, - ) - if err != nil { - s.internalError(w, r, "get history failed", err) - - return - } - - s.respondJSON(w, r, msgs, http.StatusOK) -} - -func (s *Handlers) getDMHistory( - w http.ResponseWriter, - r *http.Request, - uid int64, - target string, - beforeID int64, - limit int, -) { - targetID, err := s.params.Database.GetUserByNick( - r.Context(), target, - ) - if err != nil { - s.respondError( - w, r, "user not found", http.StatusNotFound, - ) - - return - } - - msgs, err := s.params.Database.GetDMsBefore( - r.Context(), uid, targetID, beforeID, limit, - ) - if err != nil { - s.internalError( - w, r, "get dm history failed", err, - ) - - return - } - - s.respondJSON(w, r, msgs, http.StatusOK) -} - -// HandleServerInfo returns server metadata (MOTD, name). +// HandleServerInfo returns server metadata. func (s *Handlers) HandleServerInfo() http.HandlerFunc { type response struct { Name string `json:"name"` MOTD string `json:"motd"` } - return func(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, r, &response{ Name: s.params.Config.ServerName, diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 11b8942..92e5234 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" + "git.eeqj.de/sneak/chat/internal/broker" "git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/db" "git.eeqj.de/sneak/chat/internal/globals" @@ -31,6 +32,7 @@ type Handlers struct { params *Params log *slog.Logger hc *healthcheck.Healthcheck + broker *broker.Broker } // New creates a new Handlers instance. @@ -39,6 +41,7 @@ func New(lc fx.Lifecycle, params Params) (*Handlers, error) { s.params = ¶ms s.log = params.Logger.Get() s.hc = params.Healthcheck + s.broker = broker.New() lc.Append(fx.Hook{ OnStart: func(_ context.Context) error { @@ -50,8 +53,8 @@ func New(lc fx.Lifecycle, params Params) (*Handlers, error) { } func (s *Handlers) respondJSON(w http.ResponseWriter, _ *http.Request, data any, status int) { - w.WriteHeader(status) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) if data != nil { err := json.NewEncoder(w).Encode(data) diff --git a/internal/models/auth_token.go b/internal/models/auth_token.go deleted file mode 100644 index c2c3fd1..0000000 --- a/internal/models/auth_token.go +++ /dev/null @@ -1,26 +0,0 @@ -package models - -import ( - "context" - "time" -) - -// AuthToken represents an authentication token for a user session. -type AuthToken struct { - Base - - Token string `json:"-"` - UserID string `json:"userId"` - CreatedAt time.Time `json:"createdAt"` - ExpiresAt *time.Time `json:"expiresAt,omitempty"` - LastUsedAt *time.Time `json:"lastUsedAt,omitempty"` -} - -// User returns the user who owns this token. -func (t *AuthToken) User(ctx context.Context) (*User, error) { - if ul := t.GetUserLookup(); ul != nil { - return ul.GetUserByID(ctx, t.UserID) - } - - return nil, ErrUserLookupNotAvailable -} diff --git a/internal/models/channel.go b/internal/models/channel.go deleted file mode 100644 index addafc9..0000000 --- a/internal/models/channel.go +++ /dev/null @@ -1,96 +0,0 @@ -package models - -import ( - "context" - "time" -) - -// Channel represents a chat channel. -type Channel struct { - Base - - ID string `json:"id"` - Name string `json:"name"` - Topic string `json:"topic"` - Modes string `json:"modes"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// Members returns all users who are members of this channel. -func (c *Channel) Members(ctx context.Context) ([]*ChannelMember, error) { - rows, err := c.GetDB().QueryContext(ctx, ` - SELECT cm.channel_id, cm.user_id, cm.modes, cm.joined_at, - u.nick - FROM channel_members cm - JOIN users u ON u.id = cm.user_id - WHERE cm.channel_id = ? - ORDER BY cm.joined_at`, - c.ID, - ) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - members := []*ChannelMember{} - - for rows.Next() { - m := &ChannelMember{} - m.SetDB(c.db) - - err = rows.Scan( - &m.ChannelID, &m.UserID, &m.Modes, - &m.JoinedAt, &m.Nick, - ) - if err != nil { - return nil, err - } - - members = append(members, m) - } - - return members, rows.Err() -} - -// RecentMessages returns the most recent messages in this channel. -func (c *Channel) RecentMessages( - ctx context.Context, - limit int, -) ([]*Message, error) { - rows, err := c.GetDB().QueryContext(ctx, ` - SELECT id, ts, from_user_id, from_nick, - target, type, body, meta, created_at - FROM messages - WHERE target = ? - ORDER BY ts DESC - LIMIT ?`, - c.Name, limit, - ) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - messages := []*Message{} - - for rows.Next() { - msg := &Message{} - msg.SetDB(c.db) - - err = rows.Scan( - &msg.ID, &msg.Timestamp, &msg.FromUserID, - &msg.FromNick, &msg.Target, &msg.Type, - &msg.Body, &msg.Meta, &msg.CreatedAt, - ) - if err != nil { - return nil, err - } - - messages = append(messages, msg) - } - - return messages, rows.Err() -} diff --git a/internal/models/channel_member.go b/internal/models/channel_member.go deleted file mode 100644 index 59586c7..0000000 --- a/internal/models/channel_member.go +++ /dev/null @@ -1,35 +0,0 @@ -package models - -import ( - "context" - "time" -) - -// ChannelMember represents a user's membership in a channel. -type ChannelMember struct { - Base - - ChannelID string `json:"channelId"` - UserID string `json:"userId"` - Modes string `json:"modes"` - JoinedAt time.Time `json:"joinedAt"` - Nick string `json:"nick"` // denormalized from users table -} - -// User returns the full User for this membership. -func (cm *ChannelMember) User(ctx context.Context) (*User, error) { - if ul := cm.GetUserLookup(); ul != nil { - return ul.GetUserByID(ctx, cm.UserID) - } - - return nil, ErrUserLookupNotAvailable -} - -// Channel returns the full Channel for this membership. -func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) { - if cl := cm.GetChannelLookup(); cl != nil { - return cl.GetChannelByID(ctx, cm.ChannelID) - } - - return nil, ErrChannelLookupNotAvailable -} diff --git a/internal/models/message.go b/internal/models/message.go deleted file mode 100644 index 652ae0d..0000000 --- a/internal/models/message.go +++ /dev/null @@ -1,20 +0,0 @@ -package models - -import ( - "time" -) - -// Message represents a chat message (channel or DM). -type Message struct { - Base - - ID string `json:"id"` - Timestamp time.Time `json:"ts"` - FromUserID string `json:"fromUserId"` - FromNick string `json:"from"` - Target string `json:"to"` - Type string `json:"type"` - Body string `json:"body"` - Meta string `json:"meta"` - CreatedAt time.Time `json:"createdAt"` -} diff --git a/internal/models/message_queue.go b/internal/models/message_queue.go deleted file mode 100644 index 616cbc3..0000000 --- a/internal/models/message_queue.go +++ /dev/null @@ -1,15 +0,0 @@ -package models - -import ( - "time" -) - -// MessageQueueEntry represents a pending message delivery for a user. -type MessageQueueEntry struct { - Base - - ID int64 `json:"id"` - UserID string `json:"userId"` - MessageID string `json:"messageId"` - QueuedAt time.Time `json:"queuedAt"` -} diff --git a/internal/models/model.go b/internal/models/model.go deleted file mode 100644 index fdaa90f..0000000 --- a/internal/models/model.go +++ /dev/null @@ -1,65 +0,0 @@ -// Package models defines the data models used by the chat application. -// All model structs embed Base, which provides database access for -// relation-fetching methods directly on model instances. -package models - -import ( - "context" - "database/sql" - "errors" -) - -// DB is the interface that models use to query the database. -// This avoids a circular import with the db package. -type DB interface { - GetDB() *sql.DB -} - -// UserLookup provides user lookup by ID without circular imports. -type UserLookup interface { - GetUserByID(ctx context.Context, id string) (*User, error) -} - -// ChannelLookup provides channel lookup by ID without circular imports. -type ChannelLookup interface { - GetChannelByID(ctx context.Context, id string) (*Channel, error) -} - -// Sentinel errors for model lookup methods. -var ( - ErrUserLookupNotAvailable = errors.New("user lookup not available") - ErrChannelLookupNotAvailable = errors.New("channel lookup not available") -) - -// Base is embedded in all model structs to provide database access. -type Base struct { - db DB -} - -// SetDB injects the database reference into a model. -func (b *Base) SetDB(d DB) { - b.db = d -} - -// GetDB returns the database interface for use in model methods. -func (b *Base) GetDB() *sql.DB { - return b.db.GetDB() -} - -// GetUserLookup returns the DB as a UserLookup if it implements the interface. -func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn - if ul, ok := b.db.(UserLookup); ok { - return ul - } - - return nil -} - -// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface. -func (b *Base) GetChannelLookup() ChannelLookup { //nolint:ireturn - if cl, ok := b.db.(ChannelLookup); ok { - return cl - } - - return nil -} diff --git a/internal/models/server_link.go b/internal/models/server_link.go deleted file mode 100644 index 004ef67..0000000 --- a/internal/models/server_link.go +++ /dev/null @@ -1,18 +0,0 @@ -package models - -import ( - "time" -) - -// ServerLink represents a federation peer server configuration. -type ServerLink struct { - Base - - ID string `json:"id"` - Name string `json:"name"` - URL string `json:"url"` - SharedKeyHash string `json:"-"` - IsActive bool `json:"isActive"` - CreatedAt time.Time `json:"createdAt"` - LastSeenAt *time.Time `json:"lastSeenAt,omitempty"` -} diff --git a/internal/models/session.go b/internal/models/session.go deleted file mode 100644 index 295def2..0000000 --- a/internal/models/session.go +++ /dev/null @@ -1,26 +0,0 @@ -package models - -import ( - "context" - "time" -) - -// Session represents a server-held user session. -type Session struct { - Base - - ID string `json:"id"` - UserID string `json:"userId"` - CreatedAt time.Time `json:"createdAt"` - LastActiveAt time.Time `json:"lastActiveAt"` - ExpiresAt *time.Time `json:"expiresAt,omitempty"` -} - -// User returns the user who owns this session. -func (s *Session) User(ctx context.Context) (*User, error) { - if ul := s.GetUserLookup(); ul != nil { - return ul.GetUserByID(ctx, s.UserID) - } - - return nil, ErrUserLookupNotAvailable -} diff --git a/internal/models/user.go b/internal/models/user.go deleted file mode 100644 index f3d778f..0000000 --- a/internal/models/user.go +++ /dev/null @@ -1,92 +0,0 @@ -package models - -import ( - "context" - "time" -) - -// User represents a registered user account. -type User struct { - Base - - ID string `json:"id"` - Nick string `json:"nick"` - PasswordHash string `json:"-"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` - LastSeenAt *time.Time `json:"lastSeenAt,omitempty"` -} - -// Channels returns all channels the user is a member of. -func (u *User) Channels(ctx context.Context) ([]*Channel, error) { - rows, err := u.GetDB().QueryContext(ctx, ` - SELECT c.id, c.name, c.topic, c.modes, c.created_at, c.updated_at - FROM channels c - JOIN channel_members cm ON cm.channel_id = c.id - WHERE cm.user_id = ? - ORDER BY c.name`, - u.ID, - ) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - channels := []*Channel{} - - for rows.Next() { - c := &Channel{} - c.SetDB(u.db) - - err = rows.Scan( - &c.ID, &c.Name, &c.Topic, &c.Modes, - &c.CreatedAt, &c.UpdatedAt, - ) - if err != nil { - return nil, err - } - - channels = append(channels, c) - } - - return channels, rows.Err() -} - -// QueuedMessages returns undelivered messages for this user. -func (u *User) QueuedMessages(ctx context.Context) ([]*Message, error) { - rows, err := u.GetDB().QueryContext(ctx, ` - SELECT m.id, m.ts, m.from_user_id, m.from_nick, - m.target, m.type, m.body, m.meta, m.created_at - FROM messages m - JOIN message_queue mq ON mq.message_id = m.id - WHERE mq.user_id = ? - ORDER BY mq.queued_at ASC`, - u.ID, - ) - if err != nil { - return nil, err - } - - defer func() { _ = rows.Close() }() - - messages := []*Message{} - - for rows.Next() { - msg := &Message{} - msg.SetDB(u.db) - - err = rows.Scan( - &msg.ID, &msg.Timestamp, &msg.FromUserID, - &msg.FromNick, &msg.Target, &msg.Type, - &msg.Body, &msg.Meta, &msg.CreatedAt, - ) - if err != nil { - return nil, err - } - - messages = append(messages, msg) - } - - return messages, rows.Err() -} diff --git a/internal/server/http.go b/internal/server/http.go index 4b01db2..979f4cb 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,6 +4,6 @@ import "time" const ( httpReadTimeout = 10 * time.Second - httpWriteTimeout = 10 * time.Second + httpWriteTimeout = 60 * time.Second maxHeaderBytes = 1 << 20 ) diff --git a/web/dist/app.js b/web/dist/app.js index 2a5d789..38642c2 100644 --- a/web/dist/app.js +++ b/web/dist/app.js @@ -1 +1,464 @@ -(()=>{var te,b,Ce,Ge,O,ge,Se,xe,Te,ae,ie,se,Qe,q={},Ee=[],Xe=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,ne=Array.isArray;function U(e,t){for(var n in t)e[n]=t[n];return e}function le(e){e&&e.parentNode&&e.parentNode.removeChild(e)}function m(e,t,n){var _,i,o,s={};for(o in t)o=="key"?_=t[o]:o=="ref"?i=t[o]:s[o]=t[o];if(arguments.length>2&&(s.children=arguments.length>3?te.call(arguments,2):n),typeof e=="function"&&e.defaultProps!=null)for(o in e.defaultProps)s[o]===void 0&&(s[o]=e.defaultProps[o]);return Y(e,s,_,i,null)}function Y(e,t,n,_,i){var o={type:e,props:t,key:n,ref:_,__k:null,__:null,__b:0,__e:null,__c:null,constructor:void 0,__v:i??++Ce,__i:-1,__u:0};return i==null&&b.vnode!=null&&b.vnode(o),o}function _e(e){return e.children}function Z(e,t){this.props=e,this.context=t}function j(e,t){if(t==null)return e.__?j(e.__,e.__i+1):null;for(var n;tl&&O.sort(xe),e=O.shift(),l=O.length,e.__d&&(n=void 0,_=void 0,i=(_=(t=e).__v).__e,o=[],s=[],t.__P&&((n=U({},_)).__v=_.__v+1,b.vnode&&b.vnode(n),ue(t.__P,n,_,t.__n,t.__P.namespaceURI,32&_.__u?[i]:null,o,i??j(_),!!(32&_.__u),s),n.__v=_.__v,n.__.__k[n.__i]=n,He(o,n,s),_.__e=_.__=null,n.__e!=i&&Pe(n)));ee.__r=0}function Ie(e,t,n,_,i,o,s,l,d,c,h){var r,u,f,k,P,w,g,y=_&&_.__k||Ee,D=t.length;for(d=Ye(n,t,y,d,D),r=0;r0?s=e.__k[o]=Y(s.type,s.props,s.key,s.ref?s.ref:null,s.__v):e.__k[o]=s,d=o+u,s.__=e,s.__b=e.__b+1,l=null,(c=s.__i=Ze(s,n,d,r))!=-1&&(r--,(l=n[c])&&(l.__u|=2)),l==null||l.__v==null?(c==-1&&(i>h?u--:id?u--:u++,s.__u|=4))):e.__k[o]=null;if(r)for(o=0;o(h?1:0)){for(i=n-1,o=n+1;i>=0||o=0?i--:o++])!=null&&(2&c.__u)==0&&l==c.key&&d==c.type)return s}return-1}function ke(e,t,n){t[0]=="-"?e.setProperty(t,n??""):e[t]=n==null?"":typeof n!="number"||Xe.test(t)?n:n+"px"}function X(e,t,n,_,i){var o,s;e:if(t=="style")if(typeof n=="string")e.style.cssText=n;else{if(typeof _=="string"&&(e.style.cssText=_=""),_)for(t in _)n&&t in n||ke(e.style,t,"");if(n)for(t in n)_&&n[t]==_[t]||ke(e.style,t,n[t])}else if(t[0]=="o"&&t[1]=="n")o=t!=(t=t.replace(Te,"$1")),s=t.toLowerCase(),t=s in e||t=="onFocusOut"||t=="onFocusIn"?s.slice(2):t.slice(2),e.l||(e.l={}),e.l[t+o]=n,n?_?n.u=_.u:(n.u=ae,e.addEventListener(t,o?se:ie,o)):e.removeEventListener(t,o?se:ie,o);else{if(i=="http://www.w3.org/2000/svg")t=t.replace(/xlink(H|:h)/,"h").replace(/sName$/,"s");else if(t!="width"&&t!="height"&&t!="href"&&t!="list"&&t!="form"&&t!="tabIndex"&&t!="download"&&t!="rowSpan"&&t!="colSpan"&&t!="role"&&t!="popover"&&t in e)try{e[t]=n??"";break e}catch{}typeof n=="function"||(n==null||n===!1&&t[4]!="-"?e.removeAttribute(t):e.setAttribute(t,t=="popover"&&n==1?"":n))}}function we(e){return function(t){if(this.l){var n=this.l[t.type+e];if(t.t==null)t.t=ae++;else if(t.t0?e:ne(e)?e.map(De):U({},e)}function et(e,t,n,_,i,o,s,l,d){var c,h,r,u,f,k,P,w=n.props||q,g=t.props,y=t.type;if(y=="svg"?i="http://www.w3.org/2000/svg":y=="math"?i="http://www.w3.org/1998/Math/MathML":i||(i="http://www.w3.org/1999/xhtml"),o!=null){for(c=0;c=n.__.length&&n.__.push({}),n.__[e]}function I(e){return K=1,nt(qe,e)}function nt(e,t,n){var _=he(z++,2);if(_.t=e,!_.__c&&(_.__=[n?n(t):qe(void 0,t),function(l){var d=_.__N?_.__N[0]:_.__[0],c=_.t(d,l);d!==c&&(_.__N=[c,_.__[1]],_.__c.setState({}))}],_.__c=S,!S.__f)){var i=function(l,d,c){if(!_.__c.__H)return!0;var h=_.__c.__H.__.filter(function(u){return!!u.__c});if(h.every(function(u){return!u.__N}))return!o||o.call(this,l,d,c);var r=_.__c.props!==l;return h.forEach(function(u){if(u.__N){var f=u.__[0];u.__=u.__N,u.__N=void 0,f!==u.__[0]&&(r=!0)}}),o&&o.call(this,l,d,c)||r};S.__f=!0;var o=S.shouldComponentUpdate,s=S.componentWillUpdate;S.componentWillUpdate=function(l,d,c){if(this.__e){var h=o;o=void 0,i(l,d,c),o=h}s&&s.call(this,l,d,c)},S.shouldComponentUpdate=i}return _.__N||_.__}function W(e,t){var n=he(z++,3);!x.__s&&Ve(n.__H,t)&&(n.__=e,n.u=t,S.__H.__h.push(n))}function G(e){return K=5,Be(function(){return{current:e}},[])}function Be(e,t){var n=he(z++,7);return Ve(n.__H,t)&&(n.__=e(),n.__H=t,n.__h=e),n.__}function oe(e,t){return K=8,Be(function(){return e},t)}function _t(){for(var e;e=Je.shift();)if(e.__P&&e.__H)try{e.__H.__h.forEach(re),e.__H.__h.forEach(de),e.__H.__h=[]}catch(t){e.__H.__h=[],x.__e(t,e.__v)}}x.__b=function(e){S=null,Ue&&Ue(e)},x.__=function(e,t){e&&t.__k&&t.__k.__m&&(e.__m=t.__k.__m),We&&We(e,t)},x.__r=function(e){Le&&Le(e),z=0;var t=(S=e.__c).__H;t&&(pe===S?(t.__h=[],S.__h=[],t.__.forEach(function(n){n.__N&&(n.__=n.__N),n.u=n.__N=void 0})):(t.__h.forEach(re),t.__h.forEach(de),t.__h=[],z=0)),pe=S},x.diffed=function(e){Fe&&Fe(e);var t=e.__c;t&&t.__H&&(t.__H.__h.length&&(Je.push(t)!==1&&Ae===x.requestAnimationFrame||((Ae=x.requestAnimationFrame)||rt)(_t)),t.__H.__.forEach(function(n){n.u&&(n.__H=n.u),n.u=void 0})),pe=S=null},x.__c=function(e,t){t.some(function(n){try{n.__h.forEach(re),n.__h=n.__h.filter(function(_){return!_.__||de(_)})}catch(_){t.some(function(i){i.__h&&(i.__h=[])}),t=[],x.__e(_,n.__v)}}),Oe&&Oe(e,t)},x.unmount=function(e){je&&je(e);var t,n=e.__c;n&&n.__H&&(n.__H.__.forEach(function(_){try{re(_)}catch(i){t=i}}),n.__H=void 0,t&&x.__e(t,n.__v))};var Re=typeof requestAnimationFrame=="function";function rt(e){var t,n=function(){clearTimeout(_),Re&&cancelAnimationFrame(t),setTimeout(e)},_=setTimeout(n,35);Re&&(t=requestAnimationFrame(n))}function re(e){var t=S,n=e.__c;typeof n=="function"&&(e.__c=void 0,n()),S=t}function de(e){var t=S;e.__c=e.__(),S=t}function Ve(e,t){return!e||e.length!==t.length||t.some(function(n,_){return n!==e[_]})}function qe(e,t){return typeof t=="function"?t(e):t}var ot="/api/v1";function H(e,t={}){let n=localStorage.getItem("chat_token"),_={"Content-Type":"application/json",...t.headers||{}};return n&&(_.Authorization=`Bearer ${n}`),fetch(ot+e,{...t,headers:_}).then(async i=>{let o=await i.json().catch(()=>null);if(!i.ok)throw{status:i.status,data:o};return o})}function it(e){return new Date(e).toLocaleTimeString([],{hour:"2-digit",minute:"2-digit",second:"2-digit"})}function Ke(e){let t=0;for(let _=0;_{H("/server").then(u=>{u.name&&d(u.name),u.motd&&s(u.motd)}).catch(()=>{});let r=localStorage.getItem("chat_token");r&&H("/me").then(u=>e(u.nick,r)).catch(()=>localStorage.removeItem("chat_token")),c.current?.focus()},[]),m("div",{class:"login-screen"},m("h1",null,l),o&&m("div",{class:"motd"},o),m("form",{onSubmit:async r=>{r.preventDefault(),i("");try{let u=await H("/register",{method:"POST",body:JSON.stringify({nick:t.trim()})});localStorage.setItem("chat_token",u.token),e(u.nick,u.token)}catch(u){i(u.data?.error||"Connection failed")}}},m("input",{ref:c,type:"text",placeholder:"Choose a nickname...",value:t,onInput:r=>n(r.target.value),maxLength:32,autoFocus:!0}),m("button",{type:"submit"},"Connect")),_&&m("div",{class:"error"},_))}function ze({msg:e}){return m("div",{class:`message ${e.system?"system":""}`},m("span",{class:"timestamp"},it(e.createdAt)),m("span",{class:"nick",style:{color:e.system?void 0:Ke(e.nick)}},e.nick),m("span",{class:"content"},e.content))}function ct(){let[e,t]=I(!1),[n,_]=I(""),[i,o]=I([{type:"server",name:"Server"}]),[s,l]=I(0),[d,c]=I({server:[]}),[h,r]=I({}),[u,f]=I(""),[k,P]=I(""),[w,g]=I(0),y=G(),D=G(),N=G(),$=oe((a,p)=>{c(v=>({...v,[a]:[...v[a]||[],p]}))},[]),E=oe((a,p)=>{$(a,{id:Date.now(),nick:"*",content:p,createdAt:new Date().toISOString(),system:!0})},[$]),Q=oe((a,p)=>{_(a),t(!0),E("server",`Connected as ${a}`),H("/server").then(v=>{v.motd&&E("server",`MOTD: ${v.motd}`)}).catch(()=>{})},[E]);W(()=>{if(!e)return;let a=!0,p=async()=>{try{let v=await H(`/poll?after=${w}`);if(!a)return;let T=w;for(let C of v)if(C.id>T&&(T=C.id),C.isDm){let B=C.nick===n?C.dmTarget:C.nick;o(V=>V.find(ye=>ye.type==="dm"&&ye.name===B)?V:[...V,{type:"dm",name:B}]),$(B,C)}else C.channel&&$(C.channel,C);T>w&&g(T)}catch{}};return N.current=setInterval(p,1500),p(),()=>{a=!1,clearInterval(N.current)}},[e,w,n,$]),W(()=>{if(!e)return;let a=i[s];if(!a||a.type!=="channel")return;let p=a.name.replace("#","");H(`/channels/${p}/members`).then(T=>{r(C=>({...C,[a.name]:T}))}).catch(()=>{});let v=setInterval(()=>{H(`/channels/${p}/members`).then(T=>{r(C=>({...C,[a.name]:T}))}).catch(()=>{})},5e3);return()=>clearInterval(v)},[e,s,i]),W(()=>{y.current?.scrollIntoView({behavior:"smooth"})},[d,s]),W(()=>{D.current?.focus()},[s]);let L=async a=>{if(a){a=a.trim(),a.startsWith("#")||(a="#"+a);try{await H("/channels/join",{method:"POST",body:JSON.stringify({channel:a})}),o(p=>p.find(v=>v.type==="channel"&&v.name===a)?p:[...p,{type:"channel",name:a}]),l(i.length),E(a,`Joined ${a}`),P("")}catch(p){E("server",`Failed to join ${a}: ${p.data?.error||"error"}`)}}},F=async a=>{let p=a.replace("#","");try{await H(`/channels/${p}/part`,{method:"DELETE"})}catch{}o(v=>v.filter(C=>!(C.type==="channel"&&C.name===a))),l(0)},R=a=>{let p=i[a];p.type==="channel"?F(p.name):p.type==="dm"&&(o(v=>v.filter((T,C)=>C!==a)),s>=a&&l(Math.max(0,s-1)))},M=a=>{o(p=>p.find(v=>v.type==="dm"&&v.name===a)?p:[...p,{type:"dm",name:a}]),l(i.findIndex(p=>p.type==="dm"&&p.name===a)||i.length)},A=async()=>{let a=u.trim();if(!a)return;f("");let p=i[s];if(!(!p||p.type==="server")){if(a.startsWith("/")){let v=a.split(" "),T=v[0].toLowerCase();if(T==="/join"&&v[1]){L(v[1]);return}if(T==="/part"){p.type==="channel"&&F(p.name);return}if(T==="/msg"&&v[1]&&v.slice(2).join(" ")){let C=v[1],B=v.slice(2).join(" ");try{await H(`/dm/${C}/messages`,{method:"POST",body:JSON.stringify({content:B})}),M(C)}catch(V){E("server",`Failed to send DM: ${V.data?.error||"error"}`)}return}if(T==="/nick"){E("server","Nick changes not yet supported");return}E("server",`Unknown command: ${T}`);return}if(p.type==="channel"){let v=p.name.replace("#","");try{await H(`/channels/${v}/messages`,{method:"POST",body:JSON.stringify({content:a})})}catch(T){E(p.name,`Send failed: ${T.data?.error||"error"}`)}}else if(p.type==="dm")try{await H(`/dm/${p.name}/messages`,{method:"POST",body:JSON.stringify({content:a})})}catch(v){E(p.name,`Send failed: ${v.data?.error||"error"}`)}}};if(!e)return m(st,{onLogin:Q});let J=i[s]||i[0],me=d[J.name]||[],ve=h[J.name]||[];return m("div",{class:"app"},m("div",{class:"tab-bar"},i.map((a,p)=>m("div",{class:`tab ${p===s?"active":""}`,onClick:()=>l(p)},a.type==="dm"?`\u2192${a.name}`:a.name,a.type!=="server"&&m("span",{class:"close-btn",onClick:v=>{v.stopPropagation(),R(p)}},"\xD7"))),m("div",{class:"join-dialog"},m("input",{placeholder:"#channel",value:k,onInput:a=>P(a.target.value),onKeyDown:a=>a.key==="Enter"&&L(k)}),m("button",{onClick:()=>L(k)},"Join"))),m("div",{class:"content"},m("div",{class:"messages-pane"},J.type==="server"?m("div",{class:"server-messages"},me.map(a=>m(ze,{msg:a})),m("div",{ref:y})):m(Fragment,null,m("div",{class:"messages"},me.map(a=>m(ze,{msg:a})),m("div",{ref:y})),m("div",{class:"input-bar"},m("input",{ref:D,placeholder:`Message ${J.name}...`,value:u,onInput:a=>f(a.target.value),onKeyDown:a=>a.key==="Enter"&&A()}),m("button",{onClick:A},"Send")))),J.type==="channel"&&m("div",{class:"user-list"},m("h3",null,"Users (",ve.length,")"),ve.map(a=>m("div",{class:"user",onClick:()=>M(a.nick),style:{color:Ke(a.nick)}},a.nick)))))}$e(m(ct,null),document.getElementById("root"));})(); +(()=>{ +// Minimal Preact-like runtime using raw DOM for simplicity and zero build step. +// This replaces the previous Preact SPA with a vanilla JS implementation. + +const API = '/api/v1'; +let token = localStorage.getItem('chat_token'); +let myNick = ''; +let myUID = 0; +let lastQueueID = 0; +let pollController = null; +let channels = []; // [{name, topic}] +let activeTab = null; // '#channel' or 'nick' or 'server' +let messages = {}; // target -> [{command,from,to,body,ts,system}] +let unread = {}; // target -> count +let members = {}; // '#channel' -> [{nick}] + +function $(sel, parent) { return (parent||document).querySelector(sel); } +function $$(sel, parent) { return [...(parent||document).querySelectorAll(sel)]; } +function el(tag, attrs, ...children) { + const e = document.createElement(tag); + if (attrs) Object.entries(attrs).forEach(([k,v]) => { + if (k === 'class') e.className = v; + else if (k.startsWith('on')) e.addEventListener(k.slice(2).toLowerCase(), v); + else if (k === 'style' && typeof v === 'object') Object.assign(e.style, v); + else e.setAttribute(k, v); + }); + children.flat(Infinity).forEach(c => { + if (c == null) return; + e.appendChild(typeof c === 'string' ? document.createTextNode(c) : c); + }); + return e; +} + +async function api(path, opts = {}) { + const headers = {'Content-Type': 'application/json', ...(opts.headers||{})}; + if (token) headers['Authorization'] = `Bearer ${token}`; + const resp = await fetch(API + path, {...opts, headers, signal: opts.signal}); + const data = await resp.json().catch(() => null); + if (!resp.ok) throw {status: resp.status, data}; + return data; +} + +function nickColor(nick) { + let h = 0; + for (let i = 0; i < nick.length; i++) h = nick.charCodeAt(i) + ((h << 5) - h); + return `hsl(${Math.abs(h) % 360}, 70%, 65%)`; +} + +function formatTime(ts) { + return new Date(ts).toLocaleTimeString([], {hour:'2-digit',minute:'2-digit',second:'2-digit'}); +} + +function addMessage(target, msg) { + if (!messages[target]) messages[target] = []; + messages[target].push(msg); + if (messages[target].length > 500) messages[target] = messages[target].slice(-400); + if (target !== activeTab) { + unread[target] = (unread[target] || 0) + 1; + renderTabs(); + } + if (target === activeTab) renderMessages(); +} + +function addSystemMessage(target, text) { + addMessage(target, {command: 'SYSTEM', from: '*', body: [text], ts: new Date().toISOString(), system: true}); +} + +// --- Rendering --- + +function renderApp() { + const root = $('#root'); + root.innerHTML = ''; + root.appendChild(el('div', {class:'app'}, + el('div', {class:'tab-bar', id:'tabs'}), + el('div', {class:'content'}, + el('div', {class:'messages-pane'}, + el('div', {class:'messages', id:'msg-list'}), + el('div', {class:'input-bar', id:'input-bar'}, + el('input', {id:'msg-input', placeholder:'Message...', onKeydown: e => { if(e.key==='Enter') sendInput(); }}), + el('button', {onClick: sendInput}, 'Send') + ) + ), + el('div', {class:'user-list', id:'user-list'}) + ) + )); + renderTabs(); + renderMessages(); + renderMembers(); + $('#msg-input')?.focus(); +} + +function renderTabs() { + const container = $('#tabs'); + if (!container) return; + container.innerHTML = ''; + + // Server tab + const serverTab = el('div', {class: `tab ${activeTab === 'server' ? 'active' : ''}`, onClick: () => switchTab('server')}, 'Server'); + container.appendChild(serverTab); + + // Channel tabs + channels.forEach(ch => { + const badge = unread[ch.name] ? ` (${unread[ch.name]})` : ''; + const tab = el('div', {class: `tab ${activeTab === ch.name ? 'active' : ''}`}, + el('span', {onClick: () => switchTab(ch.name)}, ch.name + badge), + el('span', {class:'close-btn', onClick: (e) => { e.stopPropagation(); partChannel(ch.name); }}, '×') + ); + container.appendChild(tab); + }); + + // DM tabs + Object.keys(messages).filter(k => !k.startsWith('#') && k !== 'server').forEach(nick => { + const badge = unread[nick] ? ` (${unread[nick]})` : ''; + const tab = el('div', {class: `tab ${activeTab === nick ? 'active' : ''}`}, + el('span', {onClick: () => switchTab(nick)}, '→' + nick + badge), + el('span', {class:'close-btn', onClick: (e) => { e.stopPropagation(); delete messages[nick]; delete unread[nick]; if(activeTab===nick) switchTab('server'); else renderTabs(); }}, '×') + ); + container.appendChild(tab); + }); + + // Join input + const joinDiv = el('div', {class:'join-dialog'}, + el('input', {id:'join-input', placeholder:'#channel', onKeydown: e => { if(e.key==='Enter') joinFromInput(); }}), + el('button', {onClick: joinFromInput}, 'Join') + ); + container.appendChild(joinDiv); +} + +function renderMessages() { + const container = $('#msg-list'); + if (!container) return; + const msgs = messages[activeTab] || []; + container.innerHTML = ''; + msgs.forEach(m => { + const isSystem = m.system || ['JOIN','PART','QUIT','NICK','TOPIC'].includes(m.command); + const bodyText = Array.isArray(m.body) ? m.body.join('\n') : (m.body || ''); + + let displayText = bodyText; + if (m.command === 'JOIN') displayText = `${m.from} has joined ${m.to}`; + else if (m.command === 'PART') displayText = `${m.from} has left ${m.to}` + (bodyText ? ` (${bodyText})` : ''); + else if (m.command === 'QUIT') displayText = `${m.from} has quit` + (bodyText ? ` (${bodyText})` : ''); + else if (m.command === 'NICK') displayText = `${m.from} is now known as ${bodyText}`; + else if (m.command === 'TOPIC') displayText = `${m.from} set topic: ${bodyText}`; + + const msgEl = el('div', {class: `message ${isSystem ? 'system' : ''}`}, + el('span', {class:'timestamp'}, m.ts ? formatTime(m.ts) : ''), + isSystem + ? el('span', {class:'nick'}, '*') + : el('span', {class:'nick', style:{color: nickColor(m.from)}}, m.from), + el('span', {class:'content'}, displayText) + ); + container.appendChild(msgEl); + }); + container.scrollTop = container.scrollHeight; +} + +function renderMembers() { + const container = $('#user-list'); + if (!container) return; + if (!activeTab || !activeTab.startsWith('#')) { + container.innerHTML = ''; + return; + } + const mems = members[activeTab] || []; + container.innerHTML = ''; + container.appendChild(el('h3', null, `Users (${mems.length})`)); + mems.forEach(m => { + container.appendChild(el('div', {class:'user', style:{color: nickColor(m.nick)}, onClick: () => openDM(m.nick)}, m.nick)); + }); +} + +function switchTab(target) { + activeTab = target; + unread[target] = 0; + renderTabs(); + renderMessages(); + renderMembers(); + if (activeTab?.startsWith('#')) fetchMembers(activeTab); + $('#msg-input')?.focus(); +} + +// --- Actions --- + +async function joinFromInput() { + const input = $('#join-input'); + if (!input) return; + let name = input.value.trim(); + if (!name) return; + if (!name.startsWith('#')) name = '#' + name; + input.value = ''; + try { + await api('/messages', {method:'POST', body: JSON.stringify({command:'JOIN', to: name})}); + } catch(e) { + addSystemMessage('server', `Failed to join ${name}: ${e.data?.error || 'error'}`); + } +} + +async function partChannel(name) { + try { + await api('/messages', {method:'POST', body: JSON.stringify({command:'PART', to: name})}); + } catch(e) {} + channels = channels.filter(c => c.name !== name); + delete members[name]; + if (activeTab === name) switchTab('server'); + else renderTabs(); +} + +function openDM(nick) { + if (nick === myNick) return; + if (!messages[nick]) messages[nick] = []; + switchTab(nick); +} + +async function sendInput() { + const input = $('#msg-input'); + if (!input) return; + const text = input.value.trim(); + if (!text) return; + input.value = ''; + + if (text.startsWith('/')) { + const parts = text.split(' '); + const cmd = parts[0].toLowerCase(); + if (cmd === '/join' && parts[1]) { $('#join-input').value = parts[1]; joinFromInput(); return; } + if (cmd === '/part') { if(activeTab?.startsWith('#')) partChannel(activeTab); return; } + if (cmd === '/nick' && parts[1]) { + try { + await api('/messages', {method:'POST', body: JSON.stringify({command:'NICK', body:[parts[1]]})}); + } catch(e) { + addSystemMessage(activeTab||'server', `Nick change failed: ${e.data?.error || 'error'}`); + } + return; + } + if (cmd === '/msg' && parts[1] && parts.slice(2).join(' ')) { + const target = parts[1]; + const msg = parts.slice(2).join(' '); + try { + await api('/messages', {method:'POST', body: JSON.stringify({command:'PRIVMSG', to: target, body:[msg]})}); + openDM(target); + } catch(e) { + addSystemMessage(activeTab||'server', `DM failed: ${e.data?.error || 'error'}`); + } + return; + } + if (cmd === '/quit') { + try { await api('/messages', {method:'POST', body: JSON.stringify({command:'QUIT'})}); } catch(e) {} + localStorage.removeItem('chat_token'); + location.reload(); + return; + } + addSystemMessage(activeTab||'server', `Unknown command: ${cmd}`); + return; + } + + if (!activeTab || activeTab === 'server') { + addSystemMessage('server', 'Select a channel or user to send messages'); + return; + } + + try { + await api('/messages', {method:'POST', body: JSON.stringify({command:'PRIVMSG', to: activeTab, body:[text]})}); + } catch(e) { + addSystemMessage(activeTab, `Send failed: ${e.data?.error || 'error'}`); + } +} + +async function fetchMembers(channel) { + try { + const name = channel.replace('#',''); + const data = await api(`/channels/${name}/members`); + members[channel] = data; + renderMembers(); + } catch(e) {} +} + +// --- Polling --- + +async function pollLoop() { + while (true) { + try { + if (pollController) pollController.abort(); + pollController = new AbortController(); + const data = await api(`/messages?after=${lastQueueID}&timeout=15`, {signal: pollController.signal}); + if (data.last_id) lastQueueID = data.last_id; + + for (const msg of (data.messages || [])) { + handleMessage(msg); + } + } catch(e) { + if (e instanceof DOMException && e.name === 'AbortError') continue; + if (e.status === 401) { + localStorage.removeItem('chat_token'); + location.reload(); + return; + } + await new Promise(r => setTimeout(r, 2000)); + } + } +} + +function handleMessage(msg) { + const body = Array.isArray(msg.body) ? msg.body : []; + const bodyText = body.join('\n'); + + switch (msg.command) { + case 'PRIVMSG': + case 'NOTICE': { + let target = msg.to; + // DM: if it's to me, show under sender's nick tab + if (!target.startsWith('#')) { + target = msg.from === myNick ? msg.to : msg.from; + if (!messages[target]) messages[target] = []; + } + addMessage(target, msg); + break; + } + case 'JOIN': { + addMessage(msg.to, msg); + if (msg.from === myNick) { + // We joined a channel + if (!channels.find(c => c.name === msg.to)) { + channels.push({name: msg.to, topic: ''}); + } + switchTab(msg.to); + fetchMembers(msg.to); + } else if (activeTab === msg.to) { + fetchMembers(msg.to); + } + break; + } + case 'PART': { + addMessage(msg.to, msg); + if (msg.from === myNick) { + channels = channels.filter(c => c.name !== msg.to); + if (activeTab === msg.to) switchTab('server'); + else renderTabs(); + } else if (activeTab === msg.to) { + fetchMembers(msg.to); + } + break; + } + case 'QUIT': { + // Show in all channels where this user might be + channels.forEach(ch => { + addMessage(ch.name, msg); + }); + break; + } + case 'NICK': { + const newNick = body[0] || ''; + if (msg.from === myNick) { + myNick = newNick; + addSystemMessage(activeTab || 'server', `You are now known as ${newNick}`); + } else { + channels.forEach(ch => { + addMessage(ch.name, msg); + }); + } + break; + } + case 'TOPIC': { + addMessage(msg.to, msg); + const ch = channels.find(c => c.name === msg.to); + if (ch) ch.topic = bodyText; + break; + } + default: + addSystemMessage('server', `[${msg.command}] ${bodyText}`); + } +} + +// --- Login --- + +function renderLogin() { + const root = $('#root'); + root.innerHTML = ''; + + let serverName = 'Chat'; + let motd = ''; + + api('/server').then(data => { + if (data.name) { serverName = data.name; $('h1', root).textContent = serverName; } + if (data.motd) { motd = data.motd; const m = $('.motd', root); if(m) m.textContent = motd; } + }).catch(() => {}); + + const form = el('form', {class:'login-screen', onSubmit: async (e) => { + e.preventDefault(); + const nick = $('input', form).value.trim(); + if (!nick) return; + const errEl = $('.error', form); + if (errEl) errEl.textContent = ''; + try { + const data = await api('/session', {method:'POST', body: JSON.stringify({nick})}); + token = data.token; + myNick = data.nick; + myUID = data.id; + localStorage.setItem('chat_token', token); + startApp(); + } catch(err) { + const errEl = $('.error', form) || form.appendChild(el('div', {class:'error'})); + errEl.textContent = err.data?.error || 'Connection failed'; + } + }}, + el('h1', null, serverName), + motd ? el('div', {class:'motd'}, motd) : null, + el('input', {type:'text', placeholder:'Choose a nickname...', maxLength:'32', autofocus:'true'}), + el('button', {type:'submit'}, 'Connect'), + el('div', {class:'error'}) + ); + root.appendChild(form); + $('input', form)?.focus(); +} + +async function startApp() { + messages = {server: []}; + unread = {}; + channels = []; + activeTab = 'server'; + lastQueueID = 0; + + addSystemMessage('server', `Connected as ${myNick}`); + + // Fetch server info + try { + const info = await api('/server'); + if (info.motd) addSystemMessage('server', `MOTD: ${info.motd}`); + } catch(e) {} + + // Fetch current state (channels we're already in) + try { + const state = await api('/state'); + myNick = state.nick; + myUID = state.id; + if (state.channels) { + state.channels.forEach(ch => { + channels.push({name: ch.name, topic: ch.topic}); + if (!messages[ch.name]) messages[ch.name] = []; + }); + if (channels.length > 0) switchTab(channels[0].name); + } + } catch(e) {} + + renderApp(); + pollLoop(); +} + +// --- Init --- + +if (token) { + // Try to resume session + api('/state').then(data => { + myNick = data.nick; + myUID = data.id; + startApp(); + }).catch(() => { + localStorage.removeItem('chat_token'); + token = null; + renderLogin(); + }); +} else { + renderLogin(); +} + +})();