package db import ( "context" "crypto/rand" "encoding/hex" "encoding/json" "fmt" "time" "github.com/google/uuid" ) func generateToken() string { b := make([]byte, 32) _, _ = rand.Read(b) return hex.EncodeToString(b) } // IRCMessage is the IRC envelope format for all messages. type IRCMessage struct { ID string `json:"id"` Command string `json:"command"` From string `json:"from,omitempty"` To string `json:"to,omitempty"` Body json.RawMessage `json:"body,omitempty"` TS string `json:"ts"` Meta json.RawMessage `json:"meta,omitempty"` // Internal DB fields (not in JSON) DBID int64 `json:"-"` } // ChannelInfo is a lightweight channel representation. type ChannelInfo struct { ID int64 `json:"id"` Name string `json:"name"` Topic string `json:"topic"` } // MemberInfo represents a channel member. type MemberInfo struct { ID int64 `json:"id"` Nick string `json:"nick"` LastSeen time.Time `json:"lastSeen"` } // CreateUser registers a new user with the given nick. func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) { token := generateToken() now := time.Now() res, err := s.db.ExecContext(ctx, "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", nick, token, now, now) if err != nil { return 0, "", fmt.Errorf("create user: %w", err) } id, _ := res.LastInsertId() return id, token, nil } // GetUserByToken returns user id and nick for a given auth token. func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) { var id int64 var nick string err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick) if err != nil { return 0, "", err } _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id) return id, nick, nil } // GetUserByNick returns user id for a given nick. func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) { var id int64 err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id) return id, err } // GetOrCreateChannel returns the channel id, creating it if needed. func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) { var id int64 err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) if err == nil { return id, nil } now := time.Now() res, err := s.db.ExecContext(ctx, "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", name, now, now) if err != nil { return 0, fmt.Errorf("create channel: %w", err) } id, _ = res.LastInsertId() return id, nil } // JoinChannel adds a user to a channel. func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error { _, err := s.db.ExecContext(ctx, "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", channelID, userID, time.Now()) return err } // PartChannel removes a user from a channel. func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error { _, err := s.db.ExecContext(ctx, "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", channelID, userID) return err } // DeleteChannelIfEmpty deletes a channel if it has no members. func (s *Database) DeleteChannelIfEmpty(ctx context.Context, channelID int64) error { _, err := s.db.ExecContext(ctx, `DELETE FROM channels WHERE id = ? AND NOT EXISTS (SELECT 1 FROM channel_members WHERE channel_id = ?)`, channelID, channelID) return err } // ListChannels returns all channels the user has joined. func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ? ORDER BY c.name`, userID) if err != nil { return nil, err } defer rows.Close() var channels []ChannelInfo for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } channels = append(channels, ch) } if channels == nil { channels = []ChannelInfo{} } return channels, nil } // ListAllChannels returns all channels. func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, topic FROM channels ORDER BY name") if err != nil { return nil, err } defer rows.Close() var channels []ChannelInfo for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } channels = append(channels, ch) } if channels == nil { channels = []ChannelInfo{} } return channels, nil } // ChannelMembers returns all members of a channel. func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT u.id, u.nick, u.last_seen FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) if err != nil { return nil, err } defer rows.Close() var members []MemberInfo for rows.Next() { var m MemberInfo if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil { return nil, err } members = append(members, m) } if members == nil { members = []MemberInfo{} } return members, nil } // GetChannelMemberIDs returns user IDs of all members in a channel. func (s *Database) GetChannelMemberIDs(ctx context.Context, channelID int64) ([]int64, error) { rows, err := s.db.QueryContext(ctx, "SELECT user_id FROM channel_members WHERE channel_id = ?", channelID) if err != nil { return nil, err } defer rows.Close() var ids []int64 for rows.Next() { var id int64 if err := rows.Scan(&id); err != nil { return nil, err } ids = append(ids, id) } return ids, nil } // GetUserChannelIDs returns channel IDs the user is a member of. func (s *Database) GetUserChannelIDs(ctx context.Context, userID int64) ([]int64, error) { rows, err := s.db.QueryContext(ctx, "SELECT channel_id FROM channel_members WHERE user_id = ?", userID) if err != nil { return nil, err } defer rows.Close() var ids []int64 for rows.Next() { var id int64 if err := rows.Scan(&id); err != nil { return nil, err } ids = append(ids, id) } return ids, nil } // InsertMessage stores a message and returns its DB ID. func (s *Database) InsertMessage(ctx context.Context, command, from, to string, body json.RawMessage, meta json.RawMessage) (int64, string, error) { msgUUID := uuid.New().String() now := time.Now().UTC() if body == nil { body = json.RawMessage("[]") } if meta == nil { meta = json.RawMessage("{}") } res, err := s.db.ExecContext(ctx, `INSERT INTO messages (uuid, command, msg_from, msg_to, body, meta, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, msgUUID, command, from, to, string(body), string(meta), now) if err != nil { return 0, "", err } id, _ := res.LastInsertId() return id, msgUUID, nil } // EnqueueMessage adds a message to a user's delivery queue. func (s *Database) EnqueueMessage(ctx context.Context, userID, messageID int64) error { _, err := s.db.ExecContext(ctx, "INSERT OR IGNORE INTO client_queues (user_id, message_id, created_at) VALUES (?, ?, ?)", userID, messageID, time.Now()) return err } // PollMessages returns queued messages for a user after a given queue ID. func (s *Database) PollMessages(ctx context.Context, userID int64, afterQueueID int64, limit int) ([]IRCMessage, int64, error) { if limit <= 0 { limit = 100 } rows, err := s.db.QueryContext(ctx, `SELECT cq.id, m.uuid, m.command, m.msg_from, m.msg_to, m.body, m.meta, m.created_at FROM client_queues cq INNER JOIN messages m ON m.id = cq.message_id WHERE cq.user_id = ? AND cq.id > ? ORDER BY cq.id ASC LIMIT ?`, userID, afterQueueID, limit) if err != nil { return nil, afterQueueID, err } defer rows.Close() var msgs []IRCMessage var lastQID int64 for rows.Next() { var m IRCMessage var qID int64 var body, meta string var ts time.Time if err := rows.Scan(&qID, &m.ID, &m.Command, &m.From, &m.To, &body, &meta, &ts); err != nil { return nil, afterQueueID, err } m.Body = json.RawMessage(body) m.Meta = json.RawMessage(meta) m.TS = ts.Format(time.RFC3339Nano) m.DBID = qID lastQID = qID msgs = append(msgs, m) } if msgs == nil { msgs = []IRCMessage{} } if lastQID == 0 { lastQID = afterQueueID } return msgs, lastQID, nil } // GetHistory returns message history for a target (channel or DM nick pair). func (s *Database) GetHistory(ctx context.Context, target string, beforeID int64, limit int) ([]IRCMessage, error) { if limit <= 0 { limit = 50 } var query string var args []any if beforeID > 0 { query = `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages WHERE msg_to = ? AND id < ? AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?` args = []any{target, beforeID, limit} } else { query = `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages WHERE msg_to = ? AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?` args = []any{target, limit} } rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var msgs []IRCMessage for rows.Next() { var m IRCMessage var dbID int64 var body, meta string var ts time.Time if err := rows.Scan(&dbID, &m.ID, &m.Command, &m.From, &m.To, &body, &meta, &ts); err != nil { return nil, err } m.Body = json.RawMessage(body) m.Meta = json.RawMessage(meta) m.TS = ts.Format(time.RFC3339Nano) m.DBID = dbID msgs = append(msgs, m) } if msgs == nil { msgs = []IRCMessage{} } // Reverse to ascending order for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } return msgs, nil } // ChangeNick updates a user's nickname. func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error { _, err := s.db.ExecContext(ctx, "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) return err } // SetTopic sets the topic for a channel. func (s *Database) SetTopic(ctx context.Context, channelName string, topic string) error { _, err := s.db.ExecContext(ctx, "UPDATE channels SET topic = ?, updated_at = ? WHERE name = ?", topic, time.Now(), channelName) return err } // DeleteUser removes a user and all their data. func (s *Database) DeleteUser(ctx context.Context, userID int64) error { _, err := s.db.ExecContext(ctx, "DELETE FROM users WHERE id = ?", userID) return err } // GetAllChannelMembershipsForUser returns (channelID, channelName) for all channels a user is in. func (s *Database) GetAllChannelMembershipsForUser(ctx context.Context, userID int64) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ?`, userID) if err != nil { return nil, err } defer rows.Close() var channels []ChannelInfo for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } channels = append(channels, ch) } return channels, nil }