package db import ( "context" "database/sql" "fmt" "time" ) // GetOrCreateChannel returns the channel id, creating it if needed. func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (string, error) { var id string err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) if err == nil { return id, nil } now := time.Now() id = fmt.Sprintf("ch-%d", now.UnixNano()) _, err = s.db.ExecContext(ctx, "INSERT INTO channels (id, name, topic, modes, created_at, updated_at) VALUES (?, ?, '', '', ?, ?)", id, name, now, now) if err != nil { return "", fmt.Errorf("create channel: %w", err) } return id, nil } // JoinChannel adds a user to a channel. func (s *Database) JoinChannel(ctx context.Context, channelID, userID string) error { _, err := s.db.ExecContext(ctx, "INSERT OR IGNORE INTO channel_members (channel_id, user_id, modes, joined_at) VALUES (?, ?, '', ?)", channelID, userID, time.Now()) return err } // PartChannel removes a user from a channel. func (s *Database) PartChannel(ctx context.Context, channelID, userID string) error { _, err := s.db.ExecContext(ctx, "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", channelID, userID) return err } // ListChannels returns all channels the user has joined. func (s *Database) ListChannels(ctx context.Context, userID string) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ? ORDER BY c.name`, userID) if err != nil { return nil, err } defer func() { _ = rows.Close() }() channels := []ChannelInfo{} for rows.Next() { var ch ChannelInfo err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic) if err != nil { return nil, err } channels = append(channels, ch) } return channels, rows.Err() } // ChannelInfo is a lightweight channel representation. type ChannelInfo struct { ID string `json:"id"` Name string `json:"name"` Topic string `json:"topic"` } // ChannelMembers returns all members of a channel. func (s *Database) ChannelMembers(ctx context.Context, channelID string) ([]MemberInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT u.id, u.nick, u.last_seen_at FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) if err != nil { return nil, err } defer func() { _ = rows.Close() }() members := []MemberInfo{} for rows.Next() { var m MemberInfo err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen) if err != nil { return nil, err } members = append(members, m) } return members, rows.Err() } // MemberInfo represents a channel member. type MemberInfo struct { ID string `json:"id"` Nick string `json:"nick"` LastSeen *time.Time `json:"lastSeen"` } // MessageInfo represents a chat message. type MessageInfo struct { ID string `json:"id"` Channel string `json:"channel,omitempty"` Nick string `json:"nick"` Content string `json:"content"` IsDM bool `json:"isDm,omitempty"` DMTarget string `json:"dmTarget,omitempty"` CreatedAt time.Time `json:"createdAt"` } // SendMessage inserts a channel message. func (s *Database) SendMessage(ctx context.Context, channelID, userID, nick, content string) (string, error) { now := time.Now() id := fmt.Sprintf("msg-%d", now.UnixNano()) _, err := s.db.ExecContext(ctx, `INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at) VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`, id, now, userID, nick, channelID, content, now) if err != nil { return "", err } return id, nil } // SendDM inserts a direct message. func (s *Database) SendDM(ctx context.Context, fromID, fromNick, toID, content string) (string, error) { now := time.Now() id := fmt.Sprintf("msg-%d", now.UnixNano()) _, err := s.db.ExecContext(ctx, `INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at) VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`, id, now, fromID, fromNick, toID, content, now) if err != nil { return "", err } return id, nil } // PollMessages returns all new messages for a user's joined channels, ordered by timestamp. func (s *Database) PollMessages(ctx context.Context, userID string, afterTS string, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 100 } rows, err := s.db.QueryContext(ctx, `SELECT m.id, m.target, m.from_nick, m.body, m.created_at FROM messages m WHERE m.created_at > COALESCE(NULLIF(?, ''), '1970-01-01') AND ( m.target IN (SELECT cm.channel_id FROM channel_members cm WHERE cm.user_id = ?) OR m.target = ? OR m.from_user_id = ? ) ORDER BY m.created_at ASC LIMIT ?`, afterTS, userID, userID, userID, limit) if err != nil { return nil, err } defer func() { _ = rows.Close() }() msgs := []MessageInfo{} for rows.Next() { var m MessageInfo err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt) if err != nil { return nil, err } msgs = append(msgs, m) } return msgs, rows.Err() } // GetMessagesBefore returns channel messages before a given timestamp (for history scrollback). func (s *Database) GetMessagesBefore( ctx context.Context, target string, beforeTS string, limit int, ) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } var rows *sql.Rows var err error if beforeTS != "" { rows, err = s.db.QueryContext(ctx, `SELECT m.id, m.target, m.from_nick, m.body, m.created_at FROM messages m WHERE m.target = ? AND m.created_at < ? ORDER BY m.created_at DESC LIMIT ?`, target, beforeTS, limit) } else { rows, err = s.db.QueryContext(ctx, `SELECT m.id, m.target, m.from_nick, m.body, m.created_at FROM messages m WHERE m.target = ? ORDER BY m.created_at DESC LIMIT ?`, target, limit) } if err != nil { return nil, err } defer func() { _ = rows.Close() }() msgs := []MessageInfo{} for rows.Next() { var m MessageInfo err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt) if err != nil { return nil, err } msgs = append(msgs, m) } // Reverse to ascending order. for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } return msgs, rows.Err() } // GetDMsBefore returns DMs between two users before a given timestamp. func (s *Database) GetDMsBefore( ctx context.Context, userA, userB string, beforeTS string, limit int, ) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } var rows *sql.Rows var err error if beforeTS != "" { rows, err = s.db.QueryContext(ctx, `SELECT m.id, m.from_nick, m.body, m.target, m.created_at FROM messages m WHERE m.created_at < ? AND ((m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?)) ORDER BY m.created_at DESC LIMIT ?`, beforeTS, userA, userB, userB, userA, limit) } else { rows, err = s.db.QueryContext(ctx, `SELECT m.id, m.from_nick, m.body, m.target, m.created_at FROM messages m WHERE (m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?) ORDER BY m.created_at DESC LIMIT ?`, userA, userB, userB, userA, limit) } if err != nil { return nil, err } defer func() { _ = rows.Close() }() msgs := []MessageInfo{} for rows.Next() { var m MessageInfo err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt) if err != nil { return nil, err } m.IsDM = true msgs = append(msgs, m) } // Reverse to ascending order. for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } return msgs, rows.Err() } // ChangeNick updates a user's nickname. func (s *Database) ChangeNick(ctx context.Context, userID string, newNick string) error { _, err := s.db.ExecContext(ctx, "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) return err } // SetTopic sets the topic for a channel. func (s *Database) SetTopic(ctx context.Context, channelName string, _ string, topic string) error { _, err := s.db.ExecContext(ctx, "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName) return err } // GetServerName returns the server name (unused, config provides this). func (s *Database) GetServerName() string { return "" } // ListAllChannels returns all channels. func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, topic FROM channels ORDER BY name") if err != nil { return nil, err } defer func() { _ = rows.Close() }() channels := []ChannelInfo{} for rows.Next() { var ch ChannelInfo err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic) if err != nil { return nil, err } channels = append(channels, ch) } return channels, rows.Err() }