package db import ( "context" "crypto/rand" "encoding/hex" "fmt" "time" ) func generateToken() string { b := make([]byte, 32) _, _ = rand.Read(b) return hex.EncodeToString(b) } // CreateUser registers a new user with the given nick and returns the user with token. func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) { token := generateToken() now := time.Now() res, err := s.db.ExecContext(ctx, "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", nick, token, now, now) if err != nil { return 0, "", fmt.Errorf("create user: %w", err) } id, _ := res.LastInsertId() return id, token, nil } // GetUserByToken returns user id and nick for a given auth token. func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) { var id int64 var nick string err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick) if err != nil { return 0, "", err } // Update last_seen _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id) return id, nick, nil } // GetUserByNick returns user id for a given nick. func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) { var id int64 err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id) return id, err } // GetOrCreateChannel returns the channel id, creating it if needed. func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) { var id int64 err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) if err == nil { return id, nil } now := time.Now() res, err := s.db.ExecContext(ctx, "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", name, now, now) if err != nil { return 0, fmt.Errorf("create channel: %w", err) } id, _ = res.LastInsertId() return id, nil } // JoinChannel adds a user to a channel. func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error { _, err := s.db.ExecContext(ctx, "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", channelID, userID, time.Now()) return err } // PartChannel removes a user from a channel. func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error { _, err := s.db.ExecContext(ctx, "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", channelID, userID) return err } // ListChannels returns all channels the user has joined. func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ? ORDER BY c.name`, userID) if err != nil { return nil, err } defer rows.Close() var channels []ChannelInfo for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } channels = append(channels, ch) } if channels == nil { channels = []ChannelInfo{} } return channels, nil } // ChannelInfo is a lightweight channel representation. type ChannelInfo struct { ID int64 `json:"id"` Name string `json:"name"` Topic string `json:"topic"` } // ChannelMembers returns all members of a channel. func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT u.id, u.nick, u.last_seen FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) if err != nil { return nil, err } defer rows.Close() var members []MemberInfo for rows.Next() { var m MemberInfo if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil { return nil, err } members = append(members, m) } if members == nil { members = []MemberInfo{} } return members, nil } // MemberInfo represents a channel member. type MemberInfo struct { ID int64 `json:"id"` Nick string `json:"nick"` LastSeen time.Time `json:"lastSeen"` } // MessageInfo represents a chat message. type MessageInfo struct { ID int64 `json:"id"` Channel string `json:"channel,omitempty"` Nick string `json:"nick"` Content string `json:"content"` IsDM bool `json:"isDm,omitempty"` DMTarget string `json:"dmTarget,omitempty"` CreatedAt time.Time `json:"createdAt"` } // GetMessages returns messages for a channel, optionally after a given ID. func (s *Database) GetMessages(ctx context.Context, channelID int64, afterID int64, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } rows, err := s.db.QueryContext(ctx, `SELECT m.id, c.name, u.nick, m.content, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN channels c ON c.id = m.channel_id WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id > ? ORDER BY m.id ASC LIMIT ?`, channelID, afterID, limit) if err != nil { return nil, err } defer rows.Close() var msgs []MessageInfo for rows.Next() { var m MessageInfo if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { return nil, err } msgs = append(msgs, m) } if msgs == nil { msgs = []MessageInfo{} } return msgs, nil } // SendMessage inserts a channel message. func (s *Database) SendMessage(ctx context.Context, channelID, userID int64, content string) (int64, error) { res, err := s.db.ExecContext(ctx, "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)", channelID, userID, content, time.Now()) if err != nil { return 0, err } return res.LastInsertId() } // SendDM inserts a direct message. func (s *Database) SendDM(ctx context.Context, fromID, toID int64, content string) (int64, error) { res, err := s.db.ExecContext(ctx, "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)", fromID, content, toID, time.Now()) if err != nil { return 0, err } return res.LastInsertId() } // GetDMs returns direct messages between two users after a given ID. func (s *Database) GetDMs(ctx context.Context, userA, userB int64, afterID int64, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } rows, err := s.db.QueryContext(ctx, `SELECT m.id, u.nick, m.content, t.nick, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN users t ON t.id = m.dm_target_id WHERE m.is_dm = 1 AND m.id > ? AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id ASC LIMIT ?`, afterID, userA, userB, userB, userA, limit) if err != nil { return nil, err } defer rows.Close() var msgs []MessageInfo for rows.Next() { var m MessageInfo if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { return nil, err } m.IsDM = true msgs = append(msgs, m) } if msgs == nil { msgs = []MessageInfo{} } return msgs, nil } // PollMessages returns all new messages (channel + DM) for a user after a given ID. func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 100 } rows, err := s.db.QueryContext(ctx, `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, m.is_dm, COALESCE(t.nick, ''), m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id LEFT JOIN channels c ON c.id = m.channel_id LEFT JOIN users t ON t.id = m.dm_target_id WHERE m.id > ? AND ( (m.is_dm = 0 AND m.channel_id IN (SELECT channel_id FROM channel_members WHERE user_id = ?)) OR (m.is_dm = 1 AND (m.user_id = ? OR m.dm_target_id = ?)) ) ORDER BY m.id ASC LIMIT ?`, afterID, userID, userID, userID, limit) if err != nil { return nil, err } defer rows.Close() var msgs []MessageInfo for rows.Next() { var m MessageInfo var isDM int if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &isDM, &m.DMTarget, &m.CreatedAt); err != nil { return nil, err } m.IsDM = isDM == 1 msgs = append(msgs, m) } if msgs == nil { msgs = []MessageInfo{} } return msgs, nil } // GetMessagesBefore returns channel messages before a given ID (for history scrollback). func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } var query string var args []any if beforeID > 0 { query = `SELECT m.id, c.name, u.nick, m.content, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN channels c ON c.id = m.channel_id WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ? ORDER BY m.id DESC LIMIT ?` args = []any{channelID, beforeID, limit} } else { query = `SELECT m.id, c.name, u.nick, m.content, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN channels c ON c.id = m.channel_id WHERE m.channel_id = ? AND m.is_dm = 0 ORDER BY m.id DESC LIMIT ?` args = []any{channelID, limit} } rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var msgs []MessageInfo for rows.Next() { var m MessageInfo if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { return nil, err } msgs = append(msgs, m) } if msgs == nil { msgs = []MessageInfo{} } // Reverse to ascending order for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } return msgs, nil } // GetDMsBefore returns DMs between two users before a given ID (for history scrollback). func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } var query string var args []any if beforeID > 0 { query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN users t ON t.id = m.dm_target_id WHERE m.is_dm = 1 AND m.id < ? AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id DESC LIMIT ?` args = []any{beforeID, userA, userB, userB, userA, limit} } else { query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN users t ON t.id = m.dm_target_id WHERE m.is_dm = 1 AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id DESC LIMIT ?` args = []any{userA, userB, userB, userA, limit} } rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var msgs []MessageInfo for rows.Next() { var m MessageInfo if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { return nil, err } m.IsDM = true msgs = append(msgs, m) } if msgs == nil { msgs = []MessageInfo{} } // Reverse to ascending order for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } return msgs, nil } // ChangeNick updates a user's nickname. func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error { _, err := s.db.ExecContext(ctx, "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) return err } // SetTopic sets the topic for a channel. func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error { _, err := s.db.ExecContext(ctx, "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName) return err } // GetServerName returns the server name (unused, config provides this). func (s *Database) GetServerName() string { return "" } // ListAllChannels returns all channels. func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, topic FROM channels ORDER BY name") if err != nil { return nil, err } defer rows.Close() var channels []ChannelInfo for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } channels = append(channels, ch) } if channels == nil { channels = []ChannelInfo{} } return channels, nil }