package db import ( "context" "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "fmt" "time" "github.com/google/uuid" ) const ( tokenBytes = 32 defaultPollLimit = 100 defaultHistLimit = 50 ) func generateToken() (string, error) { buf := make([]byte, tokenBytes) _, err := rand.Read(buf) if err != nil { return "", fmt.Errorf("generate token: %w", err) } return hex.EncodeToString(buf), nil } // IRCMessage is the IRC envelope for all messages. type IRCMessage struct { ID string `json:"id"` Command string `json:"command"` From string `json:"from,omitempty"` To string `json:"to,omitempty"` Body json.RawMessage `json:"body,omitempty"` TS string `json:"ts"` Meta json.RawMessage `json:"meta,omitempty"` DBID int64 `json:"-"` } // ChannelInfo is a lightweight channel representation. type ChannelInfo struct { ID int64 `json:"id"` Name string `json:"name"` Topic string `json:"topic"` } // MemberInfo represents a channel member. type MemberInfo struct { ID int64 `json:"id"` Nick string `json:"nick"` LastSeen time.Time `json:"lastSeen"` } // CreateSession registers a new session and its first client. func (database *Database) CreateSession( ctx context.Context, nick string, ) (int64, int64, string, error) { sessionUUID := uuid.New().String() clientUUID := uuid.New().String() token, err := generateToken() if err != nil { return 0, 0, "", err } now := time.Now() transaction, err := database.conn.BeginTx(ctx, nil) if err != nil { return 0, 0, "", fmt.Errorf( "begin tx: %w", err, ) } res, err := transaction.ExecContext(ctx, `INSERT INTO sessions (uuid, nick, created_at, last_seen) VALUES (?, ?, ?, ?)`, sessionUUID, nick, now, now) if err != nil { _ = transaction.Rollback() return 0, 0, "", fmt.Errorf( "create session: %w", err, ) } sessionID, _ := res.LastInsertId() clientRes, err := transaction.ExecContext(ctx, `INSERT INTO clients (uuid, session_id, token, created_at, last_seen) VALUES (?, ?, ?, ?, ?)`, clientUUID, sessionID, token, now, now) if err != nil { _ = transaction.Rollback() return 0, 0, "", fmt.Errorf( "create client: %w", err, ) } clientID, _ := clientRes.LastInsertId() err = transaction.Commit() if err != nil { return 0, 0, "", fmt.Errorf( "commit session: %w", err, ) } return sessionID, clientID, token, nil } // GetSessionByToken returns session id, client id, and // nick for a client token. func (database *Database) GetSessionByToken( ctx context.Context, token string, ) (int64, int64, string, error) { var ( sessionID int64 clientID int64 nick string ) err := database.conn.QueryRowContext( ctx, `SELECT s.id, c.id, s.nick FROM clients c INNER JOIN sessions s ON s.id = c.session_id WHERE c.token = ?`, token, ).Scan(&sessionID, &clientID, &nick) if err != nil { return 0, 0, "", fmt.Errorf( "get session by token: %w", err, ) } now := time.Now() _, _ = database.conn.ExecContext( ctx, "UPDATE sessions SET last_seen = ? WHERE id = ?", now, sessionID, ) _, _ = database.conn.ExecContext( ctx, "UPDATE clients SET last_seen = ? WHERE id = ?", now, clientID, ) return sessionID, clientID, nick, nil } // GetSessionByNick returns session id for a given nick. func (database *Database) GetSessionByNick( ctx context.Context, nick string, ) (int64, error) { var sessionID int64 err := database.conn.QueryRowContext( ctx, "SELECT id FROM sessions WHERE nick = ?", nick, ).Scan(&sessionID) if err != nil { return 0, fmt.Errorf( "get session by nick: %w", err, ) } return sessionID, nil } // GetChannelByName returns the channel ID for a name. func (database *Database) GetChannelByName( ctx context.Context, name string, ) (int64, error) { var channelID int64 err := database.conn.QueryRowContext( ctx, "SELECT id FROM channels WHERE name = ?", name, ).Scan(&channelID) if err != nil { return 0, fmt.Errorf( "get channel by name: %w", err, ) } return channelID, nil } // GetOrCreateChannel returns channel id, creating if needed. // Uses INSERT OR IGNORE to avoid TOCTOU races. func (database *Database) GetOrCreateChannel( ctx context.Context, name string, ) (int64, error) { now := time.Now() _, err := database.conn.ExecContext(ctx, `INSERT OR IGNORE INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)`, name, now, now) if err != nil { return 0, fmt.Errorf("create channel: %w", err) } var channelID int64 err = database.conn.QueryRowContext( ctx, "SELECT id FROM channels WHERE name = ?", name, ).Scan(&channelID) if err != nil { return 0, fmt.Errorf("get channel: %w", err) } return channelID, nil } // JoinChannel adds a session to a channel. func (database *Database) JoinChannel( ctx context.Context, channelID, sessionID int64, ) error { _, err := database.conn.ExecContext(ctx, `INSERT OR IGNORE INTO channel_members (channel_id, session_id, joined_at) VALUES (?, ?, ?)`, channelID, sessionID, time.Now()) if err != nil { return fmt.Errorf("join channel: %w", err) } return nil } // PartChannel removes a session from a channel. func (database *Database) PartChannel( ctx context.Context, channelID, sessionID int64, ) error { _, err := database.conn.ExecContext(ctx, `DELETE FROM channel_members WHERE channel_id = ? AND session_id = ?`, channelID, sessionID) if err != nil { return fmt.Errorf("part channel: %w", err) } return nil } // DeleteChannelIfEmpty removes a channel with no members. func (database *Database) DeleteChannelIfEmpty( ctx context.Context, channelID int64, ) error { _, err := database.conn.ExecContext(ctx, `DELETE FROM channels WHERE id = ? AND NOT EXISTS (SELECT 1 FROM channel_members WHERE channel_id = ?)`, channelID, channelID) if err != nil { return fmt.Errorf( "delete channel if empty: %w", err, ) } return nil } // scanChannels scans rows into a ChannelInfo slice. func scanChannels( rows *sql.Rows, ) ([]ChannelInfo, error) { defer func() { _ = rows.Close() }() var out []ChannelInfo for rows.Next() { var chanInfo ChannelInfo err := rows.Scan( &chanInfo.ID, &chanInfo.Name, &chanInfo.Topic, ) if err != nil { return nil, fmt.Errorf("scan channel: %w", err) } out = append(out, chanInfo) } err := rows.Err() if err != nil { return nil, fmt.Errorf("rows error: %w", err) } if out == nil { out = []ChannelInfo{} } return out, nil } // ListChannels returns channels the session has joined. func (database *Database) ListChannels( ctx context.Context, sessionID int64, ) ([]ChannelInfo, error) { rows, err := database.conn.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.session_id = ? ORDER BY c.name`, sessionID) if err != nil { return nil, fmt.Errorf("list channels: %w", err) } return scanChannels(rows) } // ListAllChannels returns every channel. func (database *Database) ListAllChannels( ctx context.Context, ) ([]ChannelInfo, error) { rows, err := database.conn.QueryContext(ctx, `SELECT id, name, topic FROM channels ORDER BY name`) if err != nil { return nil, fmt.Errorf( "list all channels: %w", err, ) } return scanChannels(rows) } // ChannelMembers returns all members of a channel. func (database *Database) ChannelMembers( ctx context.Context, channelID int64, ) ([]MemberInfo, error) { rows, err := database.conn.QueryContext(ctx, `SELECT s.id, s.nick, s.last_seen FROM sessions s INNER JOIN channel_members cm ON cm.session_id = s.id WHERE cm.channel_id = ? ORDER BY s.nick`, channelID) if err != nil { return nil, fmt.Errorf( "query channel members: %w", err, ) } defer func() { _ = rows.Close() }() var members []MemberInfo for rows.Next() { var member MemberInfo err = rows.Scan( &member.ID, &member.Nick, &member.LastSeen, ) if err != nil { return nil, fmt.Errorf( "scan member: %w", err, ) } members = append(members, member) } err = rows.Err() if err != nil { return nil, fmt.Errorf("rows error: %w", err) } if members == nil { members = []MemberInfo{} } return members, nil } // IsChannelMember checks if a session belongs to a channel. func (database *Database) IsChannelMember( ctx context.Context, channelID, sessionID int64, ) (bool, error) { var count int err := database.conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM channel_members WHERE channel_id = ? AND session_id = ?`, channelID, sessionID, ).Scan(&count) if err != nil { return false, fmt.Errorf( "check membership: %w", err, ) } return count > 0, nil } // scanInt64s scans rows into an int64 slice. func scanInt64s(rows *sql.Rows) ([]int64, error) { defer func() { _ = rows.Close() }() var ids []int64 for rows.Next() { var val int64 err := rows.Scan(&val) if err != nil { return nil, fmt.Errorf( "scan int64: %w", err, ) } ids = append(ids, val) } err := rows.Err() if err != nil { return nil, fmt.Errorf("rows error: %w", err) } return ids, nil } // GetChannelMemberIDs returns session IDs in a channel. func (database *Database) GetChannelMemberIDs( ctx context.Context, channelID int64, ) ([]int64, error) { rows, err := database.conn.QueryContext(ctx, `SELECT session_id FROM channel_members WHERE channel_id = ?`, channelID) if err != nil { return nil, fmt.Errorf( "get channel member ids: %w", err, ) } return scanInt64s(rows) } // GetSessionChannelIDs returns channel IDs for a session. func (database *Database) GetSessionChannelIDs( ctx context.Context, sessionID int64, ) ([]int64, error) { rows, err := database.conn.QueryContext(ctx, `SELECT channel_id FROM channel_members WHERE session_id = ?`, sessionID) if err != nil { return nil, fmt.Errorf( "get session channel ids: %w", err, ) } return scanInt64s(rows) } // InsertMessage stores a message and returns its DB ID. func (database *Database) InsertMessage( ctx context.Context, command, from, target string, body json.RawMessage, meta json.RawMessage, ) (int64, string, error) { msgUUID := uuid.New().String() now := time.Now().UTC() if body == nil { body = json.RawMessage("[]") } if meta == nil { meta = json.RawMessage("{}") } res, err := database.conn.ExecContext(ctx, `INSERT INTO messages (uuid, command, msg_from, msg_to, body, meta, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, msgUUID, command, from, target, string(body), string(meta), now) if err != nil { return 0, "", fmt.Errorf( "insert message: %w", err, ) } dbID, _ := res.LastInsertId() return dbID, msgUUID, nil } // EnqueueToSession adds a message to all clients of a // session's queues. func (database *Database) EnqueueToSession( ctx context.Context, sessionID, messageID int64, ) error { _, err := database.conn.ExecContext(ctx, `INSERT OR IGNORE INTO client_queues (client_id, message_id, created_at) SELECT c.id, ?, ? FROM clients c WHERE c.session_id = ?`, messageID, time.Now(), sessionID) if err != nil { return fmt.Errorf( "enqueue to session: %w", err, ) } return nil } // EnqueueToClient adds a message to a specific client's // queue. func (database *Database) EnqueueToClient( ctx context.Context, clientID, messageID int64, ) error { _, err := database.conn.ExecContext(ctx, `INSERT OR IGNORE INTO client_queues (client_id, message_id, created_at) VALUES (?, ?, ?)`, clientID, messageID, time.Now()) if err != nil { return fmt.Errorf( "enqueue to client: %w", err, ) } return nil } // PollMessages returns queued messages for a client. func (database *Database) PollMessages( ctx context.Context, clientID, afterQueueID int64, limit int, ) ([]IRCMessage, int64, error) { if limit <= 0 { limit = defaultPollLimit } rows, err := database.conn.QueryContext(ctx, `SELECT cq.id, m.uuid, m.command, m.msg_from, m.msg_to, m.body, m.meta, m.created_at FROM client_queues cq INNER JOIN messages m ON m.id = cq.message_id WHERE cq.client_id = ? AND cq.id > ? ORDER BY cq.id ASC LIMIT ?`, clientID, afterQueueID, limit) if err != nil { return nil, afterQueueID, fmt.Errorf( "poll messages: %w", err, ) } msgs, lastQID, scanErr := scanMessages( rows, afterQueueID, ) if scanErr != nil { return nil, afterQueueID, scanErr } return msgs, lastQID, nil } // GetHistory returns message history for a target. func (database *Database) GetHistory( ctx context.Context, target string, beforeID int64, limit int, ) ([]IRCMessage, error) { if limit <= 0 { limit = defaultHistLimit } rows, err := database.queryHistory( ctx, target, beforeID, limit, ) if err != nil { return nil, err } msgs, _, scanErr := scanMessages(rows, 0) if scanErr != nil { return nil, scanErr } if msgs == nil { msgs = []IRCMessage{} } reverseMessages(msgs) return msgs, nil } func (database *Database) queryHistory( ctx context.Context, target string, beforeID int64, limit int, ) (*sql.Rows, error) { if beforeID > 0 { rows, err := database.conn.QueryContext(ctx, `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages WHERE msg_to = ? AND id < ? AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?`, target, beforeID, limit) if err != nil { return nil, fmt.Errorf( "query history: %w", err, ) } return rows, nil } rows, err := database.conn.QueryContext(ctx, `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages WHERE msg_to = ? AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?`, target, limit) if err != nil { return nil, fmt.Errorf("query history: %w", err) } return rows, nil } func scanMessages( rows *sql.Rows, fallbackQID int64, ) ([]IRCMessage, int64, error) { defer func() { _ = rows.Close() }() var msgs []IRCMessage lastQID := fallbackQID for rows.Next() { var ( msg IRCMessage qID int64 body, meta string createdAt time.Time ) err := rows.Scan( &qID, &msg.ID, &msg.Command, &msg.From, &msg.To, &body, &meta, &createdAt, ) if err != nil { return nil, fallbackQID, fmt.Errorf( "scan message: %w", err, ) } msg.Body = json.RawMessage(body) msg.Meta = json.RawMessage(meta) msg.TS = createdAt.Format(time.RFC3339Nano) msg.DBID = qID lastQID = qID msgs = append(msgs, msg) } err := rows.Err() if err != nil { return nil, fallbackQID, fmt.Errorf( "rows error: %w", err, ) } if msgs == nil { msgs = []IRCMessage{} } return msgs, lastQID, nil } func reverseMessages(msgs []IRCMessage) { for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } } // ChangeNick updates a session's nickname. func (database *Database) ChangeNick( ctx context.Context, sessionID int64, newNick string, ) error { _, err := database.conn.ExecContext(ctx, "UPDATE sessions SET nick = ? WHERE id = ?", newNick, sessionID) if err != nil { return fmt.Errorf("change nick: %w", err) } return nil } // SetTopic sets the topic for a channel. func (database *Database) SetTopic( ctx context.Context, channelName, topic string, ) error { _, err := database.conn.ExecContext(ctx, `UPDATE channels SET topic = ?, updated_at = ? WHERE name = ?`, topic, time.Now(), channelName) if err != nil { return fmt.Errorf("set topic: %w", err) } return nil } // DeleteSession removes a session and all its data. func (database *Database) DeleteSession( ctx context.Context, sessionID int64, ) error { _, err := database.conn.ExecContext( ctx, "DELETE FROM sessions WHERE id = ?", sessionID, ) if err != nil { return fmt.Errorf("delete session: %w", err) } return nil } // DeleteClient removes a single client record by ID. func (database *Database) DeleteClient( ctx context.Context, clientID int64, ) error { _, err := database.conn.ExecContext( ctx, "DELETE FROM clients WHERE id = ?", clientID, ) if err != nil { return fmt.Errorf("delete client: %w", err) } return nil } // GetSessionCount returns the number of active sessions. func (database *Database) GetSessionCount( ctx context.Context, ) (int64, error) { var count int64 err := database.conn.QueryRowContext( ctx, "SELECT COUNT(*) FROM sessions", ).Scan(&count) if err != nil { return 0, fmt.Errorf( "get session count: %w", err, ) } return count, nil } // ClientCountForSession returns the number of clients // belonging to a session. func (database *Database) ClientCountForSession( ctx context.Context, sessionID int64, ) (int64, error) { var count int64 err := database.conn.QueryRowContext( ctx, `SELECT COUNT(*) FROM clients WHERE session_id = ?`, sessionID, ).Scan(&count) if err != nil { return 0, fmt.Errorf( "client count for session: %w", err, ) } return count, nil } // DeleteStaleSessions removes clients not seen since the // cutoff and cleans up orphaned sessions. func (database *Database) DeleteStaleSessions( ctx context.Context, cutoff time.Time, ) (int64, error) { res, err := database.conn.ExecContext(ctx, "DELETE FROM clients WHERE last_seen < ?", cutoff, ) if err != nil { return 0, fmt.Errorf( "delete stale clients: %w", err, ) } deleted, _ := res.RowsAffected() _, err = database.conn.ExecContext(ctx, `DELETE FROM sessions WHERE id NOT IN (SELECT DISTINCT session_id FROM clients)`, ) if err != nil { return deleted, fmt.Errorf( "delete orphan sessions: %w", err, ) } return deleted, nil } // GetSessionChannels returns channels a session // belongs to. func (database *Database) GetSessionChannels( ctx context.Context, sessionID int64, ) ([]ChannelInfo, error) { rows, err := database.conn.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.session_id = ?`, sessionID) if err != nil { return nil, fmt.Errorf( "get session channels: %w", err, ) } return scanChannels(rows) }