feat: MVP two-user chat via embedded SPA (#9)
All checks were successful
check / check (push) Successful in 1m51s

Backend:
- Session/client UUID model: sessions table (uuid, nick, signing_key),
  clients table (uuid, session_id, token) with per-client message queues
- MOTD delivery as IRC numeric messages (375/372/376) on connect
- EnqueueToSession fans out to all clients of a session
- EnqueueToClient for targeted delivery (MOTD)
- All queries updated for session/client model

SPA client:
- Long-poll loop (15s timeout) instead of setInterval
- IRC message envelope parsing (command/from/to/body)
- Display JOIN/PART/NICK/TOPIC/QUIT system messages
- Nick change via /nick command
- Topic display in header bar
- Unread count badges on inactive tabs
- Auto-rejoin channels on reconnect (localStorage)
- Connection status indicator
- Message deduplication by UUID
- Channel history loaded on join
- /topic command support

Closes #9
This commit is contained in:
clawbot 2026-02-27 02:21:48 -08:00
parent 2d08a8476f
commit 32419fb1f7
8 changed files with 921 additions and 880 deletions

View File

@ -55,76 +55,132 @@ type MemberInfo struct {
LastSeen time.Time `json:"lastSeen"` LastSeen time.Time `json:"lastSeen"`
} }
// CreateUser registers a new user with the given nick. // CreateSession registers a new session and its first client.
func (database *Database) CreateUser( func (database *Database) CreateSession(
ctx context.Context, ctx context.Context,
nick string, nick string,
) (int64, string, error) { ) (int64, int64, string, error) {
sessionUUID := uuid.New().String()
clientUUID := uuid.New().String()
token, err := generateToken() token, err := generateToken()
if err != nil { if err != nil {
return 0, "", err return 0, 0, "", err
} }
now := time.Now() now := time.Now()
res, err := database.conn.ExecContext(ctx, transaction, err := database.conn.BeginTx(ctx, nil)
`INSERT INTO users
(nick, token, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
nick, token, now, now)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("create user: %w", err) return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
} }
userID, _ := res.LastInsertId() res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
sessionUUID, nick, now, now)
if err != nil {
_ = transaction.Rollback()
return userID, token, nil return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, token, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit session: %w", err,
)
}
return sessionID, clientID, token, nil
} }
// GetUserByToken returns user id and nick for a token. // GetSessionByToken returns session id, client id, and
func (database *Database) GetUserByToken( // nick for a client token.
func (database *Database) GetSessionByToken(
ctx context.Context, ctx context.Context,
token string, token string,
) (int64, string, error) { ) (int64, int64, string, error) {
var userID int64 var (
sessionID int64
var nick string clientID int64
nick string
)
err := database.conn.QueryRowContext( err := database.conn.QueryRowContext(
ctx, ctx,
"SELECT id, nick FROM users WHERE token = ?", `SELECT s.id, c.id, s.nick
FROM clients c
INNER JOIN sessions s
ON s.id = c.session_id
WHERE c.token = ?`,
token, token,
).Scan(&userID, &nick) ).Scan(&sessionID, &clientID, &nick)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("get user by token: %w", err) return 0, 0, "", fmt.Errorf(
"get session by token: %w", err,
)
} }
now := time.Now()
_, _ = database.conn.ExecContext( _, _ = database.conn.ExecContext(
ctx, ctx,
"UPDATE users SET last_seen = ? WHERE id = ?", "UPDATE sessions SET last_seen = ? WHERE id = ?",
time.Now(), userID, now, sessionID,
) )
return userID, nick, nil _, _ = database.conn.ExecContext(
ctx,
"UPDATE clients SET last_seen = ? WHERE id = ?",
now, clientID,
)
return sessionID, clientID, nick, nil
} }
// GetUserByNick returns user id for a given nick. // GetSessionByNick returns session id for a given nick.
func (database *Database) GetUserByNick( func (database *Database) GetSessionByNick(
ctx context.Context, ctx context.Context,
nick string, nick string,
) (int64, error) { ) (int64, error) {
var userID int64 var sessionID int64
err := database.conn.QueryRowContext( err := database.conn.QueryRowContext(
ctx, ctx,
"SELECT id FROM users WHERE nick = ?", "SELECT id FROM sessions WHERE nick = ?",
nick, nick,
).Scan(&userID) ).Scan(&sessionID)
if err != nil { if err != nil {
return 0, fmt.Errorf("get user by nick: %w", err) return 0, fmt.Errorf(
"get session by nick: %w", err,
)
} }
return userID, nil return sessionID, nil
} }
// GetChannelByName returns the channel ID for a name. // GetChannelByName returns the channel ID for a name.
@ -179,16 +235,16 @@ func (database *Database) GetOrCreateChannel(
return channelID, nil return channelID, nil
} }
// JoinChannel adds a user to a channel. // JoinChannel adds a session to a channel.
func (database *Database) JoinChannel( func (database *Database) JoinChannel(
ctx context.Context, ctx context.Context,
channelID, userID int64, channelID, sessionID int64,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_members `INSERT OR IGNORE INTO channel_members
(channel_id, user_id, joined_at) (channel_id, session_id, joined_at)
VALUES (?, ?, ?)`, VALUES (?, ?, ?)`,
channelID, userID, time.Now()) channelID, sessionID, time.Now())
if err != nil { if err != nil {
return fmt.Errorf("join channel: %w", err) return fmt.Errorf("join channel: %w", err)
} }
@ -196,15 +252,15 @@ func (database *Database) JoinChannel(
return nil return nil
} }
// PartChannel removes a user from a channel. // PartChannel removes a session from a channel.
func (database *Database) PartChannel( func (database *Database) PartChannel(
ctx context.Context, ctx context.Context,
channelID, userID int64, channelID, sessionID int64,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
`DELETE FROM channel_members `DELETE FROM channel_members
WHERE channel_id = ? AND user_id = ?`, WHERE channel_id = ? AND session_id = ?`,
channelID, userID) channelID, sessionID)
if err != nil { if err != nil {
return fmt.Errorf("part channel: %w", err) return fmt.Errorf("part channel: %w", err)
} }
@ -265,18 +321,18 @@ func scanChannels(
return out, nil return out, nil
} }
// ListChannels returns channels the user has joined. // ListChannels returns channels the session has joined.
func (database *Database) ListChannels( func (database *Database) ListChannels(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) ([]ChannelInfo, error) { ) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic `SELECT c.id, c.name, c.topic
FROM channels c FROM channels c
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.channel_id = c.id ON cm.channel_id = c.id
WHERE cm.user_id = ? WHERE cm.session_id = ?
ORDER BY c.name`, userID) ORDER BY c.name`, sessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list channels: %w", err) return nil, fmt.Errorf("list channels: %w", err)
} }
@ -306,12 +362,12 @@ func (database *Database) ChannelMembers(
channelID int64, channelID int64,
) ([]MemberInfo, error) { ) ([]MemberInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen `SELECT s.id, s.nick, s.last_seen
FROM users u FROM sessions s
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.user_id = u.id ON cm.session_id = s.id
WHERE cm.channel_id = ? WHERE cm.channel_id = ?
ORDER BY u.nick`, channelID) ORDER BY s.nick`, channelID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"query channel members: %w", err, "query channel members: %w", err,
@ -349,17 +405,17 @@ func (database *Database) ChannelMembers(
return members, nil return members, nil
} }
// IsChannelMember checks if a user belongs to a channel. // IsChannelMember checks if a session belongs to a channel.
func (database *Database) IsChannelMember( func (database *Database) IsChannelMember(
ctx context.Context, ctx context.Context,
channelID, userID int64, channelID, sessionID int64,
) (bool, error) { ) (bool, error) {
var count int var count int
err := database.conn.QueryRowContext(ctx, err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM channel_members `SELECT COUNT(*) FROM channel_members
WHERE channel_id = ? AND user_id = ?`, WHERE channel_id = ? AND session_id = ?`,
channelID, userID, channelID, sessionID,
).Scan(&count) ).Scan(&count)
if err != nil { if err != nil {
return false, fmt.Errorf( return false, fmt.Errorf(
@ -397,13 +453,13 @@ func scanInt64s(rows *sql.Rows) ([]int64, error) {
return ids, nil return ids, nil
} }
// GetChannelMemberIDs returns user IDs in a channel. // GetChannelMemberIDs returns session IDs in a channel.
func (database *Database) GetChannelMemberIDs( func (database *Database) GetChannelMemberIDs(
ctx context.Context, ctx context.Context,
channelID int64, channelID int64,
) ([]int64, error) { ) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT user_id FROM channel_members `SELECT session_id FROM channel_members
WHERE channel_id = ?`, channelID) WHERE channel_id = ?`, channelID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
@ -414,17 +470,17 @@ func (database *Database) GetChannelMemberIDs(
return scanInt64s(rows) return scanInt64s(rows)
} }
// GetUserChannelIDs returns channel IDs the user is in. // GetSessionChannelIDs returns channel IDs for a session.
func (database *Database) GetUserChannelIDs( func (database *Database) GetSessionChannelIDs(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) ([]int64, error) { ) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT channel_id FROM channel_members `SELECT channel_id FROM channel_members
WHERE user_id = ?`, userID) WHERE session_id = ?`, sessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"get user channel ids: %w", err, "get session channel ids: %w", err,
) )
} }
@ -467,27 +523,52 @@ func (database *Database) InsertMessage(
return dbID, msgUUID, nil return dbID, msgUUID, nil
} }
// EnqueueMessage adds a message to a user's queue. // EnqueueToSession adds a message to all clients of a
func (database *Database) EnqueueMessage( // session's queues.
func (database *Database) EnqueueToSession(
ctx context.Context, ctx context.Context,
userID, messageID int64, sessionID, messageID int64,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues `INSERT OR IGNORE INTO client_queues
(user_id, message_id, created_at) (client_id, message_id, created_at)
VALUES (?, ?, ?)`, SELECT c.id, ?, ?
userID, messageID, time.Now()) FROM clients c
WHERE c.session_id = ?`,
messageID, time.Now(), sessionID)
if err != nil { if err != nil {
return fmt.Errorf("enqueue message: %w", err) return fmt.Errorf(
"enqueue to session: %w", err,
)
} }
return nil return nil
} }
// PollMessages returns queued messages for a user. // EnqueueToClient adds a message to a specific client's
// queue.
func (database *Database) EnqueueToClient(
ctx context.Context,
clientID, messageID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues
(client_id, message_id, created_at)
VALUES (?, ?, ?)`,
clientID, messageID, time.Now())
if err != nil {
return fmt.Errorf(
"enqueue to client: %w", err,
)
}
return nil
}
// PollMessages returns queued messages for a client.
func (database *Database) PollMessages( func (database *Database) PollMessages(
ctx context.Context, ctx context.Context,
userID, afterQueueID int64, clientID, afterQueueID int64,
limit int, limit int,
) ([]IRCMessage, int64, error) { ) ([]IRCMessage, int64, error) {
if limit <= 0 { if limit <= 0 {
@ -501,9 +582,9 @@ func (database *Database) PollMessages(
FROM client_queues cq FROM client_queues cq
INNER JOIN messages m INNER JOIN messages m
ON m.id = cq.message_id ON m.id = cq.message_id
WHERE cq.user_id = ? AND cq.id > ? WHERE cq.client_id = ? AND cq.id > ?
ORDER BY cq.id ASC LIMIT ?`, ORDER BY cq.id ASC LIMIT ?`,
userID, afterQueueID, limit) clientID, afterQueueID, limit)
if err != nil { if err != nil {
return nil, afterQueueID, fmt.Errorf( return nil, afterQueueID, fmt.Errorf(
"poll messages: %w", err, "poll messages: %w", err,
@ -649,15 +730,15 @@ func reverseMessages(msgs []IRCMessage) {
} }
} }
// ChangeNick updates a user's nickname. // ChangeNick updates a session's nickname.
func (database *Database) ChangeNick( func (database *Database) ChangeNick(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
newNick string, newNick string,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
"UPDATE users SET nick = ? WHERE id = ?", "UPDATE sessions SET nick = ? WHERE id = ?",
newNick, userID) newNick, sessionID)
if err != nil { if err != nil {
return fmt.Errorf("change nick: %w", err) return fmt.Errorf("change nick: %w", err)
} }
@ -681,38 +762,38 @@ func (database *Database) SetTopic(
return nil return nil
} }
// DeleteUser removes a user and all their data. // DeleteSession removes a session and all its data.
func (database *Database) DeleteUser( func (database *Database) DeleteSession(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) error { ) error {
_, err := database.conn.ExecContext( _, err := database.conn.ExecContext(
ctx, ctx,
"DELETE FROM users WHERE id = ?", "DELETE FROM sessions WHERE id = ?",
userID, sessionID,
) )
if err != nil { if err != nil {
return fmt.Errorf("delete user: %w", err) return fmt.Errorf("delete session: %w", err)
} }
return nil return nil
} }
// GetAllChannelMembershipsForUser returns channels // GetSessionChannels returns channels a session
// a user belongs to. // belongs to.
func (database *Database) GetAllChannelMembershipsForUser( func (database *Database) GetSessionChannels(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) ([]ChannelInfo, error) { ) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic `SELECT c.id, c.name, c.topic
FROM channels c FROM channels c
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.channel_id = c.id ON cm.channel_id = c.id
WHERE cm.user_id = ?`, userID) WHERE cm.session_id = ?`, sessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"get memberships: %w", err, "get session channels: %w", err,
) )
} }

View File

@ -27,70 +27,91 @@ func setupTestDB(t *testing.T) *db.Database {
return database return database
} }
func TestCreateUser(t *testing.T) { func TestCreateSession(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
id, token, err := database.CreateUser(ctx, "alice") sessionID, _, token, err := database.CreateSession(
ctx, "alice",
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if id == 0 || token == "" { if sessionID == 0 || token == "" {
t.Fatal("expected valid id and token") t.Fatal("expected valid id and token")
} }
_, _, err = database.CreateUser(ctx, "alice") _, _, dupToken, dupErr := database.CreateSession(
if err == nil { ctx, "alice",
)
if dupErr == nil {
t.Fatal("expected error for duplicate nick") t.Fatal("expected error for duplicate nick")
} }
_ = dupToken
} }
func TestGetUserByToken(t *testing.T) { func TestGetSessionByToken(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
_, token, err := database.CreateUser(ctx, "bob") _, _, token, err := database.CreateSession(ctx, "bob")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
id, nick, err := database.GetUserByToken(ctx, token) sessionID, clientID, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if nick != "bob" || id == 0 { if nick != "bob" || sessionID == 0 || clientID == 0 {
t.Fatalf("expected bob, got %s", nick) t.Fatalf("expected bob, got %s", nick)
} }
_, _, err = database.GetUserByToken(ctx, "badtoken") badSID, badCID, badNick, badErr :=
if err == nil { database.GetSessionByToken(ctx, "badtoken")
if badErr == nil {
t.Fatal("expected error for bad token") t.Fatal("expected error for bad token")
} }
if badSID != 0 || badCID != 0 || badNick != "" {
t.Fatal("expected zero values on error")
}
} }
func TestGetUserByNick(t *testing.T) { func TestGetSessionByNick(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
_, _, err := database.CreateUser(ctx, "charlie") charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
id, err := database.GetUserByNick(ctx, "charlie") if charlieID == 0 || charlieClientID == 0 {
t.Fatal("expected valid session/client IDs")
}
if charlieToken == "" {
t.Fatal("expected non-empty token")
}
id, err := database.GetSessionByNick(ctx, "charlie")
if err != nil || id == 0 { if err != nil || id == 0 {
t.Fatal("expected to find charlie") t.Fatal("expected to find charlie")
} }
_, err = database.GetUserByNick(ctx, "nobody") _, err = database.GetSessionByNick(ctx, "nobody")
if err == nil { if err == nil {
t.Fatal("expected error for unknown nick") t.Fatal("expected error for unknown nick")
} }
@ -129,7 +150,7 @@ func TestJoinAndPart(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, "user1") sid, _, _, err := database.CreateSession(ctx, "user1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -139,22 +160,22 @@ func TestJoinAndPart(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ids, err := database.GetChannelMemberIDs(ctx, chID) ids, err := database.GetChannelMemberIDs(ctx, chID)
if err != nil || len(ids) != 1 || ids[0] != uid { if err != nil || len(ids) != 1 || ids[0] != sid {
t.Fatal("expected user in channel") t.Fatal("expected session in channel")
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.PartChannel(ctx, chID, uid) err = database.PartChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -178,17 +199,17 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
uid, _, err := database.CreateUser(ctx, "temp") sid, _, _, err := database.CreateSession(ctx, "temp")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.PartChannel(ctx, chID, uid) err = database.PartChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -204,7 +225,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
} }
} }
func createUserWithChannels( func createSessionWithChannels(
t *testing.T, t *testing.T,
database *db.Database, database *db.Database,
nick, ch1Name, ch2Name string, nick, ch1Name, ch2Name string,
@ -213,7 +234,7 @@ func createUserWithChannels(
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, nick) sid, _, _, err := database.CreateSession(ctx, nick)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -232,29 +253,29 @@ func createUserWithChannels(
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, ch1, uid) err = database.JoinChannel(ctx, ch1, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, ch2, uid) err = database.JoinChannel(ctx, ch2, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return uid, ch1, ch2 return sid, ch1, ch2
} }
func TestListChannels(t *testing.T) { func TestListChannels(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
uid, _, _ := createUserWithChannels( sid, _, _ := createSessionWithChannels(
t, database, "lister", "#a", "#b", t, database, "lister", "#a", "#b",
) )
channels, err := database.ListChannels( channels, err := database.ListChannels(
t.Context(), uid, t.Context(), sid,
) )
if err != nil || len(channels) != 2 { if err != nil || len(channels) != 2 {
t.Fatalf( t.Fatalf(
@ -295,17 +316,21 @@ func TestChangeNick(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, token, err := database.CreateUser(ctx, "old") sid, _, token, err := database.CreateSession(
ctx, "old",
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.ChangeNick(ctx, uid, "new") err = database.ChangeNick(ctx, sid, "new")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, nick, err := database.GetUserByToken(ctx, token) _, _, nick, err := database.GetSessionByToken(
ctx, token,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -375,7 +400,16 @@ func TestPollMessages(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, "poller") sid, _, token, err := database.CreateSession(
ctx, "poller",
)
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.GetSessionByToken(
ctx, token,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -389,7 +423,7 @@ func TestPollMessages(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.EnqueueMessage(ctx, uid, dbID) err = database.EnqueueToSession(ctx, sid, dbID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -397,7 +431,7 @@ func TestPollMessages(t *testing.T) {
const batchSize = 10 const batchSize = 10
msgs, lastQID, err := database.PollMessages( msgs, lastQID, err := database.PollMessages(
ctx, uid, 0, batchSize, ctx, clientID, 0, batchSize,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -420,7 +454,7 @@ func TestPollMessages(t *testing.T) {
} }
msgs, _, _ = database.PollMessages( msgs, _, _ = database.PollMessages(
ctx, uid, lastQID, batchSize, ctx, clientID, lastQID, batchSize,
) )
if len(msgs) != 0 { if len(msgs) != 0 {
@ -467,13 +501,15 @@ func TestGetHistory(t *testing.T) {
} }
} }
func TestDeleteUser(t *testing.T) { func TestDeleteSession(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, "deleteme") sid, _, _, err := database.CreateSession(
ctx, "deleteme",
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -485,19 +521,19 @@ func TestDeleteUser(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.DeleteUser(ctx, uid) err = database.DeleteSession(ctx, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = database.GetUserByNick(ctx, "deleteme") _, err = database.GetSessionByNick(ctx, "deleteme")
if err == nil { if err == nil {
t.Fatal("user should be deleted") t.Fatal("session should be deleted")
} }
ids, _ := database.GetChannelMemberIDs(ctx, chID) ids, _ := database.GetChannelMemberIDs(ctx, chID)
@ -512,12 +548,12 @@ func TestChannelMembers(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid1, _, err := database.CreateUser(ctx, "m1") sid1, _, _, err := database.CreateSession(ctx, "m1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
uid2, _, err := database.CreateUser(ctx, "m2") sid2, _, _, err := database.CreateSession(ctx, "m2")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -529,12 +565,12 @@ func TestChannelMembers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid1) err = database.JoinChannel(ctx, chID, sid1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid2) err = database.JoinChannel(ctx, chID, sid2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -548,17 +584,17 @@ func TestChannelMembers(t *testing.T) {
} }
} }
func TestGetAllChannelMembershipsForUser(t *testing.T) { func TestGetSessionChannels(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
uid, _, _ := createUserWithChannels( sid, _, _ := createSessionWithChannels(
t, database, "multi", "#m1", "#m2", t, database, "multi", "#m1", "#m2",
) )
channels, err := channels, err :=
database.GetAllChannelMembershipsForUser( database.GetSessionChannels(
t.Context(), uid, t.Context(), sid,
) )
if err != nil || len(channels) != 2 { if err != nil || len(channels) != 2 {
t.Fatalf( t.Fatalf(
@ -567,3 +603,51 @@ func TestGetAllChannelMembershipsForUser(t *testing.T) {
) )
} }
} }
func TestEnqueueToClient(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(
ctx, "enqclient",
)
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.GetSessionByToken(
ctx, token,
)
if err != nil {
t.Fatal(err)
}
body := json.RawMessage(`["test"]`)
dbID, _, err := database.InsertMessage(
ctx, "PRIVMSG", "sender", "#ch", body, nil,
)
if err != nil {
t.Fatal(err)
}
err = database.EnqueueToClient(ctx, clientID, dbID)
if err != nil {
t.Fatal(err)
}
const batchSize = 10
msgs, _, err := database.PollMessages(
ctx, clientID, 0, batchSize,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1, got %d", len(msgs))
}
}

View File

@ -1,15 +1,28 @@
-- Chat server schema (pre-1.0 consolidated) -- Chat server schema (pre-1.0 consolidated)
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
-- Users: IRC-style sessions (no passwords, just nick + token) -- Sessions: IRC-style sessions (no passwords, nick + optional signing key)
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
nick TEXT NOT NULL UNIQUE, nick TEXT NOT NULL UNIQUE,
signing_key TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid);
-- Clients: each session can have multiple connected clients
CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE, token TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE INDEX IF NOT EXISTS idx_users_token ON users(token); CREATE INDEX IF NOT EXISTS idx_clients_token ON clients(token);
CREATE INDEX IF NOT EXISTS idx_clients_session ON clients(session_id);
-- Channels -- Channels
CREATE TABLE IF NOT EXISTS channels ( CREATE TABLE IF NOT EXISTS channels (
@ -24,9 +37,9 @@ CREATE TABLE IF NOT EXISTS channels (
CREATE TABLE IF NOT EXISTS channel_members ( CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP, joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, user_id) UNIQUE(channel_id, session_id)
); );
-- Messages: IRC envelope format -- Messages: IRC envelope format
@ -46,9 +59,9 @@ CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
-- Per-client message queues for fan-out delivery -- Per-client message queues for fan-out delivery
CREATE TABLE IF NOT EXISTS client_queues ( CREATE TABLE IF NOT EXISTS client_queues (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, client_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE, message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, message_id) UNIQUE(client_id, message_id)
); );
CREATE INDEX IF NOT EXISTS idx_client_queues_user ON client_queues(user_id, id); CREATE INDEX IF NOT EXISTS idx_client_queues_client ON client_queues(client_id, id);

View File

@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -37,35 +38,37 @@ func (hdlr *Handlers) maxBodySize() int64 {
return defaultMaxBodySize return defaultMaxBodySize
} }
// authUser extracts the user from the Authorization header. // authSession extracts the session from the client token.
func (hdlr *Handlers) authUser( func (hdlr *Handlers) authSession(
request *http.Request, request *http.Request,
) (int64, string, error) { ) (int64, int64, string, error) {
auth := request.Header.Get("Authorization") auth := request.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") { if !strings.HasPrefix(auth, "Bearer ") {
return 0, "", errUnauthorized return 0, 0, "", errUnauthorized
} }
token := strings.TrimPrefix(auth, "Bearer ") token := strings.TrimPrefix(auth, "Bearer ")
if token == "" { if token == "" {
return 0, "", errUnauthorized return 0, 0, "", errUnauthorized
} }
uid, nick, err := hdlr.params.Database.GetUserByToken( sessionID, clientID, nick, err :=
request.Context(), token, hdlr.params.Database.GetSessionByToken(
) request.Context(), token,
)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("auth: %w", err) return 0, 0, "", fmt.Errorf("auth: %w", err)
} }
return uid, nick, nil return sessionID, clientID, nick, nil
} }
func (hdlr *Handlers) requireAuth( func (hdlr *Handlers) requireAuth(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) (int64, string, bool) { ) (int64, int64, string, bool) {
uid, nick, err := hdlr.authUser(request) sessionID, clientID, nick, err :=
hdlr.authSession(request)
if err != nil { if err != nil {
hdlr.respondError( hdlr.respondError(
writer, request, writer, request,
@ -73,19 +76,19 @@ func (hdlr *Handlers) requireAuth(
http.StatusUnauthorized, http.StatusUnauthorized,
) )
return 0, "", false return 0, 0, "", false
} }
return uid, nick, true return sessionID, clientID, nick, true
} }
// fanOut stores a message and enqueues it to all specified // fanOut stores a message and enqueues it to all specified
// user IDs, then notifies them. // session IDs, then notifies them.
func (hdlr *Handlers) fanOut( func (hdlr *Handlers) fanOut(
request *http.Request, request *http.Request,
command, from, target string, command, from, target string,
body json.RawMessage, body json.RawMessage,
userIDs []int64, sessionIDs []int64,
) (string, error) { ) (string, error) {
dbID, msgUUID, err := hdlr.params.Database.InsertMessage( dbID, msgUUID, err := hdlr.params.Database.InsertMessage(
request.Context(), command, from, target, body, nil, request.Context(), command, from, target, body, nil,
@ -94,16 +97,16 @@ func (hdlr *Handlers) fanOut(
return "", fmt.Errorf("insert message: %w", err) return "", fmt.Errorf("insert message: %w", err)
} }
for _, uid := range userIDs { for _, sid := range sessionIDs {
enqErr := hdlr.params.Database.EnqueueMessage( enqErr := hdlr.params.Database.EnqueueToSession(
request.Context(), uid, dbID, request.Context(), sid, dbID,
) )
if enqErr != nil { if enqErr != nil {
hdlr.log.Error("enqueue failed", hdlr.log.Error("enqueue failed",
"error", enqErr, "user_id", uid) "error", enqErr, "session_id", sid)
} }
hdlr.broker.Notify(uid) hdlr.broker.Notify(sid)
} }
return msgUUID, nil return msgUUID, nil
@ -114,10 +117,10 @@ func (hdlr *Handlers) fanOutSilent(
request *http.Request, request *http.Request,
command, from, target string, command, from, target string,
body json.RawMessage, body json.RawMessage,
userIDs []int64, sessionIDs []int64,
) error { ) error {
_, err := hdlr.fanOut( _, err := hdlr.fanOut(
request, command, from, target, body, userIDs, request, command, from, target, body, sessionIDs,
) )
return err return err
@ -125,16 +128,6 @@ func (hdlr *Handlers) fanOutSilent(
// HandleCreateSession creates a new user session. // HandleCreateSession creates a new user session.
func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc { func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
type createRequest struct {
Nick string `json:"nick"`
}
type createResponse struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
return func( return func(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
@ -143,82 +136,174 @@ func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
writer, request.Body, hdlr.maxBodySize(), writer, request.Body, hdlr.maxBodySize(),
) )
var payload createRequest hdlr.handleCreateSession(writer, request)
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
userID, token, err := hdlr.params.Database.CreateUser(
request.Context(), payload.Nick,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"create user failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.respondJSON(
writer, request,
&createResponse{
ID: userID,
Nick: payload.Nick,
Token: token,
},
http.StatusCreated,
)
} }
} }
// HandleState returns the current user's info and channels. func (hdlr *Handlers) handleCreateSession(
writer http.ResponseWriter,
request *http.Request,
) {
type createRequest struct {
Nick string `json:"nick"`
}
var payload createRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
sessionID, clientID, token, err :=
hdlr.params.Database.CreateSession(
request.Context(), payload.Nick,
)
if err != nil {
hdlr.handleCreateSessionError(
writer, request, err,
)
return
}
hdlr.deliverMOTD(request, clientID, sessionID)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleCreateSessionError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if strings.Contains(err.Error(), "UNIQUE") {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"create session failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
}
// deliverMOTD sends the MOTD as IRC numeric messages to a
// new client.
func (hdlr *Handlers) deliverMOTD(
request *http.Request,
clientID, sessionID int64,
) {
motd := hdlr.params.Config.MOTD
serverName := hdlr.params.Config.ServerName
if serverName == "" {
serverName = "chat"
}
if motd == "" {
return
}
ctx := request.Context()
hdlr.enqueueNumeric(
ctx, clientID, "375", serverName,
"- "+serverName+" Message of the Day -",
)
for line := range strings.SplitSeq(motd, "\n") {
hdlr.enqueueNumeric(
ctx, clientID, "372", serverName,
"- "+line,
)
}
hdlr.enqueueNumeric(
ctx, clientID, "376", serverName,
"End of /MOTD command.",
)
hdlr.broker.Notify(sessionID)
}
func (hdlr *Handlers) enqueueNumeric(
ctx context.Context,
clientID int64,
command, serverName, text string,
) {
body, err := json.Marshal([]string{text})
if err != nil {
hdlr.log.Error(
"marshal numeric body", "error", err,
)
return
}
dbID, _, insertErr := hdlr.params.Database.InsertMessage(
ctx, command, serverName, "",
json.RawMessage(body), nil,
)
if insertErr != nil {
hdlr.log.Error(
"insert numeric message", "error", insertErr,
)
return
}
_ = hdlr.params.Database.EnqueueToClient(
ctx, clientID, dbID,
)
}
// HandleState returns the current session's info and
// channels.
func (hdlr *Handlers) HandleState() http.HandlerFunc { func (hdlr *Handlers) HandleState() http.HandlerFunc {
return func( return func(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
uid, nick, ok := hdlr.requireAuth(writer, request) sessionID, _, nick, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
channels, err := hdlr.params.Database.ListChannels( channels, err := hdlr.params.Database.ListChannels(
request.Context(), uid, request.Context(), sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@ -234,7 +319,7 @@ func (hdlr *Handlers) HandleState() http.HandlerFunc {
} }
hdlr.respondJSON(writer, request, map[string]any{ hdlr.respondJSON(writer, request, map[string]any{
"id": uid, "id": sessionID,
"nick": nick, "nick": nick,
"channels": channels, "channels": channels,
}, http.StatusOK) }, http.StatusOK)
@ -247,7 +332,7 @@ func (hdlr *Handlers) HandleListAllChannels() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
_, _, ok := hdlr.requireAuth(writer, request) _, _, _, ok := hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@ -280,7 +365,7 @@ func (hdlr *Handlers) HandleChannelMembers() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
_, _, ok := hdlr.requireAuth(writer, request) _, _, _, ok := hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@ -328,7 +413,8 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
uid, _, ok := hdlr.requireAuth(writer, request) sessionID, clientID, _, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@ -349,7 +435,7 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
} }
msgs, lastQID, err := hdlr.params.Database.PollMessages( msgs, lastQID, err := hdlr.params.Database.PollMessages(
request.Context(), uid, request.Context(), clientID,
afterID, pollMessageLimit, afterID, pollMessageLimit,
) )
if err != nil { if err != nil {
@ -374,17 +460,20 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
return return
} }
hdlr.longPoll(writer, request, uid, afterID, timeout) hdlr.longPoll(
writer, request,
sessionID, clientID, afterID, timeout,
)
} }
} }
func (hdlr *Handlers) longPoll( func (hdlr *Handlers) longPoll(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid, afterID int64, sessionID, clientID, afterID int64,
timeout int, timeout int,
) { ) {
waitCh := hdlr.broker.Wait(uid) waitCh := hdlr.broker.Wait(sessionID)
timer := time.NewTimer( timer := time.NewTimer(
time.Duration(timeout) * time.Second, time.Duration(timeout) * time.Second,
@ -396,15 +485,15 @@ func (hdlr *Handlers) longPoll(
case <-waitCh: case <-waitCh:
case <-timer.C: case <-timer.C:
case <-request.Context().Done(): case <-request.Context().Done():
hdlr.broker.Remove(uid, waitCh) hdlr.broker.Remove(sessionID, waitCh)
return return
} }
hdlr.broker.Remove(uid, waitCh) hdlr.broker.Remove(sessionID, waitCh)
msgs, lastQID, err := hdlr.params.Database.PollMessages( msgs, lastQID, err := hdlr.params.Database.PollMessages(
request.Context(), uid, request.Context(), clientID,
afterID, pollMessageLimit, afterID, pollMessageLimit,
) )
if err != nil { if err != nil {
@ -443,7 +532,8 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
writer, request.Body, hdlr.maxBodySize(), writer, request.Body, hdlr.maxBodySize(),
) )
uid, nick, ok := hdlr.requireAuth(writer, request) sessionID, _, nick, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@ -492,7 +582,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
} }
hdlr.dispatchCommand( hdlr.dispatchCommand(
writer, request, uid, nick, writer, request, sessionID, nick,
payload.Command, payload.To, payload.Command, payload.To,
payload.Body, bodyLines, payload.Body, bodyLines,
) )
@ -502,7 +592,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
func (hdlr *Handlers) dispatchCommand( func (hdlr *Handlers) dispatchCommand(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
bodyLines func() []string, bodyLines func() []string,
@ -510,20 +600,20 @@ func (hdlr *Handlers) dispatchCommand(
switch command { switch command {
case cmdPrivmsg, "NOTICE": case cmdPrivmsg, "NOTICE":
hdlr.handlePrivmsg( hdlr.handlePrivmsg(
writer, request, uid, nick, writer, request, sessionID, nick,
command, target, body, bodyLines, command, target, body, bodyLines,
) )
case "JOIN": case "JOIN":
hdlr.handleJoin( hdlr.handleJoin(
writer, request, uid, nick, target, writer, request, sessionID, nick, target,
) )
case "PART": case "PART":
hdlr.handlePart( hdlr.handlePart(
writer, request, uid, nick, target, body, writer, request, sessionID, nick, target, body,
) )
case "NICK": case "NICK":
hdlr.handleNick( hdlr.handleNick(
writer, request, uid, nick, bodyLines, writer, request, sessionID, nick, bodyLines,
) )
case "TOPIC": case "TOPIC":
hdlr.handleTopic( hdlr.handleTopic(
@ -531,7 +621,7 @@ func (hdlr *Handlers) dispatchCommand(
) )
case "QUIT": case "QUIT":
hdlr.handleQuit( hdlr.handleQuit(
writer, request, uid, nick, body, writer, request, sessionID, nick, body,
) )
case "PING": case "PING":
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
@ -552,7 +642,7 @@ func (hdlr *Handlers) dispatchCommand(
func (hdlr *Handlers) handlePrivmsg( func (hdlr *Handlers) handlePrivmsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
bodyLines func() []string, bodyLines func() []string,
@ -580,7 +670,7 @@ func (hdlr *Handlers) handlePrivmsg(
if strings.HasPrefix(target, "#") { if strings.HasPrefix(target, "#") {
hdlr.handleChannelMsg( hdlr.handleChannelMsg(
writer, request, uid, nick, writer, request, sessionID, nick,
command, target, body, command, target, body,
) )
@ -588,7 +678,7 @@ func (hdlr *Handlers) handlePrivmsg(
} }
hdlr.handleDirectMsg( hdlr.handleDirectMsg(
writer, request, uid, nick, writer, request, sessionID, nick,
command, target, body, command, target, body,
) )
} }
@ -596,7 +686,7 @@ func (hdlr *Handlers) handlePrivmsg(
func (hdlr *Handlers) handleChannelMsg( func (hdlr *Handlers) handleChannelMsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
) { ) {
@ -614,7 +704,7 @@ func (hdlr *Handlers) handleChannelMsg(
} }
isMember, err := hdlr.params.Database.IsChannelMember( isMember, err := hdlr.params.Database.IsChannelMember(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@ -677,11 +767,11 @@ func (hdlr *Handlers) handleChannelMsg(
func (hdlr *Handlers) handleDirectMsg( func (hdlr *Handlers) handleDirectMsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
) { ) {
targetUID, err := hdlr.params.Database.GetUserByNick( targetSID, err := hdlr.params.Database.GetSessionByNick(
request.Context(), target, request.Context(), target,
) )
if err != nil { if err != nil {
@ -694,9 +784,9 @@ func (hdlr *Handlers) handleDirectMsg(
return return
} }
recipients := []int64{targetUID} recipients := []int64{targetSID}
if targetUID != uid { if targetSID != sessionID {
recipients = append(recipients, uid) recipients = append(recipients, sessionID)
} }
msgUUID, err := hdlr.fanOut( msgUUID, err := hdlr.fanOut(
@ -721,7 +811,7 @@ func (hdlr *Handlers) handleDirectMsg(
func (hdlr *Handlers) handleJoin( func (hdlr *Handlers) handleJoin(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, target string, nick, target string,
) { ) {
if target == "" { if target == "" {
@ -766,7 +856,7 @@ func (hdlr *Handlers) handleJoin(
} }
err = hdlr.params.Database.JoinChannel( err = hdlr.params.Database.JoinChannel(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@ -800,7 +890,7 @@ func (hdlr *Handlers) handleJoin(
func (hdlr *Handlers) handlePart( func (hdlr *Handlers) handlePart(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, target string, nick, target string,
body json.RawMessage, body json.RawMessage,
) { ) {
@ -841,7 +931,7 @@ func (hdlr *Handlers) handlePart(
) )
err = hdlr.params.Database.PartChannel( err = hdlr.params.Database.PartChannel(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@ -871,7 +961,7 @@ func (hdlr *Handlers) handlePart(
func (hdlr *Handlers) handleNick( func (hdlr *Handlers) handleNick(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick string, nick string,
bodyLines func() []string, bodyLines func() []string,
) { ) {
@ -909,7 +999,7 @@ func (hdlr *Handlers) handleNick(
} }
err := hdlr.params.Database.ChangeNick( err := hdlr.params.Database.ChangeNick(
request.Context(), uid, newNick, request.Context(), sessionID, newNick,
) )
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
@ -934,7 +1024,7 @@ func (hdlr *Handlers) handleNick(
return return
} }
hdlr.broadcastNick(request, uid, nick, newNick) hdlr.broadcastNick(request, sessionID, nick, newNick)
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
map[string]string{ map[string]string{
@ -945,15 +1035,15 @@ func (hdlr *Handlers) handleNick(
func (hdlr *Handlers) broadcastNick( func (hdlr *Handlers) broadcastNick(
request *http.Request, request *http.Request,
uid int64, sessionID int64,
oldNick, newNick string, oldNick, newNick string,
) { ) {
channels, _ := hdlr.params.Database. channels, _ := hdlr.params.Database.
GetAllChannelMembershipsForUser( GetSessionChannels(
request.Context(), uid, request.Context(), sessionID,
) )
notified := map[int64]bool{uid: true} notified := map[int64]bool{sessionID: true}
nickBody, err := json.Marshal([]string{newNick}) nickBody, err := json.Marshal([]string{newNick})
if err != nil { if err != nil {
@ -969,11 +1059,11 @@ func (hdlr *Handlers) broadcastNick(
json.RawMessage(nickBody), nil, json.RawMessage(nickBody), nil,
) )
_ = hdlr.params.Database.EnqueueMessage( _ = hdlr.params.Database.EnqueueToSession(
request.Context(), uid, dbID, request.Context(), sessionID, dbID,
) )
hdlr.broker.Notify(uid) hdlr.broker.Notify(sessionID)
for _, chanInfo := range channels { for _, chanInfo := range channels {
memberIDs, _ := hdlr.params.Database. memberIDs, _ := hdlr.params.Database.
@ -985,7 +1075,7 @@ func (hdlr *Handlers) broadcastNick(
if !notified[mid] { if !notified[mid] {
notified[mid] = true notified[mid] = true
_ = hdlr.params.Database.EnqueueMessage( _ = hdlr.params.Database.EnqueueToSession(
request.Context(), mid, dbID, request.Context(), mid, dbID,
) )
@ -1077,13 +1167,13 @@ func (hdlr *Handlers) handleTopic(
func (hdlr *Handlers) handleQuit( func (hdlr *Handlers) handleQuit(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick string, nick string,
body json.RawMessage, body json.RawMessage,
) { ) {
channels, _ := hdlr.params.Database. channels, _ := hdlr.params.Database.
GetAllChannelMembershipsForUser( GetSessionChannels(
request.Context(), uid, request.Context(), sessionID,
) )
notified := map[int64]bool{} notified := map[int64]bool{}
@ -1103,10 +1193,10 @@ func (hdlr *Handlers) handleQuit(
) )
for _, mid := range memberIDs { for _, mid := range memberIDs {
if mid != uid && !notified[mid] { if mid != sessionID && !notified[mid] {
notified[mid] = true notified[mid] = true
_ = hdlr.params.Database.EnqueueMessage( _ = hdlr.params.Database.EnqueueToSession(
request.Context(), mid, dbID, request.Context(), mid, dbID,
) )
@ -1115,7 +1205,7 @@ func (hdlr *Handlers) handleQuit(
} }
_ = hdlr.params.Database.PartChannel( _ = hdlr.params.Database.PartChannel(
request.Context(), chanInfo.ID, uid, request.Context(), chanInfo.ID, sessionID,
) )
_ = hdlr.params.Database.DeleteChannelIfEmpty( _ = hdlr.params.Database.DeleteChannelIfEmpty(
@ -1123,8 +1213,8 @@ func (hdlr *Handlers) handleQuit(
) )
} }
_ = hdlr.params.Database.DeleteUser( _ = hdlr.params.Database.DeleteSession(
request.Context(), uid, request.Context(), sessionID,
) )
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
@ -1138,7 +1228,8 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
uid, nick, ok := hdlr.requireAuth(writer, request) sessionID, _, nick, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@ -1155,7 +1246,7 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
} }
if !hdlr.canAccessHistory( if !hdlr.canAccessHistory(
writer, request, uid, nick, target, writer, request, sessionID, nick, target,
) { ) {
return return
} }
@ -1198,12 +1289,12 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
func (hdlr *Handlers) canAccessHistory( func (hdlr *Handlers) canAccessHistory(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, target string, nick, target string,
) bool { ) bool {
if strings.HasPrefix(target, "#") { if strings.HasPrefix(target, "#") {
return hdlr.canAccessChannelHistory( return hdlr.canAccessChannelHistory(
writer, request, uid, target, writer, request, sessionID, target,
) )
} }
@ -1225,7 +1316,7 @@ func (hdlr *Handlers) canAccessHistory(
func (hdlr *Handlers) canAccessChannelHistory( func (hdlr *Handlers) canAccessChannelHistory(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
target string, target string,
) bool { ) bool {
chID, err := hdlr.params.Database.GetChannelByName( chID, err := hdlr.params.Database.GetChannelByName(
@ -1242,7 +1333,7 @@ func (hdlr *Handlers) canAccessChannelHistory(
} }
isMember, err := hdlr.params.Database.IsChannelMember( isMember, err := hdlr.params.Database.IsChannelMember(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(

466
web/dist/app.js vendored

File diff suppressed because one or more lines are too long

43
web/dist/style.css vendored
View File

@ -14,6 +14,9 @@
--tab-active: #e94560; --tab-active: #e94560;
--tab-bg: #16213e; --tab-bg: #16213e;
--tab-hover: #1a1a3e; --tab-hover: #1a1a3e;
--topic-bg: #121a30;
--unread-bg: #e94560;
--warn: #f0ad4e;
} }
html, body, #root { html, body, #root {
@ -86,6 +89,7 @@ html, body, #root {
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
overflow-x: auto; overflow-x: auto;
flex-shrink: 0; flex-shrink: 0;
align-items: center;
} }
.tab { .tab {
@ -95,6 +99,7 @@ html, body, #root {
white-space: nowrap; white-space: nowrap;
color: var(--text-muted); color: var(--text-muted);
user-select: none; user-select: none;
position: relative;
} }
.tab:hover { .tab:hover {
@ -116,6 +121,43 @@ html, body, #root {
color: var(--accent); color: var(--accent);
} }
.tab .unread-badge {
display: inline-block;
background: var(--unread-bg);
color: white;
font-size: 10px;
font-weight: bold;
padding: 1px 5px;
border-radius: 8px;
margin-left: 6px;
min-width: 16px;
text-align: center;
}
/* Connection status */
.connection-status {
padding: 4px 12px;
background: var(--warn);
color: #1a1a2e;
font-size: 12px;
font-weight: bold;
white-space: nowrap;
flex-shrink: 0;
}
/* Topic bar */
.topic-bar {
padding: 6px 12px;
background: var(--topic-bg);
border-bottom: 1px solid var(--border);
color: var(--text-muted);
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
flex-shrink: 0;
}
/* Content area */ /* Content area */
.content { .content {
display: flex; display: flex;
@ -243,6 +285,7 @@ html, body, #root {
gap: 8px; gap: 8px;
background: var(--bg-secondary); background: var(--bg-secondary);
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
margin-left: auto;
} }
.join-dialog input { .join-dialog input {

View File

@ -1,13 +1,17 @@
import { h, render, Component } from 'preact'; import { h, render } from 'preact';
import { useState, useEffect, useRef, useCallback } from 'preact/hooks'; import { useState, useEffect, useRef, useCallback } from 'preact/hooks';
const API = '/api/v1'; const API = '/api/v1';
const POLL_TIMEOUT = 15;
const RECONNECT_DELAY = 3000;
const MEMBER_REFRESH_INTERVAL = 10000;
function api(path, opts = {}) { function api(path, opts = {}) {
const token = localStorage.getItem('chat_token'); const token = localStorage.getItem('chat_token');
const headers = { 'Content-Type': 'application/json', ...(opts.headers || {}) }; const headers = { 'Content-Type': 'application/json', ...(opts.headers || {}) };
if (token) headers['Authorization'] = `Bearer ${token}`; if (token) headers['Authorization'] = `Bearer ${token}`;
return fetch(API + path, { ...opts, headers }).then(async r => { const { signal, ...rest } = opts;
return fetch(API + path, { ...rest, headers, signal }).then(async r => {
const data = await r.json().catch(() => null); const data = await r.json().catch(() => null);
if (!r.ok) throw { status: r.status, data }; if (!r.ok) throw { status: r.status, data };
return data; return data;
@ -19,7 +23,6 @@ function formatTime(ts) {
return d.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' }); return d.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' });
} }
// Nick color hashing
function nickColor(nick) { function nickColor(nick) {
let h = 0; let h = 0;
for (let i = 0; i < nick.length; i++) h = nick.charCodeAt(i) + ((h << 5) - h); for (let i = 0; i < nick.length; i++) h = nick.charCodeAt(i) + ((h << 5) - h);
@ -39,10 +42,9 @@ function LoginScreen({ onLogin }) {
if (s.name) setServerName(s.name); if (s.name) setServerName(s.name);
if (s.motd) setMotd(s.motd); if (s.motd) setMotd(s.motd);
}).catch(() => {}); }).catch(() => {});
// Check for saved token
const saved = localStorage.getItem('chat_token'); const saved = localStorage.getItem('chat_token');
if (saved) { if (saved) {
api('/state').then(u => onLogin(u.nick, saved)).catch(() => localStorage.removeItem('chat_token')); api('/state').then(u => onLogin(u.nick)).catch(() => localStorage.removeItem('chat_token'));
} }
inputRef.current?.focus(); inputRef.current?.focus();
}, []); }, []);
@ -56,7 +58,7 @@ function LoginScreen({ onLogin }) {
body: JSON.stringify({ nick: nick.trim() }) body: JSON.stringify({ nick: nick.trim() })
}); });
localStorage.setItem('chat_token', res.token); localStorage.setItem('chat_token', res.token);
onLogin(res.nick, res.token); onLogin(res.nick);
} catch (err) { } catch (err) {
setError(err.data?.error || 'Connection failed'); setError(err.data?.error || 'Connection failed');
} }
@ -84,11 +86,19 @@ function LoginScreen({ onLogin }) {
} }
function Message({ msg }) { function Message({ msg }) {
if (msg.system) {
return (
<div class="message system">
<span class="timestamp">{formatTime(msg.ts)}</span>
<span class="content">{msg.text}</span>
</div>
);
}
return ( return (
<div class={`message ${msg.system ? 'system' : ''}`}> <div class="message">
<span class="timestamp">{formatTime(msg.createdAt)}</span> <span class="timestamp">{formatTime(msg.ts)}</span>
<span class="nick" style={{ color: msg.system ? undefined : nickColor(msg.nick) }}>{msg.nick}</span> <span class="nick" style={{ color: nickColor(msg.from) }}>{msg.from}</span>
<span class="content">{msg.content}</span> <span class="content">{msg.text}</span>
</div> </div>
); );
} }
@ -98,93 +108,194 @@ function App() {
const [nick, setNick] = useState(''); const [nick, setNick] = useState('');
const [tabs, setTabs] = useState([{ type: 'server', name: 'Server' }]); const [tabs, setTabs] = useState([{ type: 'server', name: 'Server' }]);
const [activeTab, setActiveTab] = useState(0); const [activeTab, setActiveTab] = useState(0);
const [messages, setMessages] = useState({ server: [] }); // keyed by tab name const [messages, setMessages] = useState({ Server: [] });
const [members, setMembers] = useState({}); // keyed by channel name const [members, setMembers] = useState({});
const [topics, setTopics] = useState({});
const [unread, setUnread] = useState({});
const [input, setInput] = useState(''); const [input, setInput] = useState('');
const [joinInput, setJoinInput] = useState(''); const [joinInput, setJoinInput] = useState('');
const [lastMsgId, setLastMsgId] = useState(0); const [connected, setConnected] = useState(true);
const lastIdRef = useRef(0);
const seenIdsRef = useRef(new Set());
const pollAbortRef = useRef(null);
const tabsRef = useRef(tabs);
const activeTabRef = useRef(activeTab);
const nickRef = useRef(nick);
const messagesEndRef = useRef(); const messagesEndRef = useRef();
const inputRef = useRef(); const inputRef = useRef();
const pollRef = useRef();
useEffect(() => { tabsRef.current = tabs; }, [tabs]);
useEffect(() => { activeTabRef.current = activeTab; }, [activeTab]);
useEffect(() => { nickRef.current = nick; }, [nick]);
// Persist joined channels
useEffect(() => {
const channels = tabs.filter(t => t.type === 'channel').map(t => t.name);
localStorage.setItem('chat_channels', JSON.stringify(channels));
}, [tabs]);
// Clear unread on tab switch
useEffect(() => {
const tab = tabs[activeTab];
if (tab) setUnread(prev => ({ ...prev, [tab.name]: 0 }));
}, [activeTab, tabs]);
const addMessage = useCallback((tabName, msg) => { const addMessage = useCallback((tabName, msg) => {
if (msg.id && seenIdsRef.current.has(msg.id)) return;
if (msg.id) seenIdsRef.current.add(msg.id);
setMessages(prev => ({ setMessages(prev => ({
...prev, ...prev,
[tabName]: [...(prev[tabName] || []), msg] [tabName]: [...(prev[tabName] || []), msg]
})); }));
const currentTab = tabsRef.current[activeTabRef.current];
if (!currentTab || currentTab.name !== tabName) {
setUnread(prev => ({ ...prev, [tabName]: (prev[tabName] || 0) + 1 }));
}
}, []); }, []);
const addSystemMessage = useCallback((tabName, text) => { const addSystemMessage = useCallback((tabName, text) => {
addMessage(tabName, { setMessages(prev => ({
id: Date.now(), ...prev,
nick: '*', [tabName]: [...(prev[tabName] || []), {
content: text, id: 'sys-' + Date.now() + '-' + Math.random(),
createdAt: new Date().toISOString(), ts: new Date().toISOString(),
system: true text,
}); system: true
}, [addMessage]); }]
}));
}, []);
const onLogin = useCallback((userNick, token) => { const refreshMembers = useCallback((channel) => {
setNick(userNick); const chName = channel.replace('#', '');
setLoggedIn(true); api(`/channels/${chName}/members`).then(m => {
addSystemMessage('server', `Connected as ${userNick}`); setMembers(prev => ({ ...prev, [channel]: m }));
// Fetch server info
api('/server').then(s => {
if (s.motd) addSystemMessage('server', `MOTD: ${s.motd}`);
}).catch(() => {}); }).catch(() => {});
}, [addSystemMessage]); }, []);
// Poll for new messages const processMessage = useCallback((msg) => {
const body = Array.isArray(msg.body) ? msg.body.join('\n') : '';
const base = { id: msg.id, ts: msg.ts, from: msg.from, to: msg.to, command: msg.command };
switch (msg.command) {
case 'PRIVMSG':
case 'NOTICE': {
const parsed = { ...base, text: body, system: false };
const target = msg.to;
if (target && target.startsWith('#')) {
addMessage(target, parsed);
} else {
const dmPeer = msg.from === nickRef.current ? msg.to : msg.from;
setTabs(prev => {
if (!prev.find(t => t.type === 'dm' && t.name === dmPeer)) {
return [...prev, { type: 'dm', name: dmPeer }];
}
return prev;
});
addMessage(dmPeer, parsed);
}
break;
}
case 'JOIN': {
const text = `${msg.from} has joined ${msg.to}`;
if (msg.to) addMessage(msg.to, { ...base, text, system: true });
if (msg.to && msg.to.startsWith('#')) refreshMembers(msg.to);
break;
}
case 'PART': {
const reason = body ? ': ' + body : '';
const text = `${msg.from} has left ${msg.to}${reason}`;
if (msg.to) addMessage(msg.to, { ...base, text, system: true });
if (msg.to && msg.to.startsWith('#')) refreshMembers(msg.to);
break;
}
case 'QUIT': {
const reason = body ? ': ' + body : '';
const text = `${msg.from} has quit${reason}`;
tabsRef.current.forEach(tab => {
if (tab.type === 'channel') {
addMessage(tab.name, { ...base, text, system: true });
}
});
break;
}
case 'NICK': {
const newNick = Array.isArray(msg.body) ? msg.body[0] : body;
const text = `${msg.from} is now known as ${newNick}`;
tabsRef.current.forEach(tab => {
if (tab.type === 'channel') {
addMessage(tab.name, { ...base, text, system: true });
}
});
if (msg.from === nickRef.current && newNick) setNick(newNick);
// Refresh members in all channels
tabsRef.current.forEach(tab => {
if (tab.type === 'channel') refreshMembers(tab.name);
});
break;
}
case 'TOPIC': {
const text = `${msg.from} set the topic: ${body}`;
if (msg.to) {
addMessage(msg.to, { ...base, text, system: true });
setTopics(prev => ({ ...prev, [msg.to]: body }));
}
break;
}
case '375':
case '372':
case '376':
addMessage('Server', { ...base, text: body, system: true });
break;
default:
addMessage('Server', { ...base, text: body || msg.command, system: true });
}
}, [addMessage, refreshMembers]);
// Long-poll loop
useEffect(() => { useEffect(() => {
if (!loggedIn) return; if (!loggedIn) return;
let alive = true; let alive = true;
const poll = async () => { const poll = async () => {
try { while (alive) {
const msgs = await api(`/messages?after=${lastMsgId}`); try {
if (!alive) return; const controller = new AbortController();
let maxId = lastMsgId; pollAbortRef.current = controller;
for (const msg of msgs) { const result = await api(
if (msg.id > maxId) maxId = msg.id; `/messages?after=${lastIdRef.current}&timeout=${POLL_TIMEOUT}`,
if (msg.isDm) { { signal: controller.signal }
const dmTab = msg.nick === nick ? msg.dmTarget : msg.nick; );
// Ensure DM tab exists if (!alive) break;
setTabs(prev => { setConnected(true);
if (!prev.find(t => t.type === 'dm' && t.name === dmTab)) { if (result.messages) {
return [...prev, { type: 'dm', name: dmTab }]; for (const m of result.messages) processMessage(m);
}
return prev;
});
addMessage(dmTab, msg);
} else if (msg.channel) {
addMessage(msg.channel, msg);
} }
if (result.last_id > lastIdRef.current) {
lastIdRef.current = result.last_id;
}
} catch (err) {
if (!alive) break;
if (err.name === 'AbortError') continue;
setConnected(false);
await new Promise(r => setTimeout(r, RECONNECT_DELAY));
} }
if (maxId > lastMsgId) setLastMsgId(maxId);
} catch (err) {
// silent
} }
}; };
pollRef.current = setInterval(poll, 1500);
poll();
return () => { alive = false; clearInterval(pollRef.current); };
}, [loggedIn, lastMsgId, nick, addMessage]);
// Fetch members for active channel tab poll();
return () => { alive = false; pollAbortRef.current?.abort(); };
}, [loggedIn, processMessage]);
// Refresh members for active channel
useEffect(() => { useEffect(() => {
if (!loggedIn) return; if (!loggedIn) return;
const tab = tabs[activeTab]; const tab = tabs[activeTab];
if (!tab || tab.type !== 'channel') return; if (!tab || tab.type !== 'channel') return;
const chName = tab.name.replace('#', ''); refreshMembers(tab.name);
api(`/channels/${chName}/members`).then(m => { const iv = setInterval(() => refreshMembers(tab.name), MEMBER_REFRESH_INTERVAL);
setMembers(prev => ({ ...prev, [tab.name]: m }));
}).catch(() => {});
const iv = setInterval(() => {
api(`/channels/${chName}/members`).then(m => {
setMembers(prev => ({ ...prev, [tab.name]: m }));
}).catch(() => {});
}, 5000);
return () => clearInterval(iv); return () => clearInterval(iv);
}, [loggedIn, activeTab, tabs]); }, [loggedIn, activeTab, tabs, refreshMembers]);
// Auto-scroll // Auto-scroll
useEffect(() => { useEffect(() => {
@ -192,9 +303,37 @@ function App() {
}, [messages, activeTab]); }, [messages, activeTab]);
// Focus input on tab change // Focus input on tab change
useEffect(() => { inputRef.current?.focus(); }, [activeTab]);
// Fetch topic for active channel
useEffect(() => { useEffect(() => {
inputRef.current?.focus(); if (!loggedIn) return;
}, [activeTab]); const tab = tabs[activeTab];
if (!tab || tab.type !== 'channel') return;
api('/channels').then(channels => {
const ch = channels.find(c => c.name === tab.name);
if (ch && ch.topic) setTopics(prev => ({ ...prev, [tab.name]: ch.topic }));
}).catch(() => {});
}, [loggedIn, activeTab, tabs]);
const onLogin = useCallback(async (userNick) => {
setNick(userNick);
setLoggedIn(true);
addSystemMessage('Server', `Connected as ${userNick}`);
// Auto-rejoin saved channels
const saved = JSON.parse(localStorage.getItem('chat_channels') || '[]');
for (const ch of saved) {
try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'JOIN', to: ch }) });
setTabs(prev => {
if (prev.find(t => t.type === 'channel' && t.name === ch)) return prev;
return [...prev, { type: 'channel', name: ch }];
});
} catch (e) {
// Channel may not exist anymore
}
}
}, [addSystemMessage]);
const joinChannel = async (name) => { const joinChannel = async (name) => {
if (!name) return; if (!name) return;
@ -206,22 +345,29 @@ function App() {
if (prev.find(t => t.type === 'channel' && t.name === name)) return prev; if (prev.find(t => t.type === 'channel' && t.name === name)) return prev;
return [...prev, { type: 'channel', name }]; return [...prev, { type: 'channel', name }];
}); });
setActiveTab(tabs.length); // switch to new tab setActiveTab(tabs.length);
addSystemMessage(name, `Joined ${name}`); // Load history
try {
const hist = await api(`/history?target=${encodeURIComponent(name)}&limit=50`);
if (Array.isArray(hist)) {
for (const m of hist) processMessage(m);
}
} catch (e) {
// History may be empty
}
setJoinInput(''); setJoinInput('');
} catch (err) { } catch (err) {
addSystemMessage('server', `Failed to join ${name}: ${err.data?.error || 'error'}`); addSystemMessage('Server', `Failed to join ${name}: ${err.data?.error || 'error'}`);
} }
}; };
const partChannel = async (name) => { const partChannel = async (name) => {
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PART', to: name }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PART', to: name }) });
} catch (err) { /* ignore */ } } catch (e) {
setTabs(prev => { // Ignore
const next = prev.filter(t => !(t.type === 'channel' && t.name === name)); }
return next; setTabs(prev => prev.filter(t => !(t.type === 'channel' && t.name === name)));
});
setActiveTab(0); setActiveTab(0);
}; };
@ -240,7 +386,8 @@ function App() {
if (prev.find(t => t.type === 'dm' && t.name === targetNick)) return prev; if (prev.find(t => t.type === 'dm' && t.name === targetNick)) return prev;
return [...prev, { type: 'dm', name: targetNick }]; return [...prev, { type: 'dm', name: targetNick }];
}); });
setActiveTab(tabs.findIndex(t => t.type === 'dm' && t.name === targetNick) || tabs.length); const idx = tabs.findIndex(t => t.type === 'dm' && t.name === targetNick);
setActiveTab(idx >= 0 ? idx : tabs.length);
}; };
const sendMessage = async () => { const sendMessage = async () => {
@ -250,46 +397,45 @@ function App() {
const tab = tabs[activeTab]; const tab = tabs[activeTab];
if (!tab || tab.type === 'server') return; if (!tab || tab.type === 'server') return;
// Handle /commands
if (text.startsWith('/')) { if (text.startsWith('/')) {
const parts = text.split(' '); const parts = text.split(' ');
const cmd = parts[0].toLowerCase(); const cmd = parts[0].toLowerCase();
if (cmd === '/join' && parts[1]) { if (cmd === '/join' && parts[1]) { joinChannel(parts[1]); return; }
joinChannel(parts[1]); if (cmd === '/part') { if (tab.type === 'channel') partChannel(tab.name); return; }
return;
}
if (cmd === '/part') {
if (tab.type === 'channel') partChannel(tab.name);
return;
}
if (cmd === '/msg' && parts[1] && parts.slice(2).join(' ')) { if (cmd === '/msg' && parts[1] && parts.slice(2).join(' ')) {
const target = parts[1]; const target = parts[1];
const msg = parts.slice(2).join(' '); const body = parts.slice(2).join(' ');
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: target, body: [msg] }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: target, body: [body] }) });
openDM(target); openDM(target);
} catch (err) { } catch (err) {
addSystemMessage('server', `Failed to send DM: ${err.data?.error || 'error'}`); addSystemMessage('Server', `DM failed: ${err.data?.error || 'error'}`);
} }
return; return;
} }
if (cmd === '/nick' && parts[1]) { if (cmd === '/nick' && parts[1]) {
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'NICK', body: [parts[1]] }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'NICK', body: [parts[1]] }) });
setNick(parts[1]);
addSystemMessage('server', `Nick changed to ${parts[1]}`);
} catch (err) { } catch (err) {
addSystemMessage('server', `Nick change failed: ${err.data?.error || 'error'}`); addSystemMessage('Server', `Nick change failed: ${err.data?.error || 'error'}`);
} }
return; return;
} }
addSystemMessage('server', `Unknown command: ${cmd}`); if (cmd === '/topic' && tab.type === 'channel') {
const topicText = parts.slice(1).join(' ');
try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'TOPIC', to: tab.name, body: [topicText] }) });
} catch (err) {
addSystemMessage('Server', `Topic failed: ${err.data?.error || 'error'}`);
}
return;
}
addSystemMessage('Server', `Unknown command: ${cmd}`);
return; return;
} }
const to = tab.type === 'channel' ? tab.name : tab.name;
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to, body: [text] }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: tab.name, body: [text] }) });
} catch (err) { } catch (err) {
addSystemMessage(tab.name, `Send failed: ${err.data?.error || 'error'}`); addSystemMessage(tab.name, `Send failed: ${err.data?.error || 'error'}`);
} }
@ -300,16 +446,21 @@ function App() {
const currentTab = tabs[activeTab] || tabs[0]; const currentTab = tabs[activeTab] || tabs[0];
const currentMessages = messages[currentTab.name] || []; const currentMessages = messages[currentTab.name] || [];
const currentMembers = members[currentTab.name] || []; const currentMembers = members[currentTab.name] || [];
const currentTopic = topics[currentTab.name] || '';
return ( return (
<div class="app"> <div class="app">
<div class="tab-bar"> <div class="tab-bar">
{!connected && <div class="connection-status"> Reconnecting...</div>}
{tabs.map((tab, i) => ( {tabs.map((tab, i) => (
<div <div
class={`tab ${i === activeTab ? 'active' : ''}`} class={`tab ${i === activeTab ? 'active' : ''}`}
onClick={() => setActiveTab(i)} onClick={() => setActiveTab(i)}
> >
{tab.type === 'dm' ? `${tab.name}` : tab.name} {tab.type === 'dm' ? `${tab.name}` : tab.name}
{unread[tab.name] > 0 && i !== activeTab && (
<span class="unread-badge">{unread[tab.name]}</span>
)}
{tab.type !== 'server' && ( {tab.type !== 'server' && (
<span class="close-btn" onClick={(e) => { e.stopPropagation(); closeTab(i); }}>×</span> <span class="close-btn" onClick={(e) => { e.stopPropagation(); closeTab(i); }}>×</span>
)} )}
@ -326,30 +477,27 @@ function App() {
</div> </div>
</div> </div>
{currentTab.type === 'channel' && currentTopic && (
<div class="topic-bar" title={currentTopic}>{currentTopic}</div>
)}
<div class="content"> <div class="content">
<div class="messages-pane"> <div class="messages-pane">
{currentTab.type === 'server' ? ( <div class={currentTab.type === 'server' ? 'server-messages' : 'messages'}>
<div class="server-messages"> {currentMessages.map(m => <Message msg={m} />)}
{currentMessages.map(m => <Message msg={m} />)} <div ref={messagesEndRef} />
<div ref={messagesEndRef} /> </div>
{currentTab.type !== 'server' && (
<div class="input-bar">
<input
ref={inputRef}
placeholder={`Message ${currentTab.name}...`}
value={input}
onInput={e => setInput(e.target.value)}
onKeyDown={e => e.key === 'Enter' && sendMessage()}
/>
<button onClick={sendMessage}>Send</button>
</div> </div>
) : (
<>
<div class="messages">
{currentMessages.map(m => <Message msg={m} />)}
<div ref={messagesEndRef} />
</div>
<div class="input-bar">
<input
ref={inputRef}
placeholder={`Message ${currentTab.name}...`}
value={input}
onInput={e => setInput(e.target.value)}
onKeyDown={e => e.key === 'Enter' && sendMessage()}
/>
<button onClick={sendMessage}>Send</button>
</div>
</>
)} )}
</div> </div>

View File

@ -14,6 +14,9 @@
--tab-active: #e94560; --tab-active: #e94560;
--tab-bg: #16213e; --tab-bg: #16213e;
--tab-hover: #1a1a3e; --tab-hover: #1a1a3e;
--topic-bg: #121a30;
--unread-bg: #e94560;
--warn: #f0ad4e;
} }
html, body, #root { html, body, #root {
@ -86,6 +89,7 @@ html, body, #root {
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
overflow-x: auto; overflow-x: auto;
flex-shrink: 0; flex-shrink: 0;
align-items: center;
} }
.tab { .tab {
@ -95,6 +99,7 @@ html, body, #root {
white-space: nowrap; white-space: nowrap;
color: var(--text-muted); color: var(--text-muted);
user-select: none; user-select: none;
position: relative;
} }
.tab:hover { .tab:hover {
@ -116,6 +121,43 @@ html, body, #root {
color: var(--accent); color: var(--accent);
} }
.tab .unread-badge {
display: inline-block;
background: var(--unread-bg);
color: white;
font-size: 10px;
font-weight: bold;
padding: 1px 5px;
border-radius: 8px;
margin-left: 6px;
min-width: 16px;
text-align: center;
}
/* Connection status */
.connection-status {
padding: 4px 12px;
background: var(--warn);
color: #1a1a2e;
font-size: 12px;
font-weight: bold;
white-space: nowrap;
flex-shrink: 0;
}
/* Topic bar */
.topic-bar {
padding: 6px 12px;
background: var(--topic-bg);
border-bottom: 1px solid var(--border);
color: var(--text-muted);
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
flex-shrink: 0;
}
/* Content area */ /* Content area */
.content { .content {
display: flex; display: flex;
@ -243,6 +285,7 @@ html, body, #root {
gap: 8px; gap: 8px;
background: var(--bg-secondary); background: var(--bg-secondary);
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
margin-left: auto;
} }
.join-dialog input { .join-dialog input {