feat: implement Tier 2 channel modes (+b/+i/+s/+k/+l) (#92)
Some checks failed
check / check (push) Failing after 1m31s
Some checks failed
check / check (push) Failing after 1m31s
## Summary Implements the second tier of IRC channel features as described in [#86](#86). ## Features ### 1. Ban System (+b) - `channel_bans` table with mask, set_by, created_at - Add/remove/list bans via MODE +b/-b - Wildcard matching (`*!*@*.example.com`, `badnick!*@*`, etc.) - Ban enforcement on both JOIN and PRIVMSG - RPL_BANLIST (367) / RPL_ENDOFBANLIST (368) for ban listing ### 2. Invite-Only (+i) - `is_invite_only` column on channels table - INVITE command: operators can invite users - `channel_invites` table tracks pending invites - Invites consumed on successful JOIN - ERR_INVITEONLYCHAN (473) for uninvited JOIN attempts ### 3. Secret (+s) - `is_secret` column on channels table - Secret channels hidden from LIST for non-members - Secret channels hidden from WHOIS channel list for non-members ### 4. Channel Key (+k) - `channel_key` column on channels table - MODE +k sets key, MODE -k clears it - Key required on JOIN (`JOIN #channel key`) - ERR_BADCHANNELKEY (475) for wrong/missing key ### 5. User Limit (+l) - `user_limit` column on channels table (0 = no limit) - MODE +l sets limit, MODE -l removes it - ERR_CHANNELISFULL (471) when limit reached ## ISUPPORT Changes - CHANMODES updated to `b,k,Hl,imnst` - RPL_MYINFO modes updated to `ikmnostl` ## Tests ### Database-level tests: - Wildcard matching (10 patterns) - Ban CRUD operations - Session ban checking - Invite-only flag toggle - Invite CRUD + clearing - Secret channel filtering (LIST and WHOIS) - Channel key set/get/clear - User limit set/get/clear ### Handler-level tests: - Ban add/remove/list via MODE - Ban blocks JOIN - Ban blocks PRIVMSG - Invite-only JOIN rejection + INVITE acceptance - Secret channel hidden from LIST - Channel key required on JOIN - User limit enforcement - Mode string includes new modes - ISUPPORT updated CHANMODES - Non-operators cannot set any Tier 2 modes ## Schema Changes - Added `is_invite_only`, `is_secret`, `channel_key`, `user_limit` to `channels` table - Added `channel_bans` table - Added `channel_invites` table - All changes in `001_initial.sql` (pre-1.0.0 repo) closes #86 Co-authored-by: user <user@Mac.lan guest wan> Reviewed-on: #92 Co-authored-by: clawbot <clawbot@noreply.example.org> Co-committed-by: clawbot <clawbot@noreply.example.org>
This commit was merged in pull request #92.
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/neoirc/pkg/irc"
|
||||
@@ -1835,3 +1836,534 @@ func (database *Database) PruneSpentHashcash(
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// --- Tier 2: Ban system (+b) ---
|
||||
|
||||
// BanInfo represents a channel ban entry.
|
||||
type BanInfo struct {
|
||||
Mask string
|
||||
SetBy string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AddChannelBan inserts a ban mask for a channel.
|
||||
func (database *Database) AddChannelBan(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
mask, setBy string,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO channel_bans
|
||||
(channel_id, mask, set_by, created_at)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
channelID, mask, setBy, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add channel ban: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveChannelBan removes a ban mask from a channel.
|
||||
func (database *Database) RemoveChannelBan(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
mask string,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`DELETE FROM channel_bans
|
||||
WHERE channel_id = ? AND mask = ?`,
|
||||
channelID, mask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove channel ban: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListChannelBans returns all bans for a channel.
|
||||
//
|
||||
//nolint:dupl // different query+type vs filtered variant
|
||||
func (database *Database) ListChannelBans(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) ([]BanInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT mask, set_by, created_at
|
||||
FROM channel_bans
|
||||
WHERE channel_id = ?
|
||||
ORDER BY created_at ASC`,
|
||||
channelID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list channel bans: %w", err)
|
||||
}
|
||||
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var bans []BanInfo
|
||||
|
||||
for rows.Next() {
|
||||
var ban BanInfo
|
||||
|
||||
if scanErr := rows.Scan(
|
||||
&ban.Mask, &ban.SetBy, &ban.CreatedAt,
|
||||
); scanErr != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"scan channel ban: %w", scanErr,
|
||||
)
|
||||
}
|
||||
|
||||
bans = append(bans, ban)
|
||||
}
|
||||
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"iterate channel bans: %w", rowErr,
|
||||
)
|
||||
}
|
||||
|
||||
return bans, nil
|
||||
}
|
||||
|
||||
// IsSessionBanned checks if a session's hostmask matches
|
||||
// any ban in the channel. Returns true if banned.
|
||||
func (database *Database) IsSessionBanned(
|
||||
ctx context.Context,
|
||||
channelID, sessionID int64,
|
||||
) (bool, error) {
|
||||
// Get the session's hostmask parts.
|
||||
var nick, username, hostname string
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT nick, username, hostname
|
||||
FROM sessions WHERE id = ?`,
|
||||
sessionID,
|
||||
).Scan(&nick, &username, &hostname)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(
|
||||
"get session hostmask: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
hostmask := FormatHostmask(nick, username, hostname)
|
||||
|
||||
// Get all ban masks for the channel.
|
||||
bans, banErr := database.ListChannelBans(ctx, channelID)
|
||||
if banErr != nil {
|
||||
return false, banErr
|
||||
}
|
||||
|
||||
for _, ban := range bans {
|
||||
if MatchBanMask(ban.Mask, hostmask) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// MatchBanMask checks if hostmask matches a ban pattern
|
||||
// using IRC-style wildcard matching (* and ?).
|
||||
func MatchBanMask(pattern, hostmask string) bool {
|
||||
return wildcardMatch(
|
||||
strings.ToLower(pattern),
|
||||
strings.ToLower(hostmask),
|
||||
)
|
||||
}
|
||||
|
||||
// wildcardMatch implements simple glob-style matching
|
||||
// with * (any sequence) and ? (any single character).
|
||||
func wildcardMatch(pattern, str string) bool {
|
||||
for len(pattern) > 0 {
|
||||
switch pattern[0] {
|
||||
case '*':
|
||||
// Skip consecutive asterisks.
|
||||
for len(pattern) > 0 && pattern[0] == '*' {
|
||||
pattern = pattern[1:]
|
||||
}
|
||||
|
||||
if len(pattern) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for i := 0; i <= len(str); i++ {
|
||||
if wildcardMatch(pattern, str[i:]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
case '?':
|
||||
if len(str) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern = pattern[1:]
|
||||
str = str[1:]
|
||||
default:
|
||||
if len(str) == 0 || pattern[0] != str[0] {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern = pattern[1:]
|
||||
str = str[1:]
|
||||
}
|
||||
}
|
||||
|
||||
return len(str) == 0
|
||||
}
|
||||
|
||||
// --- Tier 2: Invite-only (+i) ---
|
||||
|
||||
// IsChannelInviteOnly checks if a channel has +i mode.
|
||||
func (database *Database) IsChannelInviteOnly(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) (bool, error) {
|
||||
var isInviteOnly int
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT is_invite_only FROM channels
|
||||
WHERE id = ?`,
|
||||
channelID,
|
||||
).Scan(&isInviteOnly)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(
|
||||
"check invite only: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return isInviteOnly != 0, nil
|
||||
}
|
||||
|
||||
// SetChannelInviteOnly sets or unsets +i mode.
|
||||
func (database *Database) SetChannelInviteOnly(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
inviteOnly bool,
|
||||
) error {
|
||||
val := 0
|
||||
if inviteOnly {
|
||||
val = 1
|
||||
}
|
||||
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`UPDATE channels
|
||||
SET is_invite_only = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
val, time.Now(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"set invite only: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddChannelInvite records that a session has been
|
||||
// invited to a channel.
|
||||
func (database *Database) AddChannelInvite(
|
||||
ctx context.Context,
|
||||
channelID, sessionID int64,
|
||||
invitedBy string,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO channel_invites
|
||||
(channel_id, session_id, invited_by, created_at)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
channelID, sessionID, invitedBy, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add channel invite: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasChannelInvite checks if a session has been invited
|
||||
// to a channel.
|
||||
func (database *Database) HasChannelInvite(
|
||||
ctx context.Context,
|
||||
channelID, sessionID int64,
|
||||
) (bool, error) {
|
||||
var count int
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM channel_invites
|
||||
WHERE channel_id = ? AND session_id = ?`,
|
||||
channelID, sessionID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(
|
||||
"check invite: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ClearChannelInvite removes a session's invite to a
|
||||
// channel (called after successful JOIN).
|
||||
func (database *Database) ClearChannelInvite(
|
||||
ctx context.Context,
|
||||
channelID, sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`DELETE FROM channel_invites
|
||||
WHERE channel_id = ? AND session_id = ?`,
|
||||
channelID, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("clear invite: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Tier 2: Secret (+s) ---
|
||||
|
||||
// IsChannelSecret checks if a channel has +s mode.
|
||||
func (database *Database) IsChannelSecret(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) (bool, error) {
|
||||
var isSecret int
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT is_secret FROM channels
|
||||
WHERE id = ?`,
|
||||
channelID,
|
||||
).Scan(&isSecret)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(
|
||||
"check secret: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return isSecret != 0, nil
|
||||
}
|
||||
|
||||
// SetChannelSecret sets or unsets +s mode.
|
||||
func (database *Database) SetChannelSecret(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
secret bool,
|
||||
) error {
|
||||
val := 0
|
||||
if secret {
|
||||
val = 1
|
||||
}
|
||||
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`UPDATE channels
|
||||
SET is_secret = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
val, time.Now(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set secret: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAllChannelsWithCountsFiltered returns all channels
|
||||
// with member counts, excluding secret channels that
|
||||
// the given session is not a member of.
|
||||
//
|
||||
//nolint:dupl // different query+type vs ListChannelBans
|
||||
func (database *Database) ListAllChannelsWithCountsFiltered(
|
||||
ctx context.Context,
|
||||
sessionID int64,
|
||||
) ([]ChannelInfoFull, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT c.name, COUNT(cm.id) AS member_count,
|
||||
c.topic
|
||||
FROM channels c
|
||||
LEFT JOIN channel_members cm
|
||||
ON cm.channel_id = c.id
|
||||
WHERE c.is_secret = 0
|
||||
OR c.id IN (
|
||||
SELECT channel_id FROM channel_members
|
||||
WHERE session_id = ?
|
||||
)
|
||||
GROUP BY c.id
|
||||
ORDER BY c.name ASC`,
|
||||
sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"list channels filtered: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var channels []ChannelInfoFull
|
||||
|
||||
for rows.Next() {
|
||||
var chanInfo ChannelInfoFull
|
||||
|
||||
if scanErr := rows.Scan(
|
||||
&chanInfo.Name,
|
||||
&chanInfo.MemberCount,
|
||||
&chanInfo.Topic,
|
||||
); scanErr != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"scan channel: %w", scanErr,
|
||||
)
|
||||
}
|
||||
|
||||
channels = append(channels, chanInfo)
|
||||
}
|
||||
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"iterate channels: %w", rowErr,
|
||||
)
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// GetSessionChannelsFiltered returns channels a session
|
||||
// belongs to, optionally excluding secret channels for
|
||||
// WHOIS (when the querier is not in the same channel).
|
||||
// If querierID == targetID, returns all channels.
|
||||
func (database *Database) GetSessionChannelsFiltered(
|
||||
ctx context.Context,
|
||||
targetSID, querierSID int64,
|
||||
) ([]ChannelInfo, error) {
|
||||
// If querying yourself, return all channels.
|
||||
if targetSID == querierSID {
|
||||
return database.GetSessionChannels(ctx, targetSID)
|
||||
}
|
||||
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT c.id, c.name, c.topic
|
||||
FROM channels c
|
||||
JOIN channel_members cm
|
||||
ON cm.channel_id = c.id
|
||||
WHERE cm.session_id = ?
|
||||
AND (c.is_secret = 0
|
||||
OR c.id IN (
|
||||
SELECT channel_id FROM channel_members
|
||||
WHERE session_id = ?
|
||||
))
|
||||
ORDER BY c.name ASC`,
|
||||
targetSID, querierSID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"get session channels filtered: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var channels []ChannelInfo
|
||||
|
||||
for rows.Next() {
|
||||
var chanInfo ChannelInfo
|
||||
|
||||
if scanErr := rows.Scan(
|
||||
&chanInfo.ID,
|
||||
&chanInfo.Name,
|
||||
&chanInfo.Topic,
|
||||
); scanErr != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"scan channel: %w", scanErr,
|
||||
)
|
||||
}
|
||||
|
||||
channels = append(channels, chanInfo)
|
||||
}
|
||||
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"iterate channels: %w", rowErr,
|
||||
)
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// --- Tier 2: Channel Key (+k) ---
|
||||
|
||||
// GetChannelKey returns the key for a channel (empty
|
||||
// string means no key set).
|
||||
func (database *Database) GetChannelKey(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) (string, error) {
|
||||
var key string
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT channel_key FROM channels
|
||||
WHERE id = ?`,
|
||||
channelID,
|
||||
).Scan(&key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get channel key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SetChannelKey sets or clears the key for a channel.
|
||||
func (database *Database) SetChannelKey(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
key string,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`UPDATE channels
|
||||
SET channel_key = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
key, time.Now(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set channel key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Tier 2: User Limit (+l) ---
|
||||
|
||||
// GetChannelUserLimit returns the user limit for a
|
||||
// channel (0 means no limit).
|
||||
func (database *Database) GetChannelUserLimit(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) (int, error) {
|
||||
var limit int
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT user_limit FROM channels
|
||||
WHERE id = ?`,
|
||||
channelID,
|
||||
).Scan(&limit)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(
|
||||
"get channel user limit: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
// SetChannelUserLimit sets the user limit for a channel.
|
||||
func (database *Database) SetChannelUserLimit(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
limit int,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`UPDATE channels
|
||||
SET user_limit = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
limit, time.Now(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"set channel user limit: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1017,3 +1017,474 @@ func TestGetOperCount(t *testing.T) {
|
||||
t.Fatalf("expected 1 oper, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tier 2 Tests ---
|
||||
|
||||
func TestWildcardMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
pattern string
|
||||
input string
|
||||
match bool
|
||||
}{
|
||||
{"*!*@*", "nick!user@host", true},
|
||||
{"*!*@*.example.com", "nick!user@foo.example.com", true},
|
||||
{"*!*@*.example.com", "nick!user@other.net", false},
|
||||
{"badnick!*@*", "badnick!user@host", true},
|
||||
{"badnick!*@*", "goodnick!user@host", false},
|
||||
{"nick!user@host", "nick!user@host", true},
|
||||
{"nick!user@host", "nick!user@other", false},
|
||||
{"*", "anything", true},
|
||||
{"?ick!*@*", "nick!user@host", true},
|
||||
{"?ick!*@*", "nn!user@host", false},
|
||||
// Case-insensitive.
|
||||
{"Nick!*@*", "nick!user@host", true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
result := db.MatchBanMask(tc.pattern, tc.input)
|
||||
if result != tc.match {
|
||||
t.Errorf(
|
||||
"MatchBanMask(%q, %q) = %v, want %v",
|
||||
tc.pattern, tc.input, result, tc.match,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelBanCRUD(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// No bans initially.
|
||||
bans, err := database.ListChannelBans(ctx, chID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(bans) != 0 {
|
||||
t.Fatalf("expected 0 bans, got %d", len(bans))
|
||||
}
|
||||
|
||||
// Add a ban.
|
||||
err = database.AddChannelBan(
|
||||
ctx, chID, "*!*@evil.com", "op",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bans, err = database.ListChannelBans(ctx, chID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(bans) != 1 {
|
||||
t.Fatalf("expected 1 ban, got %d", len(bans))
|
||||
}
|
||||
|
||||
if bans[0].Mask != "*!*@evil.com" {
|
||||
t.Fatalf("wrong mask: %s", bans[0].Mask)
|
||||
}
|
||||
|
||||
// Duplicate add is ignored (OR IGNORE).
|
||||
err = database.AddChannelBan(
|
||||
ctx, chID, "*!*@evil.com", "op2",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bans, _ = database.ListChannelBans(ctx, chID)
|
||||
if len(bans) != 1 {
|
||||
t.Fatalf("expected 1 ban after dup, got %d", len(bans))
|
||||
}
|
||||
|
||||
// Remove ban.
|
||||
err = database.RemoveChannelBan(
|
||||
ctx, chID, "*!*@evil.com",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bans, _ = database.ListChannelBans(ctx, chID)
|
||||
if len(bans) != 0 {
|
||||
t.Fatalf("expected 0 bans after remove, got %d", len(bans))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSessionBanned(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
sid, _, _, err := database.CreateSession(
|
||||
ctx, "victim", "victim", "evil.com", "",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#bantest")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Not banned initially.
|
||||
banned, err := database.IsSessionBanned(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if banned {
|
||||
t.Fatal("should not be banned initially")
|
||||
}
|
||||
|
||||
// Add ban matching the hostmask.
|
||||
err = database.AddChannelBan(
|
||||
ctx, chID, "*!*@evil.com", "op",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
banned, err = database.IsSessionBanned(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !banned {
|
||||
t.Fatal("should be banned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInviteOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#invite")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Default: not invite-only.
|
||||
isIO, err := database.IsChannelInviteOnly(ctx, chID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if isIO {
|
||||
t.Fatal("should not be invite-only by default")
|
||||
}
|
||||
|
||||
// Set invite-only.
|
||||
err = database.SetChannelInviteOnly(ctx, chID, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
|
||||
if !isIO {
|
||||
t.Fatal("should be invite-only")
|
||||
}
|
||||
|
||||
// Unset.
|
||||
err = database.SetChannelInviteOnly(ctx, chID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
|
||||
if isIO {
|
||||
t.Fatal("should not be invite-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInviteCRUD(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
sid, _, _, err := database.CreateSession(
|
||||
ctx, "invited", "", "", "",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#inv")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// No invite initially.
|
||||
has, err := database.HasChannelInvite(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if has {
|
||||
t.Fatal("should not have invite")
|
||||
}
|
||||
|
||||
// Add invite.
|
||||
err = database.AddChannelInvite(ctx, chID, sid, "op")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
has, _ = database.HasChannelInvite(ctx, chID, sid)
|
||||
if !has {
|
||||
t.Fatal("should have invite")
|
||||
}
|
||||
|
||||
// Clear invite.
|
||||
err = database.ClearChannelInvite(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
has, _ = database.HasChannelInvite(ctx, chID, sid)
|
||||
if has {
|
||||
t.Fatal("invite should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#secret")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Default: not secret.
|
||||
isSec, err := database.IsChannelSecret(ctx, chID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if isSec {
|
||||
t.Fatal("should not be secret by default")
|
||||
}
|
||||
|
||||
err = database.SetChannelSecret(ctx, chID, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
isSec, _ = database.IsChannelSecret(ctx, chID)
|
||||
if !isSec {
|
||||
t.Fatal("should be secret")
|
||||
}
|
||||
}
|
||||
|
||||
// createTestSession is a helper to create a session and
|
||||
// return only the session ID.
|
||||
func createTestSession(
|
||||
t *testing.T,
|
||||
database *db.Database,
|
||||
nick string,
|
||||
) int64 {
|
||||
t.Helper()
|
||||
|
||||
sid, _, _, err := database.CreateSession(
|
||||
t.Context(), nick, "", "", "",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("create session %s: %v", nick, err)
|
||||
}
|
||||
|
||||
return sid
|
||||
}
|
||||
|
||||
func TestSecretChannelFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
// Create two sessions.
|
||||
sid1 := createTestSession(t, database, "member")
|
||||
sid2 := createTestSession(t, database, "outsider")
|
||||
|
||||
// Create a secret channel.
|
||||
chID, _ := database.GetOrCreateChannel(ctx, "#secret")
|
||||
_ = database.SetChannelSecret(ctx, chID, true)
|
||||
_ = database.JoinChannel(ctx, chID, sid1)
|
||||
|
||||
// Create a non-secret channel.
|
||||
chID2, _ := database.GetOrCreateChannel(ctx, "#public")
|
||||
_ = database.JoinChannel(ctx, chID2, sid1)
|
||||
|
||||
// Member should see both.
|
||||
list, err := database.ListAllChannelsWithCountsFiltered(
|
||||
ctx, sid1,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(list) != 2 {
|
||||
t.Fatalf("member should see 2 channels, got %d", len(list))
|
||||
}
|
||||
|
||||
// Outsider should only see public.
|
||||
list, _ = database.ListAllChannelsWithCountsFiltered(
|
||||
ctx, sid2,
|
||||
)
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("outsider should see 1 channel, got %d", len(list))
|
||||
}
|
||||
|
||||
if list[0].Name != "#public" {
|
||||
t.Fatalf("outsider should see #public, got %s", list[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhoisChannelFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
sid1 := createTestSession(t, database, "target")
|
||||
sid2 := createTestSession(t, database, "querier")
|
||||
|
||||
// Create secret channel, target joins it.
|
||||
chID, _ := database.GetOrCreateChannel(ctx, "#hidden")
|
||||
_ = database.SetChannelSecret(ctx, chID, true)
|
||||
_ = database.JoinChannel(ctx, chID, sid1)
|
||||
|
||||
// Querier (non-member) should not see the channel.
|
||||
channels, err := database.GetSessionChannelsFiltered(
|
||||
ctx, sid1, sid2,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(channels) != 0 {
|
||||
t.Fatalf(
|
||||
"querier should see 0 channels, got %d",
|
||||
len(channels),
|
||||
)
|
||||
}
|
||||
|
||||
// Target querying self should see it.
|
||||
channels, _ = database.GetSessionChannelsFiltered(
|
||||
ctx, sid1, sid1,
|
||||
)
|
||||
if len(channels) != 1 {
|
||||
t.Fatalf(
|
||||
"self-query should see 1 channel, got %d",
|
||||
len(channels),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:dupl // structurally similar to TestChannelUserLimit
|
||||
func TestChannelKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#keyed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Default: no key.
|
||||
key, err := database.GetChannelKey(ctx, chID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if key != "" {
|
||||
t.Fatalf("expected empty key, got %q", key)
|
||||
}
|
||||
|
||||
// Set key.
|
||||
err = database.SetChannelKey(ctx, chID, "secret123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, _ = database.GetChannelKey(ctx, chID)
|
||||
if key != "secret123" {
|
||||
t.Fatalf("expected secret123, got %q", key)
|
||||
}
|
||||
|
||||
// Clear key.
|
||||
err = database.SetChannelKey(ctx, chID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, _ = database.GetChannelKey(ctx, chID)
|
||||
if key != "" {
|
||||
t.Fatalf("expected empty key, got %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:dupl // structurally similar to TestChannelKey
|
||||
func TestChannelUserLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
chID, err := database.GetOrCreateChannel(ctx, "#limited")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Default: no limit.
|
||||
limit, err := database.GetChannelUserLimit(ctx, chID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
t.Fatalf("expected 0 limit, got %d", limit)
|
||||
}
|
||||
|
||||
// Set limit.
|
||||
err = database.SetChannelUserLimit(ctx, chID, 50)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
limit, _ = database.GetChannelUserLimit(ctx, chID)
|
||||
if limit != 50 {
|
||||
t.Fatalf("expected 50, got %d", limit)
|
||||
}
|
||||
|
||||
// Clear limit.
|
||||
err = database.SetChannelUserLimit(ctx, chID, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
limit, _ = database.GetChannelUserLimit(ctx, chID)
|
||||
if limit != 0 {
|
||||
t.Fatalf("expected 0, got %d", limit)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,10 +42,36 @@ CREATE TABLE IF NOT EXISTS channels (
|
||||
hashcash_bits INTEGER NOT NULL DEFAULT 0,
|
||||
is_moderated INTEGER NOT NULL DEFAULT 0,
|
||||
is_topic_locked INTEGER NOT NULL DEFAULT 1,
|
||||
is_invite_only INTEGER NOT NULL DEFAULT 0,
|
||||
is_secret INTEGER NOT NULL DEFAULT 0,
|
||||
channel_key TEXT NOT NULL DEFAULT '',
|
||||
user_limit INTEGER NOT NULL DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Channel bans
|
||||
CREATE TABLE IF NOT EXISTS channel_bans (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
mask TEXT NOT NULL,
|
||||
set_by TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(channel_id, mask)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_bans_channel ON channel_bans(channel_id);
|
||||
|
||||
-- Channel invites (in-memory would be simpler but DB survives restarts)
|
||||
CREATE TABLE IF NOT EXISTS channel_invites (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
invited_by TEXT NOT NULL DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(channel_id, session_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_invites_channel ON channel_invites(channel_id);
|
||||
|
||||
-- Channel members
|
||||
CREATE TABLE IF NOT EXISTS channel_members (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
||||
Reference in New Issue
Block a user