package db import ( "context" "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "fmt" "time" "github.com/google/uuid" ) const ( tokenBytes = 32 defaultPollLimit = 100 defaultHistLimit = 50 ) func generateToken() string { b := make([]byte, tokenBytes) _, _ = rand.Read(b) return hex.EncodeToString(b) } // IRCMessage is the IRC envelope for all messages. type IRCMessage struct { ID string `json:"id"` Command string `json:"command"` From string `json:"from,omitempty"` To string `json:"to,omitempty"` Body json.RawMessage `json:"body,omitempty"` TS string `json:"ts"` Meta json.RawMessage `json:"meta,omitempty"` DBID int64 `json:"-"` } // ChannelInfo is a lightweight channel representation. type ChannelInfo struct { ID int64 `json:"id"` Name string `json:"name"` Topic string `json:"topic"` } // MemberInfo represents a channel member. type MemberInfo struct { ID int64 `json:"id"` Nick string `json:"nick"` LastSeen time.Time `json:"lastSeen"` } // CreateUser registers a new user with the given nick. func (s *Database) CreateUser( ctx context.Context, nick string, ) (int64, string, error) { token := generateToken() now := time.Now() res, err := s.db.ExecContext(ctx, `INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)`, nick, token, now, now) if err != nil { return 0, "", fmt.Errorf("create user: %w", err) } id, _ := res.LastInsertId() return id, token, nil } // GetUserByToken returns user id and nick for a token. func (s *Database) GetUserByToken( ctx context.Context, token string, ) (int64, string, error) { var id int64 var nick string err := s.db.QueryRowContext( ctx, "SELECT id, nick FROM users WHERE token = ?", token, ).Scan(&id, &nick) if err != nil { return 0, "", err } _, _ = s.db.ExecContext( ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id, ) return id, nick, nil } // GetUserByNick returns user id for a given nick. func (s *Database) GetUserByNick( ctx context.Context, nick string, ) (int64, error) { var id int64 err := s.db.QueryRowContext( ctx, "SELECT id FROM users WHERE nick = ?", nick, ).Scan(&id) return id, err } // GetChannelByName returns the channel ID for a name. func (s *Database) GetChannelByName( ctx context.Context, name string, ) (int64, error) { var id int64 err := s.db.QueryRowContext( ctx, "SELECT id FROM channels WHERE name = ?", name, ).Scan(&id) return id, err } // GetOrCreateChannel returns channel id, creating if needed. func (s *Database) GetOrCreateChannel( ctx context.Context, name string, ) (int64, error) { var id int64 err := s.db.QueryRowContext( ctx, "SELECT id FROM channels WHERE name = ?", name, ).Scan(&id) if err == nil { return id, nil } now := time.Now() res, err := s.db.ExecContext(ctx, `INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)`, name, now, now) if err != nil { return 0, fmt.Errorf("create channel: %w", err) } id, _ = res.LastInsertId() return id, nil } // JoinChannel adds a user to a channel. func (s *Database) JoinChannel( ctx context.Context, channelID, userID int64, ) error { _, err := s.db.ExecContext(ctx, `INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)`, channelID, userID, time.Now()) return err } // PartChannel removes a user from a channel. func (s *Database) PartChannel( ctx context.Context, channelID, userID int64, ) error { _, err := s.db.ExecContext(ctx, `DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?`, channelID, userID) return err } // DeleteChannelIfEmpty removes a channel with no members. func (s *Database) DeleteChannelIfEmpty( ctx context.Context, channelID int64, ) error { _, err := s.db.ExecContext(ctx, `DELETE FROM channels WHERE id = ? AND NOT EXISTS (SELECT 1 FROM channel_members WHERE channel_id = ?)`, channelID, channelID) return err } // scanChannels scans rows into a ChannelInfo slice. func scanChannels( rows *sql.Rows, ) ([]ChannelInfo, error) { defer func() { _ = rows.Close() }() var out []ChannelInfo for rows.Next() { var ch ChannelInfo err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic) if err != nil { return nil, err } out = append(out, ch) } err := rows.Err() if err != nil { return nil, err } if out == nil { out = []ChannelInfo{} } return out, nil } // ListChannels returns channels the user has joined. func (s *Database) ListChannels( ctx context.Context, userID int64, ) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ? ORDER BY c.name`, userID) if err != nil { return nil, err } return scanChannels(rows) } // ListAllChannels returns every channel. func (s *Database) ListAllChannels( ctx context.Context, ) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT id, name, topic FROM channels ORDER BY name`) if err != nil { return nil, err } return scanChannels(rows) } // ChannelMembers returns all members of a channel. func (s *Database) ChannelMembers( ctx context.Context, channelID int64, ) ([]MemberInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT u.id, u.nick, u.last_seen FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var members []MemberInfo for rows.Next() { var m MemberInfo err = rows.Scan(&m.ID, &m.Nick, &m.LastSeen) if err != nil { return nil, err } members = append(members, m) } err = rows.Err() if err != nil { return nil, err } if members == nil { members = []MemberInfo{} } return members, nil } // scanInt64s scans rows into an int64 slice. func scanInt64s(rows *sql.Rows) ([]int64, error) { defer func() { _ = rows.Close() }() var ids []int64 for rows.Next() { var id int64 err := rows.Scan(&id) if err != nil { return nil, err } ids = append(ids, id) } err := rows.Err() if err != nil { return nil, err } return ids, nil } // GetChannelMemberIDs returns user IDs in a channel. func (s *Database) GetChannelMemberIDs( ctx context.Context, channelID int64, ) ([]int64, error) { rows, err := s.db.QueryContext(ctx, `SELECT user_id FROM channel_members WHERE channel_id = ?`, channelID) if err != nil { return nil, err } return scanInt64s(rows) } // GetUserChannelIDs returns channel IDs the user is in. func (s *Database) GetUserChannelIDs( ctx context.Context, userID int64, ) ([]int64, error) { rows, err := s.db.QueryContext(ctx, `SELECT channel_id FROM channel_members WHERE user_id = ?`, userID) if err != nil { return nil, err } return scanInt64s(rows) } // InsertMessage stores a message and returns its DB ID. func (s *Database) InsertMessage( ctx context.Context, command, from, to string, body json.RawMessage, meta json.RawMessage, ) (int64, string, error) { msgUUID := uuid.New().String() now := time.Now().UTC() if body == nil { body = json.RawMessage("[]") } if meta == nil { meta = json.RawMessage("{}") } res, err := s.db.ExecContext(ctx, `INSERT INTO messages (uuid, command, msg_from, msg_to, body, meta, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, msgUUID, command, from, to, string(body), string(meta), now) if err != nil { return 0, "", err } id, _ := res.LastInsertId() return id, msgUUID, nil } // EnqueueMessage adds a message to a user's queue. func (s *Database) EnqueueMessage( ctx context.Context, userID, messageID int64, ) error { _, err := s.db.ExecContext(ctx, `INSERT OR IGNORE INTO client_queues (user_id, message_id, created_at) VALUES (?, ?, ?)`, userID, messageID, time.Now()) return err } // PollMessages returns queued messages for a user. func (s *Database) PollMessages( ctx context.Context, userID, afterQueueID int64, limit int, ) ([]IRCMessage, int64, error) { if limit <= 0 { limit = defaultPollLimit } rows, err := s.db.QueryContext(ctx, `SELECT cq.id, m.uuid, m.command, m.msg_from, m.msg_to, m.body, m.meta, m.created_at FROM client_queues cq INNER JOIN messages m ON m.id = cq.message_id WHERE cq.user_id = ? AND cq.id > ? ORDER BY cq.id ASC LIMIT ?`, userID, afterQueueID, limit) if err != nil { return nil, afterQueueID, err } msgs, lastQID, scanErr := scanMessages( rows, afterQueueID, ) if scanErr != nil { return nil, afterQueueID, scanErr } return msgs, lastQID, nil } // GetHistory returns message history for a target. func (s *Database) GetHistory( ctx context.Context, target string, beforeID int64, limit int, ) ([]IRCMessage, error) { if limit <= 0 { limit = defaultHistLimit } rows, err := s.queryHistory( ctx, target, beforeID, limit, ) if err != nil { return nil, err } msgs, _, scanErr := scanMessages(rows, 0) if scanErr != nil { return nil, scanErr } if msgs == nil { msgs = []IRCMessage{} } reverseMessages(msgs) return msgs, nil } func (s *Database) queryHistory( ctx context.Context, target string, beforeID int64, limit int, ) (*sql.Rows, error) { if beforeID > 0 { return s.db.QueryContext(ctx, `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages WHERE msg_to = ? AND id < ? AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?`, target, beforeID, limit) } return s.db.QueryContext(ctx, `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages WHERE msg_to = ? AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?`, target, limit) } func scanMessages( rows *sql.Rows, fallbackQID int64, ) ([]IRCMessage, int64, error) { defer func() { _ = rows.Close() }() var msgs []IRCMessage lastQID := fallbackQID for rows.Next() { var ( m IRCMessage qID int64 body, meta string ts time.Time ) err := rows.Scan( &qID, &m.ID, &m.Command, &m.From, &m.To, &body, &meta, &ts, ) if err != nil { return nil, fallbackQID, err } m.Body = json.RawMessage(body) m.Meta = json.RawMessage(meta) m.TS = ts.Format(time.RFC3339Nano) m.DBID = qID lastQID = qID msgs = append(msgs, m) } err := rows.Err() if err != nil { return nil, fallbackQID, err } if msgs == nil { msgs = []IRCMessage{} } return msgs, lastQID, nil } func reverseMessages(msgs []IRCMessage) { for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } } // ChangeNick updates a user's nickname. func (s *Database) ChangeNick( ctx context.Context, userID int64, newNick string, ) error { _, err := s.db.ExecContext(ctx, "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) return err } // SetTopic sets the topic for a channel. func (s *Database) SetTopic( ctx context.Context, channelName, topic string, ) error { _, err := s.db.ExecContext(ctx, `UPDATE channels SET topic = ?, updated_at = ? WHERE name = ?`, topic, time.Now(), channelName) return err } // DeleteUser removes a user and all their data. func (s *Database) DeleteUser( ctx context.Context, userID int64, ) error { _, err := s.db.ExecContext( ctx, "DELETE FROM users WHERE id = ?", userID, ) return err } // GetAllChannelMembershipsForUser returns channels // a user belongs to. func (s *Database) GetAllChannelMembershipsForUser( ctx context.Context, userID int64, ) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ?`, userID) if err != nil { return nil, err } return scanChannels(rows) }