All checks were successful
check / check (push) Successful in 59s
1. ISUPPORT/applyChannelModes: extend IRC MODE handler to support +i/-i, +s/-s, +n/-n (routed through svc.SetChannelFlag), and +H/-H (hashcash bits with parameter parsing). Add 'n' (no external messages) as a proper DB-backed channel flag with is_no_external column (default: on). Update IRC ISUPPORT to CHANMODES=,,H,imnst to match actual support. 2. QueryChannelMode: rewrite to return complete mode string including all boolean flags (n, i, m, s, t) and parameterized modes (k, l, H), matching the HTTP handler's buildChannelModeString logic. Simplify buildChannelModeString to delegate to QueryChannelMode for consistency. 3. Service struct encapsulation: change exported fields (DB, Broker, Config, Log) to unexported (db, broker, config, log). Add NewTestService constructor for use by external test packages. Update ircserver export_test.go to use the new constructor. Closes #89
2416 lines
50 KiB
Go
2416 lines
50 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/neoirc/pkg/irc"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const (
|
|
tokenBytes = 32
|
|
defaultPollLimit = 100
|
|
defaultHistLimit = 50
|
|
)
|
|
|
|
func generateToken() (string, error) {
|
|
buf := make([]byte, tokenBytes)
|
|
|
|
_, err := rand.Read(buf)
|
|
if err != nil {
|
|
return "", fmt.Errorf("generate token: %w", err)
|
|
}
|
|
|
|
return hex.EncodeToString(buf), nil
|
|
}
|
|
|
|
// hashToken returns the lowercase hex-encoded SHA-256
|
|
// digest of a plaintext token string.
|
|
func hashToken(token string) string {
|
|
sum := sha256.Sum256([]byte(token))
|
|
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
// IRCMessage is the IRC envelope for all messages.
|
|
type IRCMessage struct {
|
|
ID string `json:"id"`
|
|
Command string `json:"command"`
|
|
Code int `json:"code,omitempty"`
|
|
From string `json:"from,omitempty"`
|
|
To string `json:"to,omitempty"`
|
|
Params json.RawMessage `json:"params,omitempty"`
|
|
Body json.RawMessage `json:"body,omitempty"`
|
|
TS string `json:"ts"`
|
|
Meta json.RawMessage `json:"meta,omitempty"`
|
|
DBID int64 `json:"-"`
|
|
}
|
|
|
|
// isNumericCode returns true if s is exactly a 3-digit
|
|
// IRC numeric reply code.
|
|
func isNumericCode(s string) bool {
|
|
return len(s) == 3 &&
|
|
s[0] >= '0' && s[0] <= '9' &&
|
|
s[1] >= '0' && s[1] <= '9' &&
|
|
s[2] >= '0' && s[2] <= '9'
|
|
}
|
|
|
|
// ChannelInfo is a lightweight channel representation.
|
|
type ChannelInfo struct {
|
|
ID int64 `json:"id"`
|
|
Name string `json:"name"`
|
|
Topic string `json:"topic"`
|
|
}
|
|
|
|
// MemberInfo represents a channel member.
|
|
type MemberInfo struct {
|
|
ID int64 `json:"id"`
|
|
Nick string `json:"nick"`
|
|
Username string `json:"username"`
|
|
Hostname string `json:"hostname"`
|
|
IsOperator bool `json:"isOperator"`
|
|
IsVoiced bool `json:"isVoiced"`
|
|
LastSeen time.Time `json:"lastSeen"`
|
|
}
|
|
|
|
// Hostmask returns the IRC hostmask in
|
|
// nick!user@host format.
|
|
func (m *MemberInfo) Hostmask() string {
|
|
return FormatHostmask(m.Nick, m.Username, m.Hostname)
|
|
}
|
|
|
|
// FormatHostmask formats a nick, username, and hostname
|
|
// into a standard IRC hostmask string (nick!user@host).
|
|
func FormatHostmask(nick, username, hostname string) string {
|
|
if username == "" {
|
|
username = nick
|
|
}
|
|
|
|
if hostname == "" {
|
|
hostname = "*"
|
|
}
|
|
|
|
return nick + "!" + username + "@" + hostname
|
|
}
|
|
|
|
// CreateSession registers a new session and its first client.
|
|
func (database *Database) CreateSession(
|
|
ctx context.Context,
|
|
nick, username, hostname, remoteIP string,
|
|
) (int64, int64, string, error) {
|
|
if username == "" {
|
|
username = nick
|
|
}
|
|
|
|
sessionUUID := uuid.New().String()
|
|
clientUUID := uuid.New().String()
|
|
|
|
token, err := generateToken()
|
|
if err != nil {
|
|
return 0, 0, "", err
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
transaction, err := database.conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"begin tx: %w", err,
|
|
)
|
|
}
|
|
|
|
res, err := transaction.ExecContext(ctx,
|
|
`INSERT INTO sessions
|
|
(uuid, nick, username, hostname, ip,
|
|
created_at, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
sessionUUID, nick, username, hostname,
|
|
remoteIP, now, now)
|
|
if err != nil {
|
|
_ = transaction.Rollback()
|
|
|
|
return 0, 0, "", fmt.Errorf(
|
|
"create session: %w", err,
|
|
)
|
|
}
|
|
|
|
sessionID, _ := res.LastInsertId()
|
|
|
|
tokenHash := hashToken(token)
|
|
|
|
clientRes, err := transaction.ExecContext(ctx,
|
|
`INSERT INTO clients
|
|
(uuid, session_id, token, ip, hostname,
|
|
created_at, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
clientUUID, sessionID, tokenHash,
|
|
remoteIP, hostname, now, now)
|
|
if err != nil {
|
|
_ = transaction.Rollback()
|
|
|
|
return 0, 0, "", fmt.Errorf(
|
|
"create client: %w", err,
|
|
)
|
|
}
|
|
|
|
clientID, _ := clientRes.LastInsertId()
|
|
|
|
err = transaction.Commit()
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"commit session: %w", err,
|
|
)
|
|
}
|
|
|
|
return sessionID, clientID, token, nil
|
|
}
|
|
|
|
// GetSessionByToken returns session id, client id, and
|
|
// nick for a client token.
|
|
func (database *Database) GetSessionByToken(
|
|
ctx context.Context,
|
|
token string,
|
|
) (int64, int64, string, error) {
|
|
var (
|
|
sessionID int64
|
|
clientID int64
|
|
nick string
|
|
)
|
|
|
|
tokenHash := hashToken(token)
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT s.id, c.id, s.nick
|
|
FROM clients c
|
|
INNER JOIN sessions s
|
|
ON s.id = c.session_id
|
|
WHERE c.token = ?`,
|
|
tokenHash,
|
|
).Scan(&sessionID, &clientID, &nick)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"get session by token: %w", err,
|
|
)
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
_, _ = database.conn.ExecContext(
|
|
ctx,
|
|
"UPDATE sessions SET last_seen = ? WHERE id = ?",
|
|
now, sessionID,
|
|
)
|
|
|
|
_, _ = database.conn.ExecContext(
|
|
ctx,
|
|
"UPDATE clients SET last_seen = ? WHERE id = ?",
|
|
now, clientID,
|
|
)
|
|
|
|
return sessionID, clientID, nick, nil
|
|
}
|
|
|
|
// GetSessionByNick returns session id for a given nick.
|
|
func (database *Database) GetSessionByNick(
|
|
ctx context.Context,
|
|
nick string,
|
|
) (int64, error) {
|
|
var sessionID int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT id FROM sessions WHERE nick = ?",
|
|
nick,
|
|
).Scan(&sessionID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get session by nick: %w", err,
|
|
)
|
|
}
|
|
|
|
return sessionID, nil
|
|
}
|
|
|
|
// SessionHostInfo holds the username, hostname, and IP
|
|
// for a session.
|
|
type SessionHostInfo struct {
|
|
Username string
|
|
Hostname string
|
|
IP string
|
|
}
|
|
|
|
// GetSessionHostInfo returns the username, hostname,
|
|
// and IP for a session.
|
|
func (database *Database) GetSessionHostInfo(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (*SessionHostInfo, error) {
|
|
var info SessionHostInfo
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT username, hostname, ip
|
|
FROM sessions WHERE id = ?`,
|
|
sessionID,
|
|
).Scan(&info.Username, &info.Hostname, &info.IP)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get session host info: %w", err,
|
|
)
|
|
}
|
|
|
|
return &info, nil
|
|
}
|
|
|
|
// ClientHostInfo holds the IP and hostname for a client.
|
|
type ClientHostInfo struct {
|
|
IP string
|
|
Hostname string
|
|
}
|
|
|
|
// GetClientHostInfo returns the IP and hostname for a
|
|
// client.
|
|
func (database *Database) GetClientHostInfo(
|
|
ctx context.Context,
|
|
clientID int64,
|
|
) (*ClientHostInfo, error) {
|
|
var info ClientHostInfo
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT ip, hostname
|
|
FROM clients WHERE id = ?`,
|
|
clientID,
|
|
).Scan(&info.IP, &info.Hostname)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get client host info: %w", err,
|
|
)
|
|
}
|
|
|
|
return &info, nil
|
|
}
|
|
|
|
// SetSessionOper sets the is_oper flag on a session.
|
|
func (database *Database) SetSessionOper(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
isOper bool,
|
|
) error {
|
|
val := 0
|
|
if isOper {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(
|
|
ctx,
|
|
`UPDATE sessions SET is_oper = ? WHERE id = ?`,
|
|
val, sessionID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("set session oper: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsSessionOper returns whether the session has oper
|
|
// status.
|
|
func (database *Database) IsSessionOper(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (bool, error) {
|
|
var isOper int
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT is_oper FROM sessions WHERE id = ?`,
|
|
sessionID,
|
|
).Scan(&isOper)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check session oper: %w", err,
|
|
)
|
|
}
|
|
|
|
return isOper != 0, nil
|
|
}
|
|
|
|
// GetLatestClientForSession returns the IP and hostname
|
|
// of the most recently created client for a session.
|
|
func (database *Database) GetLatestClientForSession(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (*ClientHostInfo, error) {
|
|
var info ClientHostInfo
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT ip, hostname FROM clients
|
|
WHERE session_id = ?
|
|
ORDER BY created_at DESC LIMIT 1`,
|
|
sessionID,
|
|
).Scan(&info.IP, &info.Hostname)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get latest client for session: %w", err,
|
|
)
|
|
}
|
|
|
|
return &info, nil
|
|
}
|
|
|
|
// GetChannelByName returns the channel ID for a name.
|
|
func (database *Database) GetChannelByName(
|
|
ctx context.Context,
|
|
name string,
|
|
) (int64, error) {
|
|
var channelID int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT id FROM channels WHERE name = ?",
|
|
name,
|
|
).Scan(&channelID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get channel by name: %w", err,
|
|
)
|
|
}
|
|
|
|
return channelID, nil
|
|
}
|
|
|
|
// GetOrCreateChannel returns channel id, creating if needed.
|
|
// Uses INSERT OR IGNORE to avoid TOCTOU races.
|
|
func (database *Database) GetOrCreateChannel(
|
|
ctx context.Context,
|
|
name string,
|
|
) (int64, error) {
|
|
now := time.Now()
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO channels
|
|
(name, created_at, updated_at)
|
|
VALUES (?, ?, ?)`,
|
|
name, now, now)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("create channel: %w", err)
|
|
}
|
|
|
|
var channelID int64
|
|
|
|
err = database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT id FROM channels WHERE name = ?",
|
|
name,
|
|
).Scan(&channelID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("get channel: %w", err)
|
|
}
|
|
|
|
return channelID, nil
|
|
}
|
|
|
|
// JoinChannel adds a session to a channel.
|
|
func (database *Database) JoinChannel(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO channel_members
|
|
(channel_id, session_id, joined_at)
|
|
VALUES (?, ?, ?)`,
|
|
channelID, sessionID, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("join channel: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// JoinChannelAsOperator adds a session to a channel with
|
|
// operator status. Used when a user creates a new channel.
|
|
func (database *Database) JoinChannelAsOperator(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO channel_members
|
|
(channel_id, session_id, is_operator, joined_at)
|
|
VALUES (?, ?, 1, ?)`,
|
|
channelID, sessionID, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"join channel as operator: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CountChannelMembers returns the number of members in
|
|
// a channel.
|
|
func (database *Database) CountChannelMembers(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT COUNT(*) FROM channel_members
|
|
WHERE channel_id = ?`,
|
|
channelID,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"count channel members: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// IsChannelOperator checks if a session has operator
|
|
// status in a channel.
|
|
func (database *Database) IsChannelOperator(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) (bool, error) {
|
|
var isOp int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_operator FROM channel_members
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
channelID, sessionID,
|
|
).Scan(&isOp)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check channel operator: %w", err,
|
|
)
|
|
}
|
|
|
|
return isOp != 0, nil
|
|
}
|
|
|
|
// IsChannelVoiced checks if a session has voice status
|
|
// in a channel.
|
|
func (database *Database) IsChannelVoiced(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) (bool, error) {
|
|
var isVoiced int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_voiced FROM channel_members
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
channelID, sessionID,
|
|
).Scan(&isVoiced)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check channel voiced: %w", err,
|
|
)
|
|
}
|
|
|
|
return isVoiced != 0, nil
|
|
}
|
|
|
|
// SetChannelMemberOperator sets or clears operator status
|
|
// for a session in a channel.
|
|
func (database *Database) SetChannelMemberOperator(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
isOp bool,
|
|
) error {
|
|
val := 0
|
|
if isOp {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channel_members
|
|
SET is_operator = ?
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
val, channelID, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set channel member operator: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetChannelMemberVoiced sets or clears voice status
|
|
// for a session in a channel.
|
|
func (database *Database) SetChannelMemberVoiced(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
isVoiced bool,
|
|
) error {
|
|
val := 0
|
|
if isVoiced {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channel_members
|
|
SET is_voiced = ?
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
val, channelID, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set channel member voiced: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsChannelModerated returns whether a channel has +m set.
|
|
func (database *Database) IsChannelModerated(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (bool, error) {
|
|
var isMod int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_moderated FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&isMod)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check channel moderated: %w", err,
|
|
)
|
|
}
|
|
|
|
return isMod != 0, nil
|
|
}
|
|
|
|
// SetChannelModerated sets or clears +m on a channel.
|
|
func (database *Database) SetChannelModerated(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
moderated bool,
|
|
) error {
|
|
val := 0
|
|
if moderated {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET is_moderated = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
val, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set channel moderated: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsChannelTopicLocked returns whether a channel has
|
|
// +t set.
|
|
func (database *Database) IsChannelTopicLocked(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (bool, error) {
|
|
var isLocked int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_topic_locked FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&isLocked)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check channel topic locked: %w", err,
|
|
)
|
|
}
|
|
|
|
return isLocked != 0, nil
|
|
}
|
|
|
|
// SetChannelTopicLocked sets or clears +t on a channel.
|
|
func (database *Database) SetChannelTopicLocked(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
locked bool,
|
|
) error {
|
|
val := 0
|
|
if locked {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET is_topic_locked = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
val, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set channel topic locked: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// PartChannel removes a session from a channel.
|
|
func (database *Database) PartChannel(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`DELETE FROM channel_members
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
channelID, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("part channel: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteChannelIfEmpty removes a channel with no members.
|
|
func (database *Database) DeleteChannelIfEmpty(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`DELETE FROM channels WHERE id = ?
|
|
AND NOT EXISTS
|
|
(SELECT 1 FROM channel_members
|
|
WHERE channel_id = ?)`,
|
|
channelID, channelID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"delete channel if empty: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// scanChannels scans rows into a ChannelInfo slice.
|
|
func scanChannels(
|
|
rows *sql.Rows,
|
|
) ([]ChannelInfo, error) {
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var out []ChannelInfo
|
|
|
|
for rows.Next() {
|
|
var chanInfo ChannelInfo
|
|
|
|
err := rows.Scan(
|
|
&chanInfo.ID, &chanInfo.Name, &chanInfo.Topic,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("scan channel: %w", err)
|
|
}
|
|
|
|
out = append(out, chanInfo)
|
|
}
|
|
|
|
err := rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rows error: %w", err)
|
|
}
|
|
|
|
if out == nil {
|
|
out = []ChannelInfo{}
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
// ListChannels returns channels the session has joined.
|
|
func (database *Database) ListChannels(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) ([]ChannelInfo, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT c.id, c.name, c.topic
|
|
FROM channels c
|
|
INNER JOIN channel_members cm
|
|
ON cm.channel_id = c.id
|
|
WHERE cm.session_id = ?
|
|
ORDER BY c.name`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list channels: %w", err)
|
|
}
|
|
|
|
return scanChannels(rows)
|
|
}
|
|
|
|
// ListAllChannels returns every channel.
|
|
func (database *Database) ListAllChannels(
|
|
ctx context.Context,
|
|
) ([]ChannelInfo, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT id, name, topic
|
|
FROM channels ORDER BY name`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"list all channels: %w", err,
|
|
)
|
|
}
|
|
|
|
return scanChannels(rows)
|
|
}
|
|
|
|
// ChannelMembers returns all members of a channel.
|
|
func (database *Database) ChannelMembers(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) ([]MemberInfo, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT s.id, s.nick, s.username,
|
|
s.hostname, cm.is_operator, cm.is_voiced,
|
|
s.last_seen
|
|
FROM sessions s
|
|
INNER JOIN channel_members cm
|
|
ON cm.session_id = s.id
|
|
WHERE cm.channel_id = ?
|
|
ORDER BY s.nick`, channelID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"query channel members: %w", err,
|
|
)
|
|
}
|
|
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var members []MemberInfo
|
|
|
|
for rows.Next() {
|
|
var (
|
|
member MemberInfo
|
|
isOp int
|
|
isV int
|
|
)
|
|
|
|
err = rows.Scan(
|
|
&member.ID, &member.Nick,
|
|
&member.Username, &member.Hostname,
|
|
&isOp, &isV,
|
|
&member.LastSeen,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan member: %w", err,
|
|
)
|
|
}
|
|
|
|
member.IsOperator = isOp != 0
|
|
member.IsVoiced = isV != 0
|
|
|
|
members = append(members, member)
|
|
}
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rows error: %w", err)
|
|
}
|
|
|
|
if members == nil {
|
|
members = []MemberInfo{}
|
|
}
|
|
|
|
return members, nil
|
|
}
|
|
|
|
// IsChannelMember checks if a session belongs to a channel.
|
|
func (database *Database) IsChannelMember(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) (bool, error) {
|
|
var count int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT COUNT(*) FROM channel_members
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
channelID, sessionID,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check membership: %w", err,
|
|
)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
// scanInt64s scans rows into an int64 slice.
|
|
func scanInt64s(rows *sql.Rows) ([]int64, error) {
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var ids []int64
|
|
|
|
for rows.Next() {
|
|
var val int64
|
|
|
|
err := rows.Scan(&val)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan int64: %w", err,
|
|
)
|
|
}
|
|
|
|
ids = append(ids, val)
|
|
}
|
|
|
|
err := rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rows error: %w", err)
|
|
}
|
|
|
|
return ids, nil
|
|
}
|
|
|
|
// GetChannelMemberIDs returns session IDs in a channel.
|
|
func (database *Database) GetChannelMemberIDs(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) ([]int64, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT session_id FROM channel_members
|
|
WHERE channel_id = ?`, channelID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get channel member ids: %w", err,
|
|
)
|
|
}
|
|
|
|
return scanInt64s(rows)
|
|
}
|
|
|
|
// GetSessionChannelIDs returns channel IDs for a session.
|
|
func (database *Database) GetSessionChannelIDs(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) ([]int64, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT channel_id FROM channel_members
|
|
WHERE session_id = ?`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get session channel ids: %w", err,
|
|
)
|
|
}
|
|
|
|
return scanInt64s(rows)
|
|
}
|
|
|
|
// InsertMessage stores a message and returns its DB ID.
|
|
func (database *Database) InsertMessage(
|
|
ctx context.Context,
|
|
command, from, target string,
|
|
params json.RawMessage,
|
|
body json.RawMessage,
|
|
meta json.RawMessage,
|
|
) (int64, string, error) {
|
|
msgUUID := uuid.New().String()
|
|
now := time.Now().UTC()
|
|
|
|
if params == nil {
|
|
params = json.RawMessage("[]")
|
|
}
|
|
|
|
if body == nil {
|
|
body = json.RawMessage("[]")
|
|
}
|
|
|
|
if meta == nil {
|
|
meta = json.RawMessage("{}")
|
|
}
|
|
|
|
res, err := database.conn.ExecContext(ctx,
|
|
`INSERT INTO messages
|
|
(uuid, command, msg_from, msg_to,
|
|
params, body, meta, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
msgUUID, command, from, target,
|
|
string(params), string(body), string(meta), now)
|
|
if err != nil {
|
|
return 0, "", fmt.Errorf(
|
|
"insert message: %w", err,
|
|
)
|
|
}
|
|
|
|
dbID, _ := res.LastInsertId()
|
|
|
|
return dbID, msgUUID, nil
|
|
}
|
|
|
|
// EnqueueToSession adds a message to all clients of a
|
|
// session's queues.
|
|
func (database *Database) EnqueueToSession(
|
|
ctx context.Context,
|
|
sessionID, messageID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO client_queues
|
|
(client_id, message_id, created_at)
|
|
SELECT c.id, ?, ?
|
|
FROM clients c
|
|
WHERE c.session_id = ?`,
|
|
messageID, time.Now(), sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"enqueue to session: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EnqueueToClient adds a message to a specific client's
|
|
// queue.
|
|
func (database *Database) EnqueueToClient(
|
|
ctx context.Context,
|
|
clientID, messageID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO client_queues
|
|
(client_id, message_id, created_at)
|
|
VALUES (?, ?, ?)`,
|
|
clientID, messageID, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"enqueue to client: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// PollMessages returns queued messages for a client.
|
|
func (database *Database) PollMessages(
|
|
ctx context.Context,
|
|
clientID, afterQueueID int64,
|
|
limit int,
|
|
) ([]IRCMessage, int64, error) {
|
|
if limit <= 0 {
|
|
limit = defaultPollLimit
|
|
}
|
|
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT cq.id, m.uuid, m.command,
|
|
m.msg_from, m.msg_to,
|
|
m.params, m.body, m.meta, m.created_at
|
|
FROM client_queues cq
|
|
INNER JOIN messages m
|
|
ON m.id = cq.message_id
|
|
WHERE cq.client_id = ? AND cq.id > ?
|
|
ORDER BY cq.id ASC LIMIT ?`,
|
|
clientID, afterQueueID, limit)
|
|
if err != nil {
|
|
return nil, afterQueueID, fmt.Errorf(
|
|
"poll messages: %w", err,
|
|
)
|
|
}
|
|
|
|
msgs, lastQID, scanErr := scanMessages(
|
|
rows, afterQueueID,
|
|
)
|
|
if scanErr != nil {
|
|
return nil, afterQueueID, scanErr
|
|
}
|
|
|
|
return msgs, lastQID, nil
|
|
}
|
|
|
|
// GetHistory returns message history for a target.
|
|
func (database *Database) GetHistory(
|
|
ctx context.Context,
|
|
target string,
|
|
beforeID int64,
|
|
limit int,
|
|
) ([]IRCMessage, error) {
|
|
if limit <= 0 {
|
|
limit = defaultHistLimit
|
|
}
|
|
|
|
rows, err := database.queryHistory(
|
|
ctx, target, beforeID, limit,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
msgs, _, scanErr := scanMessages(rows, 0)
|
|
if scanErr != nil {
|
|
return nil, scanErr
|
|
}
|
|
|
|
if msgs == nil {
|
|
msgs = []IRCMessage{}
|
|
}
|
|
|
|
reverseMessages(msgs)
|
|
|
|
return msgs, nil
|
|
}
|
|
|
|
func (database *Database) queryHistory(
|
|
ctx context.Context,
|
|
target string,
|
|
beforeID int64,
|
|
limit int,
|
|
) (*sql.Rows, error) {
|
|
if beforeID > 0 {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT id, uuid, command, msg_from,
|
|
msg_to, params, body, meta, created_at
|
|
FROM messages
|
|
WHERE msg_to = ? AND id < ?
|
|
AND command = 'PRIVMSG'
|
|
ORDER BY id DESC LIMIT ?`,
|
|
target, beforeID, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"query history: %w", err,
|
|
)
|
|
}
|
|
|
|
return rows, nil
|
|
}
|
|
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT id, uuid, command, msg_from,
|
|
msg_to, params, body, meta, created_at
|
|
FROM messages
|
|
WHERE msg_to = ?
|
|
AND command = 'PRIVMSG'
|
|
ORDER BY id DESC LIMIT ?`,
|
|
target, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query history: %w", err)
|
|
}
|
|
|
|
return rows, nil
|
|
}
|
|
|
|
func scanMessages(
|
|
rows *sql.Rows,
|
|
fallbackQID int64,
|
|
) ([]IRCMessage, int64, error) {
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var msgs []IRCMessage
|
|
|
|
lastQID := fallbackQID
|
|
|
|
for rows.Next() {
|
|
var (
|
|
msg IRCMessage
|
|
qID int64
|
|
params, body, meta string
|
|
createdAt time.Time
|
|
)
|
|
|
|
err := rows.Scan(
|
|
&qID, &msg.ID, &msg.Command,
|
|
&msg.From, &msg.To,
|
|
¶ms, &body, &meta, &createdAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fallbackQID, fmt.Errorf(
|
|
"scan message: %w", err,
|
|
)
|
|
}
|
|
|
|
if params != "" && params != "[]" {
|
|
msg.Params = json.RawMessage(params)
|
|
}
|
|
|
|
msg.Body = json.RawMessage(body)
|
|
msg.Meta = json.RawMessage(meta)
|
|
msg.TS = createdAt.Format(time.RFC3339Nano)
|
|
msg.DBID = qID
|
|
lastQID = qID
|
|
|
|
if isNumericCode(msg.Command) {
|
|
code, _ := strconv.Atoi(msg.Command)
|
|
msg.Code = code
|
|
|
|
if mt, err := irc.FromInt(code); err == nil {
|
|
msg.Command = mt.Name()
|
|
}
|
|
}
|
|
|
|
msgs = append(msgs, msg)
|
|
}
|
|
|
|
err := rows.Err()
|
|
if err != nil {
|
|
return nil, fallbackQID, fmt.Errorf(
|
|
"rows error: %w", err,
|
|
)
|
|
}
|
|
|
|
if msgs == nil {
|
|
msgs = []IRCMessage{}
|
|
}
|
|
|
|
return msgs, lastQID, nil
|
|
}
|
|
|
|
func reverseMessages(msgs []IRCMessage) {
|
|
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
|
msgs[i], msgs[j] = msgs[j], msgs[i]
|
|
}
|
|
}
|
|
|
|
// ChangeNick updates a session's nickname.
|
|
func (database *Database) ChangeNick(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
newNick string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
"UPDATE sessions SET nick = ? WHERE id = ?",
|
|
newNick, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("change nick: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetTopic sets the topic for a channel.
|
|
func (database *Database) SetTopic(
|
|
ctx context.Context,
|
|
channelName, topic string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels SET topic = ?,
|
|
updated_at = ? WHERE name = ?`,
|
|
topic, time.Now(), channelName)
|
|
if err != nil {
|
|
return fmt.Errorf("set topic: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteSession removes a session and all its data.
|
|
func (database *Database) DeleteSession(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(
|
|
ctx,
|
|
"DELETE FROM sessions WHERE id = ?",
|
|
sessionID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("delete session: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteClient removes a single client record by ID.
|
|
func (database *Database) DeleteClient(
|
|
ctx context.Context,
|
|
clientID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(
|
|
ctx,
|
|
"DELETE FROM clients WHERE id = ?",
|
|
clientID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("delete client: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetUserCount returns the number of active users.
|
|
func (database *Database) GetUserCount(
|
|
ctx context.Context,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT COUNT(*) FROM sessions",
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get user count: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// GetOperCount returns the number of sessions with oper
|
|
// status.
|
|
func (database *Database) GetOperCount(
|
|
ctx context.Context,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT COUNT(*) FROM sessions WHERE is_oper = 1",
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get oper count: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// ClientCountForSession returns the number of clients
|
|
// belonging to a session.
|
|
func (database *Database) ClientCountForSession(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT COUNT(*) FROM clients
|
|
WHERE session_id = ?`,
|
|
sessionID,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"client count for session: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// DeleteStaleUsers removes clients not seen since the
|
|
// cutoff and cleans up orphaned users (sessions).
|
|
func (database *Database) DeleteStaleUsers(
|
|
ctx context.Context,
|
|
cutoff time.Time,
|
|
) (int64, error) {
|
|
res, err := database.conn.ExecContext(ctx,
|
|
"DELETE FROM clients WHERE last_seen < ?",
|
|
cutoff,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"delete stale clients: %w", err,
|
|
)
|
|
}
|
|
|
|
deleted, _ := res.RowsAffected()
|
|
|
|
_, err = database.conn.ExecContext(ctx,
|
|
`DELETE FROM sessions WHERE id NOT IN
|
|
(SELECT DISTINCT session_id FROM clients)`,
|
|
)
|
|
if err != nil {
|
|
return deleted, fmt.Errorf(
|
|
"delete orphan sessions: %w", err,
|
|
)
|
|
}
|
|
|
|
return deleted, nil
|
|
}
|
|
|
|
// StaleSession holds the id and nick of a session
|
|
// whose clients are all stale.
|
|
type StaleSession struct {
|
|
ID int64
|
|
Nick string
|
|
}
|
|
|
|
// GetStaleOrphanSessions returns sessions where every
|
|
// client has a last_seen before cutoff.
|
|
func (database *Database) GetStaleOrphanSessions(
|
|
ctx context.Context,
|
|
cutoff time.Time,
|
|
) ([]StaleSession, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT s.id, s.nick
|
|
FROM sessions s
|
|
WHERE s.id IN (
|
|
SELECT DISTINCT session_id FROM clients
|
|
WHERE last_seen < ?
|
|
)
|
|
AND s.id NOT IN (
|
|
SELECT DISTINCT session_id FROM clients
|
|
WHERE last_seen >= ?
|
|
)`, cutoff, cutoff)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get stale orphan sessions: %w", err,
|
|
)
|
|
}
|
|
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var result []StaleSession
|
|
|
|
for rows.Next() {
|
|
var stale StaleSession
|
|
if err := rows.Scan(&stale.ID, &stale.Nick); err != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan stale session: %w", err,
|
|
)
|
|
}
|
|
|
|
result = append(result, stale)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf(
|
|
"iterate stale sessions: %w", err,
|
|
)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// GetSessionChannels returns channels a session
|
|
// belongs to.
|
|
func (database *Database) GetSessionChannels(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) ([]ChannelInfo, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT c.id, c.name, c.topic
|
|
FROM channels c
|
|
INNER JOIN channel_members cm
|
|
ON cm.channel_id = c.id
|
|
WHERE cm.session_id = ?`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get session channels: %w", err,
|
|
)
|
|
}
|
|
|
|
return scanChannels(rows)
|
|
}
|
|
|
|
// GetChannelCount returns the total number of channels.
|
|
func (database *Database) GetChannelCount(
|
|
ctx context.Context,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT COUNT(*) FROM channels",
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get channel count: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// ChannelInfoFull contains extended channel information.
|
|
type ChannelInfoFull struct {
|
|
ID int64 `json:"id"`
|
|
Name string `json:"name"`
|
|
Topic string `json:"topic"`
|
|
MemberCount int64 `json:"memberCount"`
|
|
}
|
|
|
|
// ListAllChannelsWithCounts returns every channel
|
|
// with its member count.
|
|
func (database *Database) ListAllChannelsWithCounts(
|
|
ctx context.Context,
|
|
) ([]ChannelInfoFull, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT c.id, c.name, c.topic,
|
|
COUNT(cm.session_id) AS member_count
|
|
FROM channels c
|
|
LEFT JOIN channel_members cm
|
|
ON cm.channel_id = c.id
|
|
GROUP BY c.id
|
|
ORDER BY c.name`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"list channels with counts: %w", err,
|
|
)
|
|
}
|
|
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var out []ChannelInfoFull
|
|
|
|
for rows.Next() {
|
|
var chanInfo ChannelInfoFull
|
|
|
|
err = rows.Scan(
|
|
&chanInfo.ID, &chanInfo.Name,
|
|
&chanInfo.Topic, &chanInfo.MemberCount,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan channel full: %w", err,
|
|
)
|
|
}
|
|
|
|
out = append(out, chanInfo)
|
|
}
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rows error: %w", err)
|
|
}
|
|
|
|
if out == nil {
|
|
out = []ChannelInfoFull{}
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
// GetChannelCreatedAt returns the creation time of a
|
|
// channel.
|
|
func (database *Database) GetChannelCreatedAt(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (time.Time, error) {
|
|
var createdAt time.Time
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT created_at FROM channels WHERE id = ?",
|
|
channelID,
|
|
).Scan(&createdAt)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf(
|
|
"get channel created_at: %w", err,
|
|
)
|
|
}
|
|
|
|
return createdAt, nil
|
|
}
|
|
|
|
// GetSessionCreatedAt returns the creation time of a
|
|
// session.
|
|
func (database *Database) GetSessionCreatedAt(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (time.Time, error) {
|
|
var createdAt time.Time
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT created_at FROM sessions WHERE id = ?",
|
|
sessionID,
|
|
).Scan(&createdAt)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf(
|
|
"get session created_at: %w", err,
|
|
)
|
|
}
|
|
|
|
return createdAt, nil
|
|
}
|
|
|
|
// SetAway sets the away message for a session.
|
|
// An empty message clears the away status.
|
|
func (database *Database) SetAway(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
message string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
"UPDATE sessions SET away_message = ? WHERE id = ?",
|
|
message, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("set away: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAway returns the away message for a session.
|
|
// Returns an empty string if the user is not away.
|
|
func (database *Database) GetAway(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (string, error) {
|
|
var msg string
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
"SELECT away_message FROM sessions WHERE id = ?",
|
|
sessionID,
|
|
).Scan(&msg)
|
|
if err != nil {
|
|
return "", fmt.Errorf("get away: %w", err)
|
|
}
|
|
|
|
return msg, nil
|
|
}
|
|
|
|
// SetTopicMeta sets the topic along with who set it and
|
|
// when.
|
|
func (database *Database) SetTopicMeta(
|
|
ctx context.Context,
|
|
channelName, topic, setBy string,
|
|
) error {
|
|
now := time.Now()
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET topic = ?, topic_set_by = ?,
|
|
topic_set_at = ?, updated_at = ?
|
|
WHERE name = ?`,
|
|
topic, setBy, now, now, channelName)
|
|
if err != nil {
|
|
return fmt.Errorf("set topic meta: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// TopicMeta holds topic metadata for a channel.
|
|
type TopicMeta struct {
|
|
SetBy string
|
|
SetAt time.Time
|
|
}
|
|
|
|
// GetTopicMeta returns who set the topic and when.
|
|
func (database *Database) GetTopicMeta(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (*TopicMeta, error) {
|
|
var (
|
|
setBy string
|
|
setAt sql.NullTime
|
|
)
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT topic_set_by, topic_set_at
|
|
FROM channels WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&setBy, &setAt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get topic meta: %w", err,
|
|
)
|
|
}
|
|
|
|
if setBy == "" || !setAt.Valid {
|
|
return nil, nil //nolint:nilnil
|
|
}
|
|
|
|
return &TopicMeta{
|
|
SetBy: setBy,
|
|
SetAt: setAt.Time,
|
|
}, nil
|
|
}
|
|
|
|
// GetSessionLastSeen returns the last_seen time for a
|
|
// session.
|
|
func (database *Database) GetSessionLastSeen(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) (time.Time, error) {
|
|
var lastSeen time.Time
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
"SELECT last_seen FROM sessions WHERE id = ?",
|
|
sessionID,
|
|
).Scan(&lastSeen)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf(
|
|
"get session last_seen: %w", err,
|
|
)
|
|
}
|
|
|
|
return lastSeen, nil
|
|
}
|
|
|
|
// PruneOldQueueEntries deletes client output queue entries
|
|
// older than cutoff and returns the number of rows removed.
|
|
func (database *Database) PruneOldQueueEntries(
|
|
ctx context.Context,
|
|
cutoff time.Time,
|
|
) (int64, error) {
|
|
res, err := database.conn.ExecContext(ctx,
|
|
"DELETE FROM client_queues WHERE created_at < ?",
|
|
cutoff,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"prune old client output queue entries: %w", err,
|
|
)
|
|
}
|
|
|
|
deleted, _ := res.RowsAffected()
|
|
|
|
return deleted, nil
|
|
}
|
|
|
|
// PruneOldMessages deletes messages older than cutoff and
|
|
// returns the number of rows removed.
|
|
func (database *Database) PruneOldMessages(
|
|
ctx context.Context,
|
|
cutoff time.Time,
|
|
) (int64, error) {
|
|
res, err := database.conn.ExecContext(ctx,
|
|
"DELETE FROM messages WHERE created_at < ?",
|
|
cutoff,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"prune old messages: %w", err,
|
|
)
|
|
}
|
|
|
|
deleted, _ := res.RowsAffected()
|
|
|
|
return deleted, nil
|
|
}
|
|
|
|
// GetClientCount returns the total number of clients.
|
|
func (database *Database) GetClientCount(
|
|
ctx context.Context,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT COUNT(*) FROM clients",
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get client count: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// GetQueueEntryCount returns the total number of entries
|
|
// in the client output queues.
|
|
func (database *Database) GetQueueEntryCount(
|
|
ctx context.Context,
|
|
) (int64, error) {
|
|
var count int64
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT COUNT(*) FROM client_queues",
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get queue entry count: %w", err,
|
|
)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// GetChannelHashcashBits returns the hashcash difficulty
|
|
// requirement for a channel. Returns 0 if not set.
|
|
func (database *Database) GetChannelHashcashBits(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (int, error) {
|
|
var bits int
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
"SELECT hashcash_bits FROM channels WHERE id = ?",
|
|
channelID,
|
|
).Scan(&bits)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get channel hashcash bits: %w", err,
|
|
)
|
|
}
|
|
|
|
return bits, nil
|
|
}
|
|
|
|
// SetChannelHashcashBits sets the hashcash difficulty
|
|
// requirement for a channel. A value of 0 disables the
|
|
// requirement.
|
|
func (database *Database) SetChannelHashcashBits(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
bits int,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET hashcash_bits = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
bits, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set channel hashcash bits: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RecordSpentHashcash stores a spent hashcash stamp hash
|
|
// for replay prevention.
|
|
func (database *Database) RecordSpentHashcash(
|
|
ctx context.Context,
|
|
stampHash string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO spent_hashcash
|
|
(stamp_hash, created_at)
|
|
VALUES (?, ?)`,
|
|
stampHash, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"record spent hashcash: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsHashcashSpent checks whether a hashcash stamp hash
|
|
// has already been used.
|
|
func (database *Database) IsHashcashSpent(
|
|
ctx context.Context,
|
|
stampHash string,
|
|
) (bool, error) {
|
|
var count int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT COUNT(*) FROM spent_hashcash
|
|
WHERE stamp_hash = ?`,
|
|
stampHash,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check spent hashcash: %w", err,
|
|
)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
// PruneSpentHashcash deletes spent hashcash tokens older
|
|
// than the cutoff and returns the number of rows removed.
|
|
func (database *Database) PruneSpentHashcash(
|
|
ctx context.Context,
|
|
cutoff time.Time,
|
|
) (int64, error) {
|
|
res, err := database.conn.ExecContext(ctx,
|
|
"DELETE FROM spent_hashcash WHERE created_at < ?",
|
|
cutoff,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"prune spent hashcash: %w", err,
|
|
)
|
|
}
|
|
|
|
deleted, _ := res.RowsAffected()
|
|
|
|
return deleted, nil
|
|
}
|
|
|
|
// --- Tier 2: Ban system (+b) ---
|
|
|
|
// BanInfo represents a channel ban entry.
|
|
type BanInfo struct {
|
|
Mask string
|
|
SetBy string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
// AddChannelBan inserts a ban mask for a channel.
|
|
func (database *Database) AddChannelBan(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
mask, setBy string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO channel_bans
|
|
(channel_id, mask, set_by, created_at)
|
|
VALUES (?, ?, ?, ?)`,
|
|
channelID, mask, setBy, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("add channel ban: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveChannelBan removes a ban mask from a channel.
|
|
func (database *Database) RemoveChannelBan(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
mask string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`DELETE FROM channel_bans
|
|
WHERE channel_id = ? AND mask = ?`,
|
|
channelID, mask)
|
|
if err != nil {
|
|
return fmt.Errorf("remove channel ban: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListChannelBans returns all bans for a channel.
|
|
//
|
|
//nolint:dupl // different query+type vs filtered variant
|
|
func (database *Database) ListChannelBans(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) ([]BanInfo, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT mask, set_by, created_at
|
|
FROM channel_bans
|
|
WHERE channel_id = ?
|
|
ORDER BY created_at ASC`,
|
|
channelID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list channel bans: %w", err)
|
|
}
|
|
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var bans []BanInfo
|
|
|
|
for rows.Next() {
|
|
var ban BanInfo
|
|
|
|
if scanErr := rows.Scan(
|
|
&ban.Mask, &ban.SetBy, &ban.CreatedAt,
|
|
); scanErr != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan channel ban: %w", scanErr,
|
|
)
|
|
}
|
|
|
|
bans = append(bans, ban)
|
|
}
|
|
|
|
if rowErr := rows.Err(); rowErr != nil {
|
|
return nil, fmt.Errorf(
|
|
"iterate channel bans: %w", rowErr,
|
|
)
|
|
}
|
|
|
|
return bans, nil
|
|
}
|
|
|
|
// IsSessionBanned checks if a session's hostmask matches
|
|
// any ban in the channel. Returns true if banned.
|
|
func (database *Database) IsSessionBanned(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) (bool, error) {
|
|
// Get the session's hostmask parts.
|
|
var nick, username, hostname string
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT nick, username, hostname
|
|
FROM sessions WHERE id = ?`,
|
|
sessionID,
|
|
).Scan(&nick, &username, &hostname)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"get session hostmask: %w", err,
|
|
)
|
|
}
|
|
|
|
hostmask := FormatHostmask(nick, username, hostname)
|
|
|
|
// Get all ban masks for the channel.
|
|
bans, banErr := database.ListChannelBans(ctx, channelID)
|
|
if banErr != nil {
|
|
return false, banErr
|
|
}
|
|
|
|
for _, ban := range bans {
|
|
if MatchBanMask(ban.Mask, hostmask) {
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// MatchBanMask checks if hostmask matches a ban pattern
|
|
// using IRC-style wildcard matching (* and ?).
|
|
func MatchBanMask(pattern, hostmask string) bool {
|
|
return wildcardMatch(
|
|
strings.ToLower(pattern),
|
|
strings.ToLower(hostmask),
|
|
)
|
|
}
|
|
|
|
// wildcardMatch implements simple glob-style matching
|
|
// with * (any sequence) and ? (any single character).
|
|
func wildcardMatch(pattern, str string) bool {
|
|
for len(pattern) > 0 {
|
|
switch pattern[0] {
|
|
case '*':
|
|
// Skip consecutive asterisks.
|
|
for len(pattern) > 0 && pattern[0] == '*' {
|
|
pattern = pattern[1:]
|
|
}
|
|
|
|
if len(pattern) == 0 {
|
|
return true
|
|
}
|
|
|
|
for i := 0; i <= len(str); i++ {
|
|
if wildcardMatch(pattern, str[i:]) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
case '?':
|
|
if len(str) == 0 {
|
|
return false
|
|
}
|
|
|
|
pattern = pattern[1:]
|
|
str = str[1:]
|
|
default:
|
|
if len(str) == 0 || pattern[0] != str[0] {
|
|
return false
|
|
}
|
|
|
|
pattern = pattern[1:]
|
|
str = str[1:]
|
|
}
|
|
}
|
|
|
|
return len(str) == 0
|
|
}
|
|
|
|
// --- Tier 2: Invite-only (+i) ---
|
|
|
|
// IsChannelInviteOnly checks if a channel has +i mode.
|
|
func (database *Database) IsChannelInviteOnly(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (bool, error) {
|
|
var isInviteOnly int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_invite_only FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&isInviteOnly)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check invite only: %w", err,
|
|
)
|
|
}
|
|
|
|
return isInviteOnly != 0, nil
|
|
}
|
|
|
|
// SetChannelInviteOnly sets or unsets +i mode.
|
|
func (database *Database) SetChannelInviteOnly(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
inviteOnly bool,
|
|
) error {
|
|
val := 0
|
|
if inviteOnly {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET is_invite_only = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
val, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set invite only: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddChannelInvite records that a session has been
|
|
// invited to a channel.
|
|
func (database *Database) AddChannelInvite(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
invitedBy string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO channel_invites
|
|
(channel_id, session_id, invited_by, created_at)
|
|
VALUES (?, ?, ?, ?)`,
|
|
channelID, sessionID, invitedBy, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("add channel invite: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HasChannelInvite checks if a session has been invited
|
|
// to a channel.
|
|
func (database *Database) HasChannelInvite(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) (bool, error) {
|
|
var count int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT COUNT(*) FROM channel_invites
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
channelID, sessionID,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check invite: %w", err,
|
|
)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
// ClearChannelInvite removes a session's invite to a
|
|
// channel (called after successful JOIN).
|
|
func (database *Database) ClearChannelInvite(
|
|
ctx context.Context,
|
|
channelID, sessionID int64,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`DELETE FROM channel_invites
|
|
WHERE channel_id = ? AND session_id = ?`,
|
|
channelID, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("clear invite: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// --- Tier 2: Secret (+s) ---
|
|
|
|
// IsChannelSecret checks if a channel has +s mode.
|
|
func (database *Database) IsChannelSecret(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (bool, error) {
|
|
var isSecret int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_secret FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&isSecret)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check secret: %w", err,
|
|
)
|
|
}
|
|
|
|
return isSecret != 0, nil
|
|
}
|
|
|
|
// SetChannelSecret sets or unsets +s mode.
|
|
func (database *Database) SetChannelSecret(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
secret bool,
|
|
) error {
|
|
val := 0
|
|
if secret {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET is_secret = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
val, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf("set secret: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// --- No External Messages (+n) ---
|
|
|
|
// IsChannelNoExternal checks if a channel has +n mode.
|
|
func (database *Database) IsChannelNoExternal(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (bool, error) {
|
|
var isNoExternal int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT is_no_external FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&isNoExternal)
|
|
if err != nil {
|
|
return false, fmt.Errorf(
|
|
"check no external: %w", err,
|
|
)
|
|
}
|
|
|
|
return isNoExternal != 0, nil
|
|
}
|
|
|
|
// SetChannelNoExternal sets or unsets +n mode.
|
|
func (database *Database) SetChannelNoExternal(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
noExternal bool,
|
|
) error {
|
|
val := 0
|
|
if noExternal {
|
|
val = 1
|
|
}
|
|
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET is_no_external = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
val, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf("set no external: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListAllChannelsWithCountsFiltered returns all channels
|
|
// with member counts, excluding secret channels that
|
|
// the given session is not a member of.
|
|
//
|
|
//nolint:dupl // different query+type vs ListChannelBans
|
|
func (database *Database) ListAllChannelsWithCountsFiltered(
|
|
ctx context.Context,
|
|
sessionID int64,
|
|
) ([]ChannelInfoFull, error) {
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT c.name, COUNT(cm.id) AS member_count,
|
|
c.topic
|
|
FROM channels c
|
|
LEFT JOIN channel_members cm
|
|
ON cm.channel_id = c.id
|
|
WHERE c.is_secret = 0
|
|
OR c.id IN (
|
|
SELECT channel_id FROM channel_members
|
|
WHERE session_id = ?
|
|
)
|
|
GROUP BY c.id
|
|
ORDER BY c.name ASC`,
|
|
sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"list channels filtered: %w", err,
|
|
)
|
|
}
|
|
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var channels []ChannelInfoFull
|
|
|
|
for rows.Next() {
|
|
var chanInfo ChannelInfoFull
|
|
|
|
if scanErr := rows.Scan(
|
|
&chanInfo.Name,
|
|
&chanInfo.MemberCount,
|
|
&chanInfo.Topic,
|
|
); scanErr != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan channel: %w", scanErr,
|
|
)
|
|
}
|
|
|
|
channels = append(channels, chanInfo)
|
|
}
|
|
|
|
if rowErr := rows.Err(); rowErr != nil {
|
|
return nil, fmt.Errorf(
|
|
"iterate channels: %w", rowErr,
|
|
)
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
// GetSessionChannelsFiltered returns channels a session
|
|
// belongs to, optionally excluding secret channels for
|
|
// WHOIS (when the querier is not in the same channel).
|
|
// If querierID == targetID, returns all channels.
|
|
func (database *Database) GetSessionChannelsFiltered(
|
|
ctx context.Context,
|
|
targetSID, querierSID int64,
|
|
) ([]ChannelInfo, error) {
|
|
// If querying yourself, return all channels.
|
|
if targetSID == querierSID {
|
|
return database.GetSessionChannels(ctx, targetSID)
|
|
}
|
|
|
|
rows, err := database.conn.QueryContext(ctx,
|
|
`SELECT c.id, c.name, c.topic
|
|
FROM channels c
|
|
JOIN channel_members cm
|
|
ON cm.channel_id = c.id
|
|
WHERE cm.session_id = ?
|
|
AND (c.is_secret = 0
|
|
OR c.id IN (
|
|
SELECT channel_id FROM channel_members
|
|
WHERE session_id = ?
|
|
))
|
|
ORDER BY c.name ASC`,
|
|
targetSID, querierSID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"get session channels filtered: %w", err,
|
|
)
|
|
}
|
|
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var channels []ChannelInfo
|
|
|
|
for rows.Next() {
|
|
var chanInfo ChannelInfo
|
|
|
|
if scanErr := rows.Scan(
|
|
&chanInfo.ID,
|
|
&chanInfo.Name,
|
|
&chanInfo.Topic,
|
|
); scanErr != nil {
|
|
return nil, fmt.Errorf(
|
|
"scan channel: %w", scanErr,
|
|
)
|
|
}
|
|
|
|
channels = append(channels, chanInfo)
|
|
}
|
|
|
|
if rowErr := rows.Err(); rowErr != nil {
|
|
return nil, fmt.Errorf(
|
|
"iterate channels: %w", rowErr,
|
|
)
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
// --- Tier 2: Channel Key (+k) ---
|
|
|
|
// GetChannelKey returns the key for a channel (empty
|
|
// string means no key set).
|
|
func (database *Database) GetChannelKey(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (string, error) {
|
|
var key string
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT channel_key FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&key)
|
|
if err != nil {
|
|
return "", fmt.Errorf("get channel key: %w", err)
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// SetChannelKey sets or clears the key for a channel.
|
|
func (database *Database) SetChannelKey(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
key string,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET channel_key = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
key, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf("set channel key: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// --- Tier 2: User Limit (+l) ---
|
|
|
|
// GetChannelUserLimit returns the user limit for a
|
|
// channel (0 means no limit).
|
|
func (database *Database) GetChannelUserLimit(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
) (int, error) {
|
|
var limit int
|
|
|
|
err := database.conn.QueryRowContext(ctx,
|
|
`SELECT user_limit FROM channels
|
|
WHERE id = ?`,
|
|
channelID,
|
|
).Scan(&limit)
|
|
if err != nil {
|
|
return 0, fmt.Errorf(
|
|
"get channel user limit: %w", err,
|
|
)
|
|
}
|
|
|
|
return limit, nil
|
|
}
|
|
|
|
// SetChannelUserLimit sets the user limit for a channel.
|
|
func (database *Database) SetChannelUserLimit(
|
|
ctx context.Context,
|
|
channelID int64,
|
|
limit int,
|
|
) error {
|
|
_, err := database.conn.ExecContext(ctx,
|
|
`UPDATE channels
|
|
SET user_limit = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
limit, time.Now(), channelID)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"set channel user limit: %w", err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|