feat: MVP two-user chat via embedded SPA (#9)
All checks were successful
check / check (push) Successful in 1m51s
All checks were successful
check / check (push) Successful in 1m51s
Backend: - Session/client UUID model: sessions table (uuid, nick, signing_key), clients table (uuid, session_id, token) with per-client message queues - MOTD delivery as IRC numeric messages (375/372/376) on connect - EnqueueToSession fans out to all clients of a session - EnqueueToClient for targeted delivery (MOTD) - All queries updated for session/client model SPA client: - Long-poll loop (15s timeout) instead of setInterval - IRC message envelope parsing (command/from/to/body) - Display JOIN/PART/NICK/TOPIC/QUIT system messages - Nick change via /nick command - Topic display in header bar - Unread count badges on inactive tabs - Auto-rejoin channels on reconnect (localStorage) - Connection status indicator - Message deduplication by UUID - Channel history loaded on join - /topic command support Closes #9
This commit is contained in:
@@ -55,76 +55,132 @@ type MemberInfo struct {
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
}
|
||||
|
||||
// CreateUser registers a new user with the given nick.
|
||||
func (database *Database) CreateUser(
|
||||
// CreateSession registers a new session and its first client.
|
||||
func (database *Database) CreateSession(
|
||||
ctx context.Context,
|
||||
nick string,
|
||||
) (int64, string, error) {
|
||||
) (int64, int64, string, error) {
|
||||
sessionUUID := uuid.New().String()
|
||||
clientUUID := uuid.New().String()
|
||||
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
return 0, 0, "", err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
res, err := database.conn.ExecContext(ctx,
|
||||
`INSERT INTO users
|
||||
(nick, token, created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
nick, token, now, now)
|
||||
transaction, err := database.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("create user: %w", err)
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"begin tx: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
userID, _ := res.LastInsertId()
|
||||
res, err := transaction.ExecContext(ctx,
|
||||
`INSERT INTO sessions
|
||||
(uuid, nick, created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
sessionUUID, nick, now, now)
|
||||
if err != nil {
|
||||
_ = transaction.Rollback()
|
||||
|
||||
return userID, token, nil
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
sessionID, _ := res.LastInsertId()
|
||||
|
||||
clientRes, err := transaction.ExecContext(ctx,
|
||||
`INSERT INTO clients
|
||||
(uuid, session_id, token,
|
||||
created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
clientUUID, sessionID, token, now, now)
|
||||
if err != nil {
|
||||
_ = transaction.Rollback()
|
||||
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create client: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
clientID, _ := clientRes.LastInsertId()
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"commit session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return sessionID, clientID, token, nil
|
||||
}
|
||||
|
||||
// GetUserByToken returns user id and nick for a token.
|
||||
func (database *Database) GetUserByToken(
|
||||
// GetSessionByToken returns session id, client id, and
|
||||
// nick for a client token.
|
||||
func (database *Database) GetSessionByToken(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (int64, string, error) {
|
||||
var userID int64
|
||||
|
||||
var nick string
|
||||
) (int64, int64, string, error) {
|
||||
var (
|
||||
sessionID int64
|
||||
clientID int64
|
||||
nick string
|
||||
)
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
"SELECT id, nick FROM users WHERE token = ?",
|
||||
`SELECT s.id, c.id, s.nick
|
||||
FROM clients c
|
||||
INNER JOIN sessions s
|
||||
ON s.id = c.session_id
|
||||
WHERE c.token = ?`,
|
||||
token,
|
||||
).Scan(&userID, &nick)
|
||||
).Scan(&sessionID, &clientID, &nick)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("get user by token: %w", err)
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"get session by token: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
_, _ = database.conn.ExecContext(
|
||||
ctx,
|
||||
"UPDATE users SET last_seen = ? WHERE id = ?",
|
||||
time.Now(), userID,
|
||||
"UPDATE sessions SET last_seen = ? WHERE id = ?",
|
||||
now, sessionID,
|
||||
)
|
||||
|
||||
return userID, nick, nil
|
||||
_, _ = database.conn.ExecContext(
|
||||
ctx,
|
||||
"UPDATE clients SET last_seen = ? WHERE id = ?",
|
||||
now, clientID,
|
||||
)
|
||||
|
||||
return sessionID, clientID, nick, nil
|
||||
}
|
||||
|
||||
// GetUserByNick returns user id for a given nick.
|
||||
func (database *Database) GetUserByNick(
|
||||
// GetSessionByNick returns session id for a given nick.
|
||||
func (database *Database) GetSessionByNick(
|
||||
ctx context.Context,
|
||||
nick string,
|
||||
) (int64, error) {
|
||||
var userID int64
|
||||
var sessionID int64
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
"SELECT id FROM users WHERE nick = ?",
|
||||
"SELECT id FROM sessions WHERE nick = ?",
|
||||
nick,
|
||||
).Scan(&userID)
|
||||
).Scan(&sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get user by nick: %w", err)
|
||||
return 0, fmt.Errorf(
|
||||
"get session by nick: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
// GetChannelByName returns the channel ID for a name.
|
||||
@@ -179,16 +235,16 @@ func (database *Database) GetOrCreateChannel(
|
||||
return channelID, nil
|
||||
}
|
||||
|
||||
// JoinChannel adds a user to a channel.
|
||||
// JoinChannel adds a session to a channel.
|
||||
func (database *Database) JoinChannel(
|
||||
ctx context.Context,
|
||||
channelID, userID int64,
|
||||
channelID, sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO channel_members
|
||||
(channel_id, user_id, joined_at)
|
||||
(channel_id, session_id, joined_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
channelID, userID, time.Now())
|
||||
channelID, sessionID, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("join channel: %w", err)
|
||||
}
|
||||
@@ -196,15 +252,15 @@ func (database *Database) JoinChannel(
|
||||
return nil
|
||||
}
|
||||
|
||||
// PartChannel removes a user from a channel.
|
||||
// PartChannel removes a session from a channel.
|
||||
func (database *Database) PartChannel(
|
||||
ctx context.Context,
|
||||
channelID, userID int64,
|
||||
channelID, sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`DELETE FROM channel_members
|
||||
WHERE channel_id = ? AND user_id = ?`,
|
||||
channelID, userID)
|
||||
WHERE channel_id = ? AND session_id = ?`,
|
||||
channelID, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("part channel: %w", err)
|
||||
}
|
||||
@@ -265,18 +321,18 @@ func scanChannels(
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ListChannels returns channels the user has joined.
|
||||
// ListChannels returns channels the session has joined.
|
||||
func (database *Database) ListChannels(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) ([]ChannelInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT c.id, c.name, c.topic
|
||||
FROM channels c
|
||||
INNER JOIN channel_members cm
|
||||
ON cm.channel_id = c.id
|
||||
WHERE cm.user_id = ?
|
||||
ORDER BY c.name`, userID)
|
||||
WHERE cm.session_id = ?
|
||||
ORDER BY c.name`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list channels: %w", err)
|
||||
}
|
||||
@@ -306,12 +362,12 @@ func (database *Database) ChannelMembers(
|
||||
channelID int64,
|
||||
) ([]MemberInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT u.id, u.nick, u.last_seen
|
||||
FROM users u
|
||||
`SELECT s.id, s.nick, s.last_seen
|
||||
FROM sessions s
|
||||
INNER JOIN channel_members cm
|
||||
ON cm.user_id = u.id
|
||||
ON cm.session_id = s.id
|
||||
WHERE cm.channel_id = ?
|
||||
ORDER BY u.nick`, channelID)
|
||||
ORDER BY s.nick`, channelID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"query channel members: %w", err,
|
||||
@@ -349,17 +405,17 @@ func (database *Database) ChannelMembers(
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// IsChannelMember checks if a user belongs to a channel.
|
||||
// IsChannelMember checks if a session belongs to a channel.
|
||||
func (database *Database) IsChannelMember(
|
||||
ctx context.Context,
|
||||
channelID, userID int64,
|
||||
channelID, sessionID int64,
|
||||
) (bool, error) {
|
||||
var count int
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM channel_members
|
||||
WHERE channel_id = ? AND user_id = ?`,
|
||||
channelID, userID,
|
||||
WHERE channel_id = ? AND session_id = ?`,
|
||||
channelID, sessionID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(
|
||||
@@ -397,13 +453,13 @@ func scanInt64s(rows *sql.Rows) ([]int64, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// GetChannelMemberIDs returns user IDs in a channel.
|
||||
// GetChannelMemberIDs returns session IDs in a channel.
|
||||
func (database *Database) GetChannelMemberIDs(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) ([]int64, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT user_id FROM channel_members
|
||||
`SELECT session_id FROM channel_members
|
||||
WHERE channel_id = ?`, channelID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
@@ -414,17 +470,17 @@ func (database *Database) GetChannelMemberIDs(
|
||||
return scanInt64s(rows)
|
||||
}
|
||||
|
||||
// GetUserChannelIDs returns channel IDs the user is in.
|
||||
func (database *Database) GetUserChannelIDs(
|
||||
// GetSessionChannelIDs returns channel IDs for a session.
|
||||
func (database *Database) GetSessionChannelIDs(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) ([]int64, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT channel_id FROM channel_members
|
||||
WHERE user_id = ?`, userID)
|
||||
WHERE session_id = ?`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"get user channel ids: %w", err,
|
||||
"get session channel ids: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -467,27 +523,52 @@ func (database *Database) InsertMessage(
|
||||
return dbID, msgUUID, nil
|
||||
}
|
||||
|
||||
// EnqueueMessage adds a message to a user's queue.
|
||||
func (database *Database) EnqueueMessage(
|
||||
// EnqueueToSession adds a message to all clients of a
|
||||
// session's queues.
|
||||
func (database *Database) EnqueueToSession(
|
||||
ctx context.Context,
|
||||
userID, messageID int64,
|
||||
sessionID, messageID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO client_queues
|
||||
(user_id, message_id, created_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
userID, messageID, time.Now())
|
||||
(client_id, message_id, created_at)
|
||||
SELECT c.id, ?, ?
|
||||
FROM clients c
|
||||
WHERE c.session_id = ?`,
|
||||
messageID, time.Now(), sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enqueue message: %w", err)
|
||||
return fmt.Errorf(
|
||||
"enqueue to session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PollMessages returns queued messages for a user.
|
||||
// EnqueueToClient adds a message to a specific client's
|
||||
// queue.
|
||||
func (database *Database) EnqueueToClient(
|
||||
ctx context.Context,
|
||||
clientID, messageID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO client_queues
|
||||
(client_id, message_id, created_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
clientID, messageID, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"enqueue to client: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PollMessages returns queued messages for a client.
|
||||
func (database *Database) PollMessages(
|
||||
ctx context.Context,
|
||||
userID, afterQueueID int64,
|
||||
clientID, afterQueueID int64,
|
||||
limit int,
|
||||
) ([]IRCMessage, int64, error) {
|
||||
if limit <= 0 {
|
||||
@@ -501,9 +582,9 @@ func (database *Database) PollMessages(
|
||||
FROM client_queues cq
|
||||
INNER JOIN messages m
|
||||
ON m.id = cq.message_id
|
||||
WHERE cq.user_id = ? AND cq.id > ?
|
||||
WHERE cq.client_id = ? AND cq.id > ?
|
||||
ORDER BY cq.id ASC LIMIT ?`,
|
||||
userID, afterQueueID, limit)
|
||||
clientID, afterQueueID, limit)
|
||||
if err != nil {
|
||||
return nil, afterQueueID, fmt.Errorf(
|
||||
"poll messages: %w", err,
|
||||
@@ -649,15 +730,15 @@ func reverseMessages(msgs []IRCMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
// ChangeNick updates a user's nickname.
|
||||
// ChangeNick updates a session's nickname.
|
||||
func (database *Database) ChangeNick(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
newNick string,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
"UPDATE users SET nick = ? WHERE id = ?",
|
||||
newNick, userID)
|
||||
"UPDATE sessions SET nick = ? WHERE id = ?",
|
||||
newNick, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("change nick: %w", err)
|
||||
}
|
||||
@@ -681,38 +762,38 @@ func (database *Database) SetTopic(
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user and all their data.
|
||||
func (database *Database) DeleteUser(
|
||||
// DeleteSession removes a session and all its data.
|
||||
func (database *Database) DeleteSession(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(
|
||||
ctx,
|
||||
"DELETE FROM users WHERE id = ?",
|
||||
userID,
|
||||
"DELETE FROM sessions WHERE id = ?",
|
||||
sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
return fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllChannelMembershipsForUser returns channels
|
||||
// a user belongs to.
|
||||
func (database *Database) GetAllChannelMembershipsForUser(
|
||||
// GetSessionChannels returns channels a session
|
||||
// belongs to.
|
||||
func (database *Database) GetSessionChannels(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) ([]ChannelInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT c.id, c.name, c.topic
|
||||
FROM channels c
|
||||
INNER JOIN channel_members cm
|
||||
ON cm.channel_id = c.id
|
||||
WHERE cm.user_id = ?`, userID)
|
||||
WHERE cm.session_id = ?`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"get memberships: %w", err,
|
||||
"get session channels: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,70 +27,91 @@ func setupTestDB(t *testing.T) *db.Database {
|
||||
return database
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
func TestCreateSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
id, token, err := database.CreateUser(ctx, "alice")
|
||||
sessionID, _, token, err := database.CreateSession(
|
||||
ctx, "alice",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id == 0 || token == "" {
|
||||
if sessionID == 0 || token == "" {
|
||||
t.Fatal("expected valid id and token")
|
||||
}
|
||||
|
||||
_, _, err = database.CreateUser(ctx, "alice")
|
||||
if err == nil {
|
||||
_, _, dupToken, dupErr := database.CreateSession(
|
||||
ctx, "alice",
|
||||
)
|
||||
if dupErr == nil {
|
||||
t.Fatal("expected error for duplicate nick")
|
||||
}
|
||||
|
||||
_ = dupToken
|
||||
}
|
||||
|
||||
func TestGetUserByToken(t *testing.T) {
|
||||
func TestGetSessionByToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
_, token, err := database.CreateUser(ctx, "bob")
|
||||
_, _, token, err := database.CreateSession(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
id, nick, err := database.GetUserByToken(ctx, token)
|
||||
sessionID, clientID, nick, err :=
|
||||
database.GetSessionByToken(ctx, token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if nick != "bob" || id == 0 {
|
||||
if nick != "bob" || sessionID == 0 || clientID == 0 {
|
||||
t.Fatalf("expected bob, got %s", nick)
|
||||
}
|
||||
|
||||
_, _, err = database.GetUserByToken(ctx, "badtoken")
|
||||
if err == nil {
|
||||
badSID, badCID, badNick, badErr :=
|
||||
database.GetSessionByToken(ctx, "badtoken")
|
||||
if badErr == nil {
|
||||
t.Fatal("expected error for bad token")
|
||||
}
|
||||
|
||||
if badSID != 0 || badCID != 0 || badNick != "" {
|
||||
t.Fatal("expected zero values on error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByNick(t *testing.T) {
|
||||
func TestGetSessionByNick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
_, _, err := database.CreateUser(ctx, "charlie")
|
||||
charlieID, charlieClientID, charlieToken, err :=
|
||||
database.CreateSession(ctx, "charlie")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
id, err := database.GetUserByNick(ctx, "charlie")
|
||||
if charlieID == 0 || charlieClientID == 0 {
|
||||
t.Fatal("expected valid session/client IDs")
|
||||
}
|
||||
|
||||
if charlieToken == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
|
||||
id, err := database.GetSessionByNick(ctx, "charlie")
|
||||
if err != nil || id == 0 {
|
||||
t.Fatal("expected to find charlie")
|
||||
}
|
||||
|
||||
_, err = database.GetUserByNick(ctx, "nobody")
|
||||
_, err = database.GetSessionByNick(ctx, "nobody")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown nick")
|
||||
}
|
||||
@@ -129,7 +150,7 @@ func TestJoinAndPart(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "user1")
|
||||
sid, _, _, err := database.CreateSession(ctx, "user1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -139,22 +160,22 @@ func TestJoinAndPart(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ids, err := database.GetChannelMemberIDs(ctx, chID)
|
||||
if err != nil || len(ids) != 1 || ids[0] != uid {
|
||||
t.Fatal("expected user in channel")
|
||||
if err != nil || len(ids) != 1 || ids[0] != sid {
|
||||
t.Fatal("expected session in channel")
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.PartChannel(ctx, chID, uid)
|
||||
err = database.PartChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -178,17 +199,17 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "temp")
|
||||
sid, _, _, err := database.CreateSession(ctx, "temp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.PartChannel(ctx, chID, uid)
|
||||
err = database.PartChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -204,7 +225,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserWithChannels(
|
||||
func createSessionWithChannels(
|
||||
t *testing.T,
|
||||
database *db.Database,
|
||||
nick, ch1Name, ch2Name string,
|
||||
@@ -213,7 +234,7 @@ func createUserWithChannels(
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, nick)
|
||||
sid, _, _, err := database.CreateSession(ctx, nick)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -232,29 +253,29 @@ func createUserWithChannels(
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, ch1, uid)
|
||||
err = database.JoinChannel(ctx, ch1, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, ch2, uid)
|
||||
err = database.JoinChannel(ctx, ch2, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return uid, ch1, ch2
|
||||
return sid, ch1, ch2
|
||||
}
|
||||
|
||||
func TestListChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
uid, _, _ := createUserWithChannels(
|
||||
sid, _, _ := createSessionWithChannels(
|
||||
t, database, "lister", "#a", "#b",
|
||||
)
|
||||
|
||||
channels, err := database.ListChannels(
|
||||
t.Context(), uid,
|
||||
t.Context(), sid,
|
||||
)
|
||||
if err != nil || len(channels) != 2 {
|
||||
t.Fatalf(
|
||||
@@ -295,17 +316,21 @@ func TestChangeNick(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, token, err := database.CreateUser(ctx, "old")
|
||||
sid, _, token, err := database.CreateSession(
|
||||
ctx, "old",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.ChangeNick(ctx, uid, "new")
|
||||
err = database.ChangeNick(ctx, sid, "new")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, nick, err := database.GetUserByToken(ctx, token)
|
||||
_, _, nick, err := database.GetSessionByToken(
|
||||
ctx, token,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -375,7 +400,16 @@ func TestPollMessages(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "poller")
|
||||
sid, _, token, err := database.CreateSession(
|
||||
ctx, "poller",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, clientID, _, err := database.GetSessionByToken(
|
||||
ctx, token,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -389,7 +423,7 @@ func TestPollMessages(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.EnqueueMessage(ctx, uid, dbID)
|
||||
err = database.EnqueueToSession(ctx, sid, dbID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -397,7 +431,7 @@ func TestPollMessages(t *testing.T) {
|
||||
const batchSize = 10
|
||||
|
||||
msgs, lastQID, err := database.PollMessages(
|
||||
ctx, uid, 0, batchSize,
|
||||
ctx, clientID, 0, batchSize,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -420,7 +454,7 @@ func TestPollMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
msgs, _, _ = database.PollMessages(
|
||||
ctx, uid, lastQID, batchSize,
|
||||
ctx, clientID, lastQID, batchSize,
|
||||
)
|
||||
|
||||
if len(msgs) != 0 {
|
||||
@@ -467,13 +501,15 @@ func TestGetHistory(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "deleteme")
|
||||
sid, _, _, err := database.CreateSession(
|
||||
ctx, "deleteme",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -485,19 +521,19 @@ func TestDeleteUser(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.DeleteUser(ctx, uid)
|
||||
err = database.DeleteSession(ctx, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = database.GetUserByNick(ctx, "deleteme")
|
||||
_, err = database.GetSessionByNick(ctx, "deleteme")
|
||||
if err == nil {
|
||||
t.Fatal("user should be deleted")
|
||||
t.Fatal("session should be deleted")
|
||||
}
|
||||
|
||||
ids, _ := database.GetChannelMemberIDs(ctx, chID)
|
||||
@@ -512,12 +548,12 @@ func TestChannelMembers(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid1, _, err := database.CreateUser(ctx, "m1")
|
||||
sid1, _, _, err := database.CreateSession(ctx, "m1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
uid2, _, err := database.CreateUser(ctx, "m2")
|
||||
sid2, _, _, err := database.CreateSession(ctx, "m2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -529,12 +565,12 @@ func TestChannelMembers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid1)
|
||||
err = database.JoinChannel(ctx, chID, sid1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid2)
|
||||
err = database.JoinChannel(ctx, chID, sid2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -548,17 +584,17 @@ func TestChannelMembers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllChannelMembershipsForUser(t *testing.T) {
|
||||
func TestGetSessionChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
uid, _, _ := createUserWithChannels(
|
||||
sid, _, _ := createSessionWithChannels(
|
||||
t, database, "multi", "#m1", "#m2",
|
||||
)
|
||||
|
||||
channels, err :=
|
||||
database.GetAllChannelMembershipsForUser(
|
||||
t.Context(), uid,
|
||||
database.GetSessionChannels(
|
||||
t.Context(), sid,
|
||||
)
|
||||
if err != nil || len(channels) != 2 {
|
||||
t.Fatalf(
|
||||
@@ -567,3 +603,51 @@ func TestGetAllChannelMembershipsForUser(t *testing.T) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueToClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
_, _, token, err := database.CreateSession(
|
||||
ctx, "enqclient",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, clientID, _, err := database.GetSessionByToken(
|
||||
ctx, token,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := json.RawMessage(`["test"]`)
|
||||
|
||||
dbID, _, err := database.InsertMessage(
|
||||
ctx, "PRIVMSG", "sender", "#ch", body, nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.EnqueueToClient(ctx, clientID, dbID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const batchSize = 10
|
||||
|
||||
msgs, _, err := database.PollMessages(
|
||||
ctx, clientID, 0, batchSize,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(msgs))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,28 @@
|
||||
-- Chat server schema (pre-1.0 consolidated)
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- Users: IRC-style sessions (no passwords, just nick + token)
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
-- Sessions: IRC-style sessions (no passwords, nick + optional signing key)
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uuid TEXT NOT NULL UNIQUE,
|
||||
nick TEXT NOT NULL UNIQUE,
|
||||
signing_key TEXT NOT NULL DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid);
|
||||
|
||||
-- Clients: each session can have multiple connected clients
|
||||
CREATE TABLE IF NOT EXISTS clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uuid TEXT NOT NULL UNIQUE,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_token ON users(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_clients_token ON clients(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_clients_session ON clients(session_id);
|
||||
|
||||
-- Channels
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
@@ -24,9 +37,9 @@ CREATE TABLE IF NOT EXISTS channels (
|
||||
CREATE TABLE IF NOT EXISTS channel_members (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(channel_id, user_id)
|
||||
UNIQUE(channel_id, session_id)
|
||||
);
|
||||
|
||||
-- Messages: IRC envelope format
|
||||
@@ -46,9 +59,9 @@ CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
|
||||
-- Per-client message queues for fan-out delivery
|
||||
CREATE TABLE IF NOT EXISTS client_queues (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
client_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, message_id)
|
||||
UNIQUE(client_id, message_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_client_queues_user ON client_queues(user_id, id);
|
||||
CREATE INDEX IF NOT EXISTS idx_client_queues_client ON client_queues(client_id, id);
|
||||
|
||||
Reference in New Issue
Block a user