From 75867ba778870a22056e41a3a7e072c14d23f9a0 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 24 Mar 2026 18:32:52 -0700 Subject: [PATCH 1/2] feat: implement Tier 2 channel modes (+b/+i/+s/+k/+l) Implement the second tier of IRC channel features: 1. Ban system (+b): Add/remove/list bans with wildcard matching. Bans prevent both joining and sending messages. Schema: channel_bans table with mask, set_by, created_at. 2. Invite-only (+i): Channel mode requiring invitation to join. INVITE command for operators. Invites stored in DB and cleared after successful JOIN. 3. Secret (+s): Hides channel from LIST for non-members and from WHOIS channel lists when querier is not in same channel. 4. Channel key (+k): Password-protected channels. Key required on JOIN, set/cleared by operators. 5. User limit (+l): Maximum member count enforcement. Rejects JOIN when channel is at capacity. Updated ISUPPORT CHANMODES to b,k,Hl,imnst. Updated RPL_MYINFO available modes to ikmnostl. Comprehensive tests for all features at both DB and handler levels. README updated with full documentation of all new modes. closes #86 --- README.md | 51 +- internal/db/queries.go | 532 ++++++++++++++ internal/db/queries_test.go | 471 ++++++++++++ internal/db/schema/001_initial.sql | 26 + internal/handlers/api.go | 1066 +++++++++++++++++++++++++--- internal/handlers/api_test.go | 483 +++++++++++++ pkg/irc/commands.go | 1 + 7 files changed, 2544 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 0ad9205..4c6df20 100644 --- a/README.md +++ b/README.md @@ -1080,7 +1080,7 @@ the server to the client (never C2S) and use 3-digit string codes in the | `002` | RPL_YOURHOST | After session creation | `{"command":"002","to":"alice","body":["Your host is neoirc, running version 0.1"]}` | | `003` | RPL_CREATED | After session creation | `{"command":"003","to":"alice","body":["This server was created 2026-02-10"]}` | | `004` | RPL_MYINFO | After session creation | `{"command":"004","to":"alice","params":["neoirc","0.1","","mnst"]}` | -| `005` | RPL_ISUPPORT | After session creation | `{"command":"005","to":"alice","params":["CHANTYPES=#","NICKLEN=32","PREFIX=(ov)@+","CHANMODES=,,H,mnst","NETWORK=neoirc"],"body":["are supported by this server"]}` | +| `005` | RPL_ISUPPORT | After session creation | `{"command":"005","to":"alice","params":["CHANTYPES=#","NICKLEN=32","PREFIX=(ov)@+","CHANMODES=b,k,Hl,imnst","NETWORK=neoirc"],"body":["are supported by this server"]}` | | `221` | RPL_UMODEIS | In response to user MODE query | `{"command":"221","to":"alice","body":["+"]}` | | `251` | RPL_LUSERCLIENT | On connect or LUSERS command | `{"command":"251","to":"alice","body":["There are 5 users and 0 invisible on 1 servers"]}` | | `252` | RPL_LUSEROP | On connect or LUSERS command | `{"command":"252","to":"alice","params":["0"],"body":["operator(s) online"]}` | @@ -1128,12 +1128,15 @@ Inspired by IRC, simplified: | Mode | Name | Meaning | Status | |------|----------------|---------|--------| +| `+b` | Ban | Prevents matching hostmasks from joining or sending (parameter: `nick!user@host` mask with wildcards) | **Enforced** | +| `+i` | Invite-only | Only invited users can join; use `INVITE nick #channel` to invite | **Enforced** | +| `+k` | Channel key | Requires a password to join (parameter: key string) | **Enforced** | +| `+l` | User limit | Maximum number of members allowed in the channel (parameter: integer) | **Enforced** | | `+m` | Moderated | Only voiced (`+v`) users and operators (`+o`) can send | **Enforced** | -| `+t` | Topic lock | Only operators can change the topic (default: ON) | **Enforced** | | `+n` | No external | Only channel members can send messages to the channel | **Enforced** | +| `+s` | Secret | Channel hidden from LIST and WHOIS for non-members | **Enforced** | +| `+t` | Topic lock | Only operators can change the topic (default: ON) | **Enforced** | | `+H` | Hashcash | Requires proof-of-work for PRIVMSG (parameter: bits, e.g. `+H 20`) | **Enforced** | -| `+i` | Invite-only | Only invited users can join | Not yet enforced | -| `+s` | Secret | Channel hidden from LIST response | Not yet enforced | **User channel modes (set per-user per-channel):** @@ -1145,6 +1148,42 @@ Inspired by IRC, simplified: **Channel creator auto-op:** The first user to JOIN a channel (creating it) automatically receives `+o` operator status. +**Ban system (+b):** Operators can ban users by hostmask pattern with wildcard +matching (`*` and `?`). `MODE #channel +b` with no argument lists current bans. +Bans prevent both joining and sending messages. + +``` +MODE #channel +b *!*@*.example.com — ban all users from example.com +MODE #channel -b *!*@*.example.com — remove the ban +MODE #channel +b — list all bans (RPL_BANLIST 367/368) +``` + +**Invite-only (+i):** When set, users must be invited by an operator before +joining. The `INVITE` command records an invite that is consumed on JOIN. + +``` +MODE #channel +i — set invite-only +INVITE nick #channel — invite a user (operator only on +i channels) +``` + +**Channel key (+k):** Requires a password to join the channel. + +``` +MODE #channel +k secretpass — set a channel key +MODE #channel -k * — remove the key +JOIN #channel secretpass — join with key +``` + +**User limit (+l):** Caps the number of members in the channel. + +``` +MODE #channel +l 50 — set limit to 50 members +MODE #channel -l — remove the limit +``` + +**Secret (+s):** Hides the channel from `LIST` for non-members and from +`WHOIS` channel lists when the querier is not in the same channel. + **KICK command:** Channel operators can remove users with `KICK #channel nick [:reason]`. The kicked user and all channel members receive the KICK message. @@ -1153,7 +1192,7 @@ RPL_AWAY), and skips hashcash validation on +H channels (servers and services use NOTICE). **ISUPPORT:** The server advertises `PREFIX=(ov)@+` and -`CHANMODES=,,H,mnst` in RPL_ISUPPORT (005). +`CHANMODES=b,k,Hl,imnst` in RPL_ISUPPORT (005). ### Per-Channel Hashcash (Anti-Spam) @@ -2695,7 +2734,7 @@ guess is borne by the server (bcrypt), not the client. - [x] **Client output queue pruning** — delete old client output queue entries per `QUEUE_MAX_AGE` - [x] **Message rotation** — prune messages older than `MESSAGE_MAX_AGE` - [x] **Channel modes** — enforce `+m` (moderated), `+t` (topic lock), `+n` (no external) -- [ ] **Channel modes (tier 2)** — enforce `+i` (invite-only), `+s` (secret), `+b` (ban), `+k` (key), `+l` (limit) +- [x] **Channel modes (tier 2)** — enforce `+i` (invite-only), `+s` (secret), `+b` (ban), `+k` (key), `+l` (limit) - [x] **User channel modes** — `+o` (operator), `+v` (voice) with NAMES prefixes - [x] **KICK command** — operator-only channel kick with broadcast - [x] **MODE command** — query and set channel/user modes diff --git a/internal/db/queries.go b/internal/db/queries.go index 5000959..9029337 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "strconv" + "strings" "time" "git.eeqj.de/sneak/neoirc/pkg/irc" @@ -1835,3 +1836,534 @@ func (database *Database) PruneSpentHashcash( return deleted, nil } + +// --- Tier 2: Ban system (+b) --- + +// BanInfo represents a channel ban entry. +type BanInfo struct { + Mask string + SetBy string + CreatedAt time.Time +} + +// AddChannelBan inserts a ban mask for a channel. +func (database *Database) AddChannelBan( + ctx context.Context, + channelID int64, + mask, setBy string, +) error { + _, err := database.conn.ExecContext(ctx, + `INSERT OR IGNORE INTO channel_bans + (channel_id, mask, set_by, created_at) + VALUES (?, ?, ?, ?)`, + channelID, mask, setBy, time.Now()) + if err != nil { + return fmt.Errorf("add channel ban: %w", err) + } + + return nil +} + +// RemoveChannelBan removes a ban mask from a channel. +func (database *Database) RemoveChannelBan( + ctx context.Context, + channelID int64, + mask string, +) error { + _, err := database.conn.ExecContext(ctx, + `DELETE FROM channel_bans + WHERE channel_id = ? AND mask = ?`, + channelID, mask) + if err != nil { + return fmt.Errorf("remove channel ban: %w", err) + } + + return nil +} + +// ListChannelBans returns all bans for a channel. +// +//nolint:dupl // different query+type vs filtered variant +func (database *Database) ListChannelBans( + ctx context.Context, + channelID int64, +) ([]BanInfo, error) { + rows, err := database.conn.QueryContext(ctx, + `SELECT mask, set_by, created_at + FROM channel_bans + WHERE channel_id = ? + ORDER BY created_at ASC`, + channelID) + if err != nil { + return nil, fmt.Errorf("list channel bans: %w", err) + } + + defer func() { _ = rows.Close() }() + + var bans []BanInfo + + for rows.Next() { + var ban BanInfo + + if scanErr := rows.Scan( + &ban.Mask, &ban.SetBy, &ban.CreatedAt, + ); scanErr != nil { + return nil, fmt.Errorf( + "scan channel ban: %w", scanErr, + ) + } + + bans = append(bans, ban) + } + + if rowErr := rows.Err(); rowErr != nil { + return nil, fmt.Errorf( + "iterate channel bans: %w", rowErr, + ) + } + + return bans, nil +} + +// IsSessionBanned checks if a session's hostmask matches +// any ban in the channel. Returns true if banned. +func (database *Database) IsSessionBanned( + ctx context.Context, + channelID, sessionID int64, +) (bool, error) { + // Get the session's hostmask parts. + var nick, username, hostname string + + err := database.conn.QueryRowContext(ctx, + `SELECT nick, username, hostname + FROM sessions WHERE id = ?`, + sessionID, + ).Scan(&nick, &username, &hostname) + if err != nil { + return false, fmt.Errorf( + "get session hostmask: %w", err, + ) + } + + hostmask := FormatHostmask(nick, username, hostname) + + // Get all ban masks for the channel. + bans, banErr := database.ListChannelBans(ctx, channelID) + if banErr != nil { + return false, banErr + } + + for _, ban := range bans { + if MatchBanMask(ban.Mask, hostmask) { + return true, nil + } + } + + return false, nil +} + +// MatchBanMask checks if hostmask matches a ban pattern +// using IRC-style wildcard matching (* and ?). +func MatchBanMask(pattern, hostmask string) bool { + return wildcardMatch( + strings.ToLower(pattern), + strings.ToLower(hostmask), + ) +} + +// wildcardMatch implements simple glob-style matching +// with * (any sequence) and ? (any single character). +func wildcardMatch(pattern, str string) bool { + for len(pattern) > 0 { + switch pattern[0] { + case '*': + // Skip consecutive asterisks. + for len(pattern) > 0 && pattern[0] == '*' { + pattern = pattern[1:] + } + + if len(pattern) == 0 { + return true + } + + for i := 0; i <= len(str); i++ { + if wildcardMatch(pattern, str[i:]) { + return true + } + } + + return false + case '?': + if len(str) == 0 { + return false + } + + pattern = pattern[1:] + str = str[1:] + default: + if len(str) == 0 || pattern[0] != str[0] { + return false + } + + pattern = pattern[1:] + str = str[1:] + } + } + + return len(str) == 0 +} + +// --- Tier 2: Invite-only (+i) --- + +// IsChannelInviteOnly checks if a channel has +i mode. +func (database *Database) IsChannelInviteOnly( + ctx context.Context, + channelID int64, +) (bool, error) { + var isInviteOnly int + + err := database.conn.QueryRowContext(ctx, + `SELECT is_invite_only FROM channels + WHERE id = ?`, + channelID, + ).Scan(&isInviteOnly) + if err != nil { + return false, fmt.Errorf( + "check invite only: %w", err, + ) + } + + return isInviteOnly != 0, nil +} + +// SetChannelInviteOnly sets or unsets +i mode. +func (database *Database) SetChannelInviteOnly( + ctx context.Context, + channelID int64, + inviteOnly bool, +) error { + val := 0 + if inviteOnly { + val = 1 + } + + _, err := database.conn.ExecContext(ctx, + `UPDATE channels + SET is_invite_only = ?, updated_at = ? + WHERE id = ?`, + val, time.Now(), channelID) + if err != nil { + return fmt.Errorf( + "set invite only: %w", err, + ) + } + + return nil +} + +// AddChannelInvite records that a session has been +// invited to a channel. +func (database *Database) AddChannelInvite( + ctx context.Context, + channelID, sessionID int64, + invitedBy string, +) error { + _, err := database.conn.ExecContext(ctx, + `INSERT OR IGNORE INTO channel_invites + (channel_id, session_id, invited_by, created_at) + VALUES (?, ?, ?, ?)`, + channelID, sessionID, invitedBy, time.Now()) + if err != nil { + return fmt.Errorf("add channel invite: %w", err) + } + + return nil +} + +// HasChannelInvite checks if a session has been invited +// to a channel. +func (database *Database) HasChannelInvite( + ctx context.Context, + channelID, sessionID int64, +) (bool, error) { + var count int + + err := database.conn.QueryRowContext(ctx, + `SELECT COUNT(*) FROM channel_invites + WHERE channel_id = ? AND session_id = ?`, + channelID, sessionID, + ).Scan(&count) + if err != nil { + return false, fmt.Errorf( + "check invite: %w", err, + ) + } + + return count > 0, nil +} + +// ClearChannelInvite removes a session's invite to a +// channel (called after successful JOIN). +func (database *Database) ClearChannelInvite( + ctx context.Context, + channelID, sessionID int64, +) error { + _, err := database.conn.ExecContext(ctx, + `DELETE FROM channel_invites + WHERE channel_id = ? AND session_id = ?`, + channelID, sessionID) + if err != nil { + return fmt.Errorf("clear invite: %w", err) + } + + return nil +} + +// --- Tier 2: Secret (+s) --- + +// IsChannelSecret checks if a channel has +s mode. +func (database *Database) IsChannelSecret( + ctx context.Context, + channelID int64, +) (bool, error) { + var isSecret int + + err := database.conn.QueryRowContext(ctx, + `SELECT is_secret FROM channels + WHERE id = ?`, + channelID, + ).Scan(&isSecret) + if err != nil { + return false, fmt.Errorf( + "check secret: %w", err, + ) + } + + return isSecret != 0, nil +} + +// SetChannelSecret sets or unsets +s mode. +func (database *Database) SetChannelSecret( + ctx context.Context, + channelID int64, + secret bool, +) error { + val := 0 + if secret { + val = 1 + } + + _, err := database.conn.ExecContext(ctx, + `UPDATE channels + SET is_secret = ?, updated_at = ? + WHERE id = ?`, + val, time.Now(), channelID) + if err != nil { + return fmt.Errorf("set secret: %w", err) + } + + return nil +} + +// ListAllChannelsWithCountsFiltered returns all channels +// with member counts, excluding secret channels that +// the given session is not a member of. +// +//nolint:dupl // different query+type vs ListChannelBans +func (database *Database) ListAllChannelsWithCountsFiltered( + ctx context.Context, + sessionID int64, +) ([]ChannelInfoFull, error) { + rows, err := database.conn.QueryContext(ctx, + `SELECT c.name, COUNT(cm.id) AS member_count, + c.topic + FROM channels c + LEFT JOIN channel_members cm + ON cm.channel_id = c.id + WHERE c.is_secret = 0 + OR c.id IN ( + SELECT channel_id FROM channel_members + WHERE session_id = ? + ) + GROUP BY c.id + ORDER BY c.name ASC`, + sessionID) + if err != nil { + return nil, fmt.Errorf( + "list channels filtered: %w", err, + ) + } + + defer func() { _ = rows.Close() }() + + var channels []ChannelInfoFull + + for rows.Next() { + var chanInfo ChannelInfoFull + + if scanErr := rows.Scan( + &chanInfo.Name, + &chanInfo.MemberCount, + &chanInfo.Topic, + ); scanErr != nil { + return nil, fmt.Errorf( + "scan channel: %w", scanErr, + ) + } + + channels = append(channels, chanInfo) + } + + if rowErr := rows.Err(); rowErr != nil { + return nil, fmt.Errorf( + "iterate channels: %w", rowErr, + ) + } + + return channels, nil +} + +// GetSessionChannelsFiltered returns channels a session +// belongs to, optionally excluding secret channels for +// WHOIS (when the querier is not in the same channel). +// If querierID == targetID, returns all channels. +func (database *Database) GetSessionChannelsFiltered( + ctx context.Context, + targetSID, querierSID int64, +) ([]ChannelInfo, error) { + // If querying yourself, return all channels. + if targetSID == querierSID { + return database.GetSessionChannels(ctx, targetSID) + } + + rows, err := database.conn.QueryContext(ctx, + `SELECT c.id, c.name, c.topic + FROM channels c + JOIN channel_members cm + ON cm.channel_id = c.id + WHERE cm.session_id = ? + AND (c.is_secret = 0 + OR c.id IN ( + SELECT channel_id FROM channel_members + WHERE session_id = ? + )) + ORDER BY c.name ASC`, + targetSID, querierSID) + if err != nil { + return nil, fmt.Errorf( + "get session channels filtered: %w", err, + ) + } + + defer func() { _ = rows.Close() }() + + var channels []ChannelInfo + + for rows.Next() { + var chanInfo ChannelInfo + + if scanErr := rows.Scan( + &chanInfo.ID, + &chanInfo.Name, + &chanInfo.Topic, + ); scanErr != nil { + return nil, fmt.Errorf( + "scan channel: %w", scanErr, + ) + } + + channels = append(channels, chanInfo) + } + + if rowErr := rows.Err(); rowErr != nil { + return nil, fmt.Errorf( + "iterate channels: %w", rowErr, + ) + } + + return channels, nil +} + +// --- Tier 2: Channel Key (+k) --- + +// GetChannelKey returns the key for a channel (empty +// string means no key set). +func (database *Database) GetChannelKey( + ctx context.Context, + channelID int64, +) (string, error) { + var key string + + err := database.conn.QueryRowContext(ctx, + `SELECT channel_key FROM channels + WHERE id = ?`, + channelID, + ).Scan(&key) + if err != nil { + return "", fmt.Errorf("get channel key: %w", err) + } + + return key, nil +} + +// SetChannelKey sets or clears the key for a channel. +func (database *Database) SetChannelKey( + ctx context.Context, + channelID int64, + key string, +) error { + _, err := database.conn.ExecContext(ctx, + `UPDATE channels + SET channel_key = ?, updated_at = ? + WHERE id = ?`, + key, time.Now(), channelID) + if err != nil { + return fmt.Errorf("set channel key: %w", err) + } + + return nil +} + +// --- Tier 2: User Limit (+l) --- + +// GetChannelUserLimit returns the user limit for a +// channel (0 means no limit). +func (database *Database) GetChannelUserLimit( + ctx context.Context, + channelID int64, +) (int, error) { + var limit int + + err := database.conn.QueryRowContext(ctx, + `SELECT user_limit FROM channels + WHERE id = ?`, + channelID, + ).Scan(&limit) + if err != nil { + return 0, fmt.Errorf( + "get channel user limit: %w", err, + ) + } + + return limit, nil +} + +// SetChannelUserLimit sets the user limit for a channel. +func (database *Database) SetChannelUserLimit( + ctx context.Context, + channelID int64, + limit int, +) error { + _, err := database.conn.ExecContext(ctx, + `UPDATE channels + SET user_limit = ?, updated_at = ? + WHERE id = ?`, + limit, time.Now(), channelID) + if err != nil { + return fmt.Errorf( + "set channel user limit: %w", err, + ) + } + + return nil +} diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go index e270bdb..648f634 100644 --- a/internal/db/queries_test.go +++ b/internal/db/queries_test.go @@ -1017,3 +1017,474 @@ func TestGetOperCount(t *testing.T) { t.Fatalf("expected 1 oper, got %d", count) } } + +// --- Tier 2 Tests --- + +func TestWildcardMatch(t *testing.T) { + t.Parallel() + + tests := []struct { + pattern string + input string + match bool + }{ + {"*!*@*", "nick!user@host", true}, + {"*!*@*.example.com", "nick!user@foo.example.com", true}, + {"*!*@*.example.com", "nick!user@other.net", false}, + {"badnick!*@*", "badnick!user@host", true}, + {"badnick!*@*", "goodnick!user@host", false}, + {"nick!user@host", "nick!user@host", true}, + {"nick!user@host", "nick!user@other", false}, + {"*", "anything", true}, + {"?ick!*@*", "nick!user@host", true}, + {"?ick!*@*", "nn!user@host", false}, + // Case-insensitive. + {"Nick!*@*", "nick!user@host", true}, + } + + for _, tc := range tests { + result := db.MatchBanMask(tc.pattern, tc.input) + if result != tc.match { + t.Errorf( + "MatchBanMask(%q, %q) = %v, want %v", + tc.pattern, tc.input, result, tc.match, + ) + } + } +} + +func TestChannelBanCRUD(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + chID, err := database.GetOrCreateChannel(ctx, "#test") + if err != nil { + t.Fatal(err) + } + + // No bans initially. + bans, err := database.ListChannelBans(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if len(bans) != 0 { + t.Fatalf("expected 0 bans, got %d", len(bans)) + } + + // Add a ban. + err = database.AddChannelBan( + ctx, chID, "*!*@evil.com", "op", + ) + if err != nil { + t.Fatal(err) + } + + bans, err = database.ListChannelBans(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if len(bans) != 1 { + t.Fatalf("expected 1 ban, got %d", len(bans)) + } + + if bans[0].Mask != "*!*@evil.com" { + t.Fatalf("wrong mask: %s", bans[0].Mask) + } + + // Duplicate add is ignored (OR IGNORE). + err = database.AddChannelBan( + ctx, chID, "*!*@evil.com", "op2", + ) + if err != nil { + t.Fatal(err) + } + + bans, _ = database.ListChannelBans(ctx, chID) + if len(bans) != 1 { + t.Fatalf("expected 1 ban after dup, got %d", len(bans)) + } + + // Remove ban. + err = database.RemoveChannelBan( + ctx, chID, "*!*@evil.com", + ) + if err != nil { + t.Fatal(err) + } + + bans, _ = database.ListChannelBans(ctx, chID) + if len(bans) != 0 { + t.Fatalf("expected 0 bans after remove, got %d", len(bans)) + } +} + +func TestIsSessionBanned(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sid, _, _, err := database.CreateSession( + ctx, "victim", "victim", "evil.com", "", + ) + if err != nil { + t.Fatal(err) + } + + chID, err := database.GetOrCreateChannel(ctx, "#bantest") + if err != nil { + t.Fatal(err) + } + + // Not banned initially. + banned, err := database.IsSessionBanned(ctx, chID, sid) + if err != nil { + t.Fatal(err) + } + + if banned { + t.Fatal("should not be banned initially") + } + + // Add ban matching the hostmask. + err = database.AddChannelBan( + ctx, chID, "*!*@evil.com", "op", + ) + if err != nil { + t.Fatal(err) + } + + banned, err = database.IsSessionBanned(ctx, chID, sid) + if err != nil { + t.Fatal(err) + } + + if !banned { + t.Fatal("should be banned") + } +} + +func TestChannelInviteOnly(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + chID, err := database.GetOrCreateChannel(ctx, "#invite") + if err != nil { + t.Fatal(err) + } + + // Default: not invite-only. + isIO, err := database.IsChannelInviteOnly(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if isIO { + t.Fatal("should not be invite-only by default") + } + + // Set invite-only. + err = database.SetChannelInviteOnly(ctx, chID, true) + if err != nil { + t.Fatal(err) + } + + isIO, _ = database.IsChannelInviteOnly(ctx, chID) + if !isIO { + t.Fatal("should be invite-only") + } + + // Unset. + err = database.SetChannelInviteOnly(ctx, chID, false) + if err != nil { + t.Fatal(err) + } + + isIO, _ = database.IsChannelInviteOnly(ctx, chID) + if isIO { + t.Fatal("should not be invite-only") + } +} + +func TestChannelInviteCRUD(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sid, _, _, err := database.CreateSession( + ctx, "invited", "", "", "", + ) + if err != nil { + t.Fatal(err) + } + + chID, err := database.GetOrCreateChannel(ctx, "#inv") + if err != nil { + t.Fatal(err) + } + + // No invite initially. + has, err := database.HasChannelInvite(ctx, chID, sid) + if err != nil { + t.Fatal(err) + } + + if has { + t.Fatal("should not have invite") + } + + // Add invite. + err = database.AddChannelInvite(ctx, chID, sid, "op") + if err != nil { + t.Fatal(err) + } + + has, _ = database.HasChannelInvite(ctx, chID, sid) + if !has { + t.Fatal("should have invite") + } + + // Clear invite. + err = database.ClearChannelInvite(ctx, chID, sid) + if err != nil { + t.Fatal(err) + } + + has, _ = database.HasChannelInvite(ctx, chID, sid) + if has { + t.Fatal("invite should be cleared") + } +} + +func TestChannelSecret(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + chID, err := database.GetOrCreateChannel(ctx, "#secret") + if err != nil { + t.Fatal(err) + } + + // Default: not secret. + isSec, err := database.IsChannelSecret(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if isSec { + t.Fatal("should not be secret by default") + } + + err = database.SetChannelSecret(ctx, chID, true) + if err != nil { + t.Fatal(err) + } + + isSec, _ = database.IsChannelSecret(ctx, chID) + if !isSec { + t.Fatal("should be secret") + } +} + +// createTestSession is a helper to create a session and +// return only the session ID. +func createTestSession( + t *testing.T, + database *db.Database, + nick string, +) int64 { + t.Helper() + + sid, _, _, err := database.CreateSession( + t.Context(), nick, "", "", "", + ) + if err != nil { + t.Fatalf("create session %s: %v", nick, err) + } + + return sid +} + +func TestSecretChannelFiltering(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + // Create two sessions. + sid1 := createTestSession(t, database, "member") + sid2 := createTestSession(t, database, "outsider") + + // Create a secret channel. + chID, _ := database.GetOrCreateChannel(ctx, "#secret") + _ = database.SetChannelSecret(ctx, chID, true) + _ = database.JoinChannel(ctx, chID, sid1) + + // Create a non-secret channel. + chID2, _ := database.GetOrCreateChannel(ctx, "#public") + _ = database.JoinChannel(ctx, chID2, sid1) + + // Member should see both. + list, err := database.ListAllChannelsWithCountsFiltered( + ctx, sid1, + ) + if err != nil { + t.Fatal(err) + } + + if len(list) != 2 { + t.Fatalf("member should see 2 channels, got %d", len(list)) + } + + // Outsider should only see public. + list, _ = database.ListAllChannelsWithCountsFiltered( + ctx, sid2, + ) + if len(list) != 1 { + t.Fatalf("outsider should see 1 channel, got %d", len(list)) + } + + if list[0].Name != "#public" { + t.Fatalf("outsider should see #public, got %s", list[0].Name) + } +} + +func TestWhoisChannelFiltering(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sid1 := createTestSession(t, database, "target") + sid2 := createTestSession(t, database, "querier") + + // Create secret channel, target joins it. + chID, _ := database.GetOrCreateChannel(ctx, "#hidden") + _ = database.SetChannelSecret(ctx, chID, true) + _ = database.JoinChannel(ctx, chID, sid1) + + // Querier (non-member) should not see the channel. + channels, err := database.GetSessionChannelsFiltered( + ctx, sid1, sid2, + ) + if err != nil { + t.Fatal(err) + } + + if len(channels) != 0 { + t.Fatalf( + "querier should see 0 channels, got %d", + len(channels), + ) + } + + // Target querying self should see it. + channels, _ = database.GetSessionChannelsFiltered( + ctx, sid1, sid1, + ) + if len(channels) != 1 { + t.Fatalf( + "self-query should see 1 channel, got %d", + len(channels), + ) + } +} + +//nolint:dupl // structurally similar to TestChannelUserLimit +func TestChannelKey(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + chID, err := database.GetOrCreateChannel(ctx, "#keyed") + if err != nil { + t.Fatal(err) + } + + // Default: no key. + key, err := database.GetChannelKey(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if key != "" { + t.Fatalf("expected empty key, got %q", key) + } + + // Set key. + err = database.SetChannelKey(ctx, chID, "secret123") + if err != nil { + t.Fatal(err) + } + + key, _ = database.GetChannelKey(ctx, chID) + if key != "secret123" { + t.Fatalf("expected secret123, got %q", key) + } + + // Clear key. + err = database.SetChannelKey(ctx, chID, "") + if err != nil { + t.Fatal(err) + } + + key, _ = database.GetChannelKey(ctx, chID) + if key != "" { + t.Fatalf("expected empty key, got %q", key) + } +} + +//nolint:dupl // structurally similar to TestChannelKey +func TestChannelUserLimit(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + chID, err := database.GetOrCreateChannel(ctx, "#limited") + if err != nil { + t.Fatal(err) + } + + // Default: no limit. + limit, err := database.GetChannelUserLimit(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if limit != 0 { + t.Fatalf("expected 0 limit, got %d", limit) + } + + // Set limit. + err = database.SetChannelUserLimit(ctx, chID, 50) + if err != nil { + t.Fatal(err) + } + + limit, _ = database.GetChannelUserLimit(ctx, chID) + if limit != 50 { + t.Fatalf("expected 50, got %d", limit) + } + + // Clear limit. + err = database.SetChannelUserLimit(ctx, chID, 0) + if err != nil { + t.Fatal(err) + } + + limit, _ = database.GetChannelUserLimit(ctx, chID) + if limit != 0 { + t.Fatalf("expected 0, got %d", limit) + } +} diff --git a/internal/db/schema/001_initial.sql b/internal/db/schema/001_initial.sql index 2ea9463..a29bdaa 100644 --- a/internal/db/schema/001_initial.sql +++ b/internal/db/schema/001_initial.sql @@ -42,10 +42,36 @@ CREATE TABLE IF NOT EXISTS channels ( hashcash_bits INTEGER NOT NULL DEFAULT 0, is_moderated INTEGER NOT NULL DEFAULT 0, is_topic_locked INTEGER NOT NULL DEFAULT 1, + is_invite_only INTEGER NOT NULL DEFAULT 0, + is_secret INTEGER NOT NULL DEFAULT 0, + channel_key TEXT NOT NULL DEFAULT '', + user_limit INTEGER NOT NULL DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); +-- Channel bans +CREATE TABLE IF NOT EXISTS channel_bans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + mask TEXT NOT NULL, + set_by TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(channel_id, mask) +); +CREATE INDEX IF NOT EXISTS idx_channel_bans_channel ON channel_bans(channel_id); + +-- Channel invites (in-memory would be simpler but DB survives restarts) +CREATE TABLE IF NOT EXISTS channel_invites ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + invited_by TEXT NOT NULL DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(channel_id, session_id) +); +CREATE INDEX IF NOT EXISTS idx_channel_invites_channel ON channel_invites(channel_id); + -- Channel members CREATE TABLE IF NOT EXISTS channel_members ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 6f06565..34f1987 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -447,7 +447,7 @@ func (hdlr *Handlers) deliverWelcome( // 004 RPL_MYINFO hdlr.enqueueNumeric( ctx, clientID, irc.RplMyInfo, nick, - []string{srvName, version, "", "mnst"}, + []string{srvName, version, "", "ikmnostl"}, "", ) @@ -458,7 +458,7 @@ func (hdlr *Handlers) deliverWelcome( "CHANTYPES=#", "NICKLEN=32", "PREFIX=(ov)@+", - "CHANMODES=,,H,mnst", + "CHANMODES=b,k,Hl,imnst", "NETWORK=neoirc", "CASEMAPPING=ascii", }, @@ -1029,6 +1029,7 @@ func (hdlr *Handlers) dispatchCommand( hdlr.handleJoin( writer, request, sessionID, clientID, nick, target, + bodyLines, ) case irc.CmdPart: hdlr.handlePart( @@ -1051,6 +1052,11 @@ func (hdlr *Handlers) dispatchCommand( sessionID, clientID, nick, target, body, bodyLines, ) + case irc.CmdInvite: + hdlr.handleInvite( + writer, request, + sessionID, clientID, nick, bodyLines, + ) case irc.CmdKick: hdlr.handleKick( writer, request, @@ -1222,40 +1228,22 @@ func (hdlr *Handlers) handleChannelMsg( ) { ctx := request.Context() - chID, err := hdlr.params.Database.GetChannelByName( - ctx, target, + chID, ok := hdlr.resolveChannelForSend( + writer, request, + sessionID, clientID, nick, target, ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, []string{target}, - "No such channel", - ) - + if !ok { return } - isMember, err := hdlr.params.Database.IsChannelMember( - ctx, chID, sessionID, - ) - if err != nil { - hdlr.log.Error( - "check membership failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - if !isMember { + // Enforce +b (ban): banned users cannot send. + isBanned, banErr := hdlr.params.Database. + IsSessionBanned(ctx, chID, sessionID) + if banErr == nil && isBanned { hdlr.respondIRCError( writer, request, clientID, sessionID, irc.ErrCannotSendToChan, nick, []string{target}, - "Cannot send to channel", + "Cannot send to channel (+b)", ) return @@ -1269,11 +1257,8 @@ func (hdlr *Handlers) handleChannelMsg( return } - // NOTICE skips hashcash validation on +H channels - // (servers and services use NOTICE). - isNotice := command == irc.CmdNotice - - if !isNotice { + // NOTICE skips hashcash validation on +H channels. + if command != irc.CmdNotice { hashcashErr := hdlr.validateChannelHashcash( request, clientID, sessionID, writer, nick, target, body, meta, chID, @@ -1289,6 +1274,58 @@ func (hdlr *Handlers) handleChannelMsg( ) } +// resolveChannelForSend validates that the channel exists +// and the session is a member. Returns the channel ID. +func (hdlr *Handlers) resolveChannelForSend( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, target string, +) (int64, bool) { + ctx := request.Context() + + chID, err := hdlr.params.Database.GetChannelByName( + ctx, target, + ) + if err != nil { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNoSuchChannel, nick, []string{target}, + "No such channel", + ) + + return 0, false + } + + isMember, memErr := hdlr.params.Database.IsChannelMember( + ctx, chID, sessionID, + ) + if memErr != nil { + hdlr.log.Error( + "check membership failed", "error", memErr, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return 0, false + } + + if !isMember { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrCannotSendToChan, nick, []string{target}, + "Cannot send to channel", + ) + + return 0, false + } + + return chID, true +} + // checkModeratedSend checks if a user can send to a // moderated channel. Returns true if sending is allowed. func (hdlr *Handlers) checkModeratedSend( @@ -1595,6 +1632,7 @@ func (hdlr *Handlers) handleJoin( request *http.Request, sessionID, clientID int64, nick, target string, + bodyLines func() []string, ) { if target == "" { hdlr.respondIRCError( @@ -1621,9 +1659,17 @@ func (hdlr *Handlers) handleJoin( return } + // Extract key from body lines (JOIN #channel key). + var suppliedKey string + + lines := bodyLines() + if len(lines) > 0 { + suppliedKey = lines[0] + } + hdlr.executeJoin( writer, request, - sessionID, clientID, nick, channel, + sessionID, clientID, nick, channel, suppliedKey, ) } @@ -1631,10 +1677,59 @@ func (hdlr *Handlers) executeJoin( writer http.ResponseWriter, request *http.Request, sessionID, clientID int64, - nick, channel string, + nick, channel, suppliedKey string, ) { ctx := request.Context() + chID, isCreator, ok := hdlr.resolveJoinChannel( + writer, request, channel, + ) + if !ok { + return + } + + // Tier 2 join checks only apply to non-creators. + if !isCreator { + if !hdlr.checkJoinAllowed( + writer, request, + sessionID, clientID, nick, + channel, chID, suppliedKey, + ) { + return + } + } + + if err := hdlr.addMemberToChannel( + ctx, chID, sessionID, isCreator, + ); err != nil { + hdlr.log.Error( + "join channel failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + hdlr.broadcastJoin( + writer, request, + sessionID, clientID, nick, channel, chID, + ) +} + +// resolveJoinChannel gets or creates the channel and +// determines if the joiner would be the first member +// (i.e., the channel creator/op). +func (hdlr *Handlers) resolveJoinChannel( + writer http.ResponseWriter, + request *http.Request, + channel string, +) (int64, bool, bool) { + ctx := request.Context() + chID, err := hdlr.params.Database.GetOrCreateChannel( ctx, channel, ) @@ -1648,11 +1743,9 @@ func (hdlr *Handlers) executeJoin( http.StatusInternalServerError, ) - return + return 0, false, false } - // Check if channel is empty before joining — first - // joiner becomes operator. memberCount, countErr := hdlr.params.Database. CountChannelMembers(ctx, chID) if countErr != nil { @@ -1663,29 +1756,50 @@ func (hdlr *Handlers) executeJoin( isCreator := countErr == nil && memberCount == 0 + return chID, isCreator, true +} + +// addMemberToChannel adds a session to a channel, with +// operator privileges if they're the first member. +func (hdlr *Handlers) addMemberToChannel( + ctx context.Context, + chID, sessionID int64, + isCreator bool, +) error { if isCreator { - err = hdlr.params.Database.JoinChannelAsOperator( - ctx, chID, sessionID, - ) - } else { - err = hdlr.params.Database.JoinChannel( + err := hdlr.params.Database.JoinChannelAsOperator( ctx, chID, sessionID, ) + if err != nil { + return fmt.Errorf( + "join as operator: %w", err, + ) + } + + return nil } + err := hdlr.params.Database.JoinChannel( + ctx, chID, sessionID, + ) if err != nil { - hdlr.log.Error( - "join channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return + return fmt.Errorf("join channel: %w", err) } + return nil +} + +// broadcastJoin fans out the JOIN, clears invites, and +// sends numerics. +func (hdlr *Handlers) broadcastJoin( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, +) { + ctx := request.Context() + memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( ctx, chID, ) @@ -1694,6 +1808,10 @@ func (hdlr *Handlers) executeJoin( request, irc.CmdJoin, nick, channel, nil, memberIDs, ) + _ = hdlr.params.Database.ClearChannelInvite( + ctx, chID, sessionID, + ) + hdlr.deliverJoinNumerics( request, clientID, sessionID, nick, channel, chID, ) @@ -1706,6 +1824,87 @@ func (hdlr *Handlers) executeJoin( http.StatusOK) } +// checkJoinAllowed runs Tier 2 restrictions for an +// existing channel. Returns true if join is allowed. +func (hdlr *Handlers) checkJoinAllowed( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + suppliedKey string, +) bool { + ctx := request.Context() + + // 1. Ban check — prevents banned users from joining. + isBanned, banErr := hdlr.params.Database. + IsSessionBanned(ctx, chID, sessionID) + if banErr == nil && isBanned { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrBannedFromChan, nick, + []string{channel}, + "Cannot join channel (+b)", + ) + + return false + } + + // 2. Invite-only check (+i). + isInviteOnly, ioErr := hdlr.params.Database. + IsChannelInviteOnly(ctx, chID) + if ioErr == nil && isInviteOnly { + hasInvite, invErr := hdlr.params.Database. + HasChannelInvite(ctx, chID, sessionID) + if invErr != nil || !hasInvite { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrInviteOnlyChan, nick, + []string{channel}, + "Cannot join channel (+i)", + ) + + return false + } + } + + // 3. Channel key check (+k). + key, keyErr := hdlr.params.Database. + GetChannelKey(ctx, chID) + if keyErr == nil && key != "" { + if suppliedKey != key { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrBadChannelKey, nick, + []string{channel}, + "Cannot join channel (+k)", + ) + + return false + } + } + + // 4. User limit check (+l). + limit, limErr := hdlr.params.Database. + GetChannelUserLimit(ctx, chID) + if limErr == nil && limit > 0 { + count, cntErr := hdlr.params.Database. + CountChannelMembers(ctx, chID) + if cntErr == nil && count >= int64(limit) { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrChannelIsFull, nick, + []string{channel}, + "Cannot join channel (+l)", + ) + + return false + } + } + + return true +} + // deliverJoinNumerics sends RPL_TOPIC/RPL_NOTOPIC, // RPL_NAMREPLY, and RPL_ENDOFNAMES to the joining client. func (hdlr *Handlers) deliverJoinNumerics( @@ -2394,17 +2593,17 @@ func (hdlr *Handlers) handleChannelMode( // buildChannelModeString constructs the current mode // string for a channel, including +n (always on), +t, +m, -// and +H with its parameter. +// +i, +s, +k, +l, and +H with their parameters. func (hdlr *Handlers) buildChannelModeString( ctx context.Context, chID int64, ) string { modes := "+n" - isTopicLocked, tlErr := hdlr.params.Database. - IsChannelTopicLocked(ctx, chID) - if tlErr == nil && isTopicLocked { - modes += "t" + isInviteOnly, ioErr := hdlr.params.Database. + IsChannelInviteOnly(ctx, chID) + if ioErr == nil && isInviteOnly { + modes += "i" } isModerated, modErr := hdlr.params.Database. @@ -2413,13 +2612,42 @@ func (hdlr *Handlers) buildChannelModeString( modes += "m" } + isSecret, secErr := hdlr.params.Database. + IsChannelSecret(ctx, chID) + if secErr == nil && isSecret { + modes += "s" + } + + isTopicLocked, tlErr := hdlr.params.Database. + IsChannelTopicLocked(ctx, chID) + if tlErr == nil && isTopicLocked { + modes += "t" + } + + var modeParams string + + key, keyErr := hdlr.params.Database. + GetChannelKey(ctx, chID) + if keyErr == nil && key != "" { + modes += "k" + modeParams += " " + key + } + + limit, limErr := hdlr.params.Database. + GetChannelUserLimit(ctx, chID) + if limErr == nil && limit > 0 { + modes += "l" + modeParams += " " + strconv.Itoa(limit) + } + bits, bitsErr := hdlr.params.Database. GetChannelHashcashBits(ctx, chID) if bitsErr == nil && bits > 0 { - modes += fmt.Sprintf("H %d", bits) + modes += "H" + modeParams += " " + strconv.Itoa(bits) } - return modes + return modes + modeParams } // queryChannelMode sends RPL_CHANNELMODEIS and @@ -2501,9 +2729,39 @@ func (hdlr *Handlers) applyChannelMode( chID int64, modeArgs []string, ) { - ctx := request.Context() modeStr := modeArgs[0] + if hdlr.applyUserModeIfMatched( + writer, request, sessionID, clientID, + nick, channel, chID, modeStr, modeArgs, + ) { + return + } + + if hdlr.applyChannelFlagIfMatched( + writer, request, sessionID, clientID, + nick, channel, chID, modeStr, + ) { + return + } + + hdlr.applyParameterizedMode( + writer, request, sessionID, clientID, + nick, channel, chID, modeStr, modeArgs, + ) +} + +// applyUserModeIfMatched handles +o/-o and +v/-v. +// Returns true if the mode was handled. +func (hdlr *Handlers) applyUserModeIfMatched( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + modeStr string, + modeArgs []string, +) bool { switch modeStr { case "+o", "-o": hdlr.applyUserMode( @@ -2511,35 +2769,104 @@ func (hdlr *Handlers) applyChannelMode( sessionID, clientID, nick, channel, chID, modeArgs, true, ) + + return true case "+v", "-v": hdlr.applyUserMode( writer, request, sessionID, clientID, nick, channel, chID, modeArgs, false, ) - case "+m": - hdlr.setChannelFlag( + + return true + default: + return false + } +} + +// applyChannelFlagIfMatched handles simple boolean modes +// (+m/-m, +t/-t, +i/-i, +s/-s). +// Returns true if the mode was handled. +func (hdlr *Handlers) applyChannelFlagIfMatched( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + modeStr string, +) bool { + flagMap := map[string]struct { + flag string + setting bool + }{ + "+m": {"m", true}, "-m": {"m", false}, + "+t": {"t", true}, "-t": {"t", false}, + "+i": {"i", true}, "-i": {"i", false}, + "+s": {"s", true}, "-s": {"s", false}, + } + + entry, exists := flagMap[modeStr] + if !exists { + return false + } + + hdlr.setChannelFlag( + writer, request, + sessionID, clientID, nick, + channel, chID, entry.flag, entry.setting, + ) + + return true +} + +// applyParameterizedMode handles modes that take +// parameters (+k/-k, +l/-l, +b/-b, +H/-H) and +// unknown modes. +func (hdlr *Handlers) applyParameterizedMode( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + modeStr string, + modeArgs []string, +) { + switch modeStr { + case "+k": + hdlr.setChannelKeyMode( writer, request, sessionID, clientID, nick, - channel, chID, "m", true, + channel, chID, modeArgs, ) - case "-m": - hdlr.setChannelFlag( + case "-k": + hdlr.clearChannelKeyMode( writer, request, sessionID, clientID, nick, - channel, chID, "m", false, + channel, chID, ) - case "+t": - hdlr.setChannelFlag( + case "+l": + hdlr.setChannelLimitMode( writer, request, sessionID, clientID, nick, - channel, chID, "t", true, + channel, chID, modeArgs, ) - case "-t": - hdlr.setChannelFlag( + case "-l": + hdlr.clearChannelLimitMode( writer, request, sessionID, clientID, nick, - channel, chID, "t", false, + channel, chID, + ) + case "+b": + hdlr.handleBanMode( + writer, request, + sessionID, clientID, nick, + channel, chID, modeArgs, true, + ) + case "-b": + hdlr.handleBanMode( + writer, request, + sessionID, clientID, nick, + channel, chID, modeArgs, false, ) case "+H": hdlr.setHashcashMode( @@ -2554,9 +2881,9 @@ func (hdlr *Handlers) applyChannelMode( channel, chID, ) default: - // Unknown or unsupported mode change. hdlr.enqueueNumeric( - ctx, clientID, irc.ErrUnknownMode, nick, + request.Context(), clientID, + irc.ErrUnknownMode, nick, []string{modeStr}, "is unknown mode char to me", ) @@ -2742,6 +3069,14 @@ func (hdlr *Handlers) setChannelFlag( err = hdlr.params.Database.SetChannelTopicLocked( ctx, chID, setting, ) + case "i": + err = hdlr.params.Database.SetChannelInviteOnly( + ctx, chID, setting, + ) + case "s": + err = hdlr.params.Database.SetChannelSecret( + ctx, chID, setting, + ) } if err != nil { @@ -2896,6 +3231,570 @@ func (hdlr *Handlers) clearHashcashMode( http.StatusOK) } +// handleBanMode handles MODE #channel +b/-b [mask]. +// +b with no argument lists bans; +b with argument adds +// a ban; -b removes a ban. +func (hdlr *Handlers) handleBanMode( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + modeArgs []string, + adding bool, +) { + // +b with no argument: list bans. + if adding && len(modeArgs) < 2 { + hdlr.listBans( + writer, request, + sessionID, clientID, nick, channel, chID, + ) + + return + } + + if !hdlr.requireChannelOp( + writer, request, + sessionID, clientID, nick, channel, chID, + ) { + return + } + + if len(modeArgs) < 2 { //nolint:mnd // mode + mask + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdMode}, + "Not enough parameters", + ) + + return + } + + hdlr.executeBanChange( + writer, request, nick, channel, + chID, modeArgs[1], adding, + ) +} + +// executeBanChange applies a ban add/remove and +// broadcasts the mode change. +func (hdlr *Handlers) executeBanChange( + writer http.ResponseWriter, + request *http.Request, + nick, channel string, + chID int64, + mask string, + adding bool, +) { + ctx := request.Context() + + var err error + if adding { + err = hdlr.params.Database.AddChannelBan( + ctx, chID, mask, nick, + ) + } else { + err = hdlr.params.Database.RemoveChannelBan( + ctx, chID, mask, + ) + } + + if err != nil { + hdlr.log.Error( + "ban mode change failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + modePrefix := "+" + if !adding { + modePrefix = "-" + } + + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs(ctx, chID) + + modeBody, mErr := json.Marshal( + []string{modePrefix + "b", mask}, + ) + if mErr == nil { + _ = hdlr.fanOutSilent( + request, irc.CmdMode, nick, channel, + json.RawMessage(modeBody), memberIDs, + ) + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + +// listBans sends RPL_BANLIST (367) and +// RPL_ENDOFBANLIST (368) for a channel. +func (hdlr *Handlers) listBans( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, +) { + ctx := request.Context() + + bans, err := hdlr.params.Database.ListChannelBans( + ctx, chID, + ) + if err != nil { + hdlr.log.Error( + "list bans failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + for _, ban := range bans { + hdlr.enqueueNumeric( + ctx, clientID, irc.RplBanList, nick, + []string{ + channel, + ban.Mask, + ban.SetBy, + strconv.FormatInt( + ban.CreatedAt.Unix(), 10, + ), + }, + "", + ) + } + + hdlr.enqueueNumeric( + ctx, clientID, irc.RplEndOfBanList, nick, + []string{channel}, + "End of channel ban list", + ) + + hdlr.broker.Notify(sessionID) + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + +// setChannelKeyMode handles MODE #channel +k . +func (hdlr *Handlers) setChannelKeyMode( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + modeArgs []string, +) { + ctx := request.Context() + + if !hdlr.requireChannelOp( + writer, request, + sessionID, clientID, nick, channel, chID, + ) { + return + } + + if len(modeArgs) < 2 { //nolint:mnd // +k requires key arg + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdMode}, + "Not enough parameters (+k requires a key)", + ) + + return + } + + key := modeArgs[1] + + err := hdlr.params.Database.SetChannelKey( + ctx, chID, key, + ) + if err != nil { + hdlr.log.Error( + "set channel key failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + // Broadcast +k mode change. + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs(ctx, chID) + + modeBody, mErr := json.Marshal([]string{"+k", key}) + if mErr == nil { + _ = hdlr.fanOutSilent( + request, irc.CmdMode, nick, channel, + json.RawMessage(modeBody), memberIDs, + ) + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + +// clearChannelKeyMode handles MODE #channel -k *. +func (hdlr *Handlers) clearChannelKeyMode( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, +) { + ctx := request.Context() + + if !hdlr.requireChannelOp( + writer, request, + sessionID, clientID, nick, channel, chID, + ) { + return + } + + err := hdlr.params.Database.SetChannelKey( + ctx, chID, "", + ) + if err != nil { + hdlr.log.Error( + "clear channel key failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs(ctx, chID) + + modeBody, mErr := json.Marshal([]string{"-k", "*"}) + if mErr == nil { + _ = hdlr.fanOutSilent( + request, irc.CmdMode, nick, channel, + json.RawMessage(modeBody), memberIDs, + ) + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + +// setChannelLimitMode handles MODE #channel +l . +func (hdlr *Handlers) setChannelLimitMode( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, + modeArgs []string, +) { + ctx := request.Context() + + if !hdlr.requireChannelOp( + writer, request, + sessionID, clientID, nick, channel, chID, + ) { + return + } + + if len(modeArgs) < 2 { //nolint:mnd // +l requires limit arg + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdMode}, + "Not enough parameters (+l requires a limit)", + ) + + return + } + + limit, parseErr := strconv.Atoi(modeArgs[1]) + if parseErr != nil || limit <= 0 { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrUnknownMode, nick, []string{"+l"}, + "Invalid user limit (must be positive integer)", + ) + + return + } + + err := hdlr.params.Database.SetChannelUserLimit( + ctx, chID, limit, + ) + if err != nil { + hdlr.log.Error( + "set channel user limit failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + // Broadcast +l mode change. + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs(ctx, chID) + + modeBody, mErr := json.Marshal( + []string{"+l", modeArgs[1]}, + ) + if mErr == nil { + _ = hdlr.fanOutSilent( + request, irc.CmdMode, nick, channel, + json.RawMessage(modeBody), memberIDs, + ) + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + +// clearChannelLimitMode handles MODE #channel -l. +func (hdlr *Handlers) clearChannelLimitMode( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, +) { + ctx := request.Context() + + if !hdlr.requireChannelOp( + writer, request, + sessionID, clientID, nick, channel, chID, + ) { + return + } + + err := hdlr.params.Database.SetChannelUserLimit( + ctx, chID, 0, + ) + if err != nil { + hdlr.log.Error( + "clear channel user limit failed", + "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs(ctx, chID) + + modeBody, mErr := json.Marshal([]string{"-l"}) + if mErr == nil { + _ = hdlr.fanOutSilent( + request, irc.CmdMode, nick, channel, + json.RawMessage(modeBody), memberIDs, + ) + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + +// handleInvite processes the INVITE command. +func (hdlr *Handlers) handleInvite( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick string, + bodyLines func() []string, +) { + lines := bodyLines() + if len(lines) < 2 { //nolint:mnd // nick + channel + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdInvite}, + "Not enough parameters", + ) + + return + } + + targetNick := lines[0] + channel := lines[1] + + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + chID, targetSID, ok := hdlr.validateInvite( + writer, request, + sessionID, clientID, nick, + targetNick, channel, + ) + if !ok { + return + } + + hdlr.executeInvite( + writer, request, + sessionID, clientID, nick, + targetNick, channel, chID, targetSID, + ) +} + +// validateInvite checks channel, membership, permissions, +// and target for an INVITE command. +func (hdlr *Handlers) validateInvite( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, targetNick, channel string, +) (int64, int64, bool) { + ctx := request.Context() + + chID, err := hdlr.params.Database.GetChannelByName( + ctx, channel, + ) + if err != nil { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNoSuchChannel, nick, + []string{channel}, "No such channel", + ) + + return 0, 0, false + } + + isMember, memErr := hdlr.params.Database.IsChannelMember( + ctx, chID, sessionID, + ) + if memErr != nil || !isMember { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNotOnChannel, nick, + []string{channel}, + "You're not on that channel", + ) + + return 0, 0, false + } + + isInviteOnly, ioErr := hdlr.params.Database. + IsChannelInviteOnly(ctx, chID) + if ioErr == nil && isInviteOnly { + if !hdlr.requireChannelOp( + writer, request, + sessionID, clientID, nick, channel, chID, + ) { + return 0, 0, false + } + } + + targetSID, nickErr := hdlr.params.Database. + GetSessionByNick(ctx, targetNick) + if nickErr != nil { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNoSuchNick, nick, + []string{targetNick}, + "No such nick/channel", + ) + + return 0, 0, false + } + + alreadyIn, aiErr := hdlr.params.Database. + IsChannelMember(ctx, chID, targetSID) + if aiErr == nil && alreadyIn { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrUserOnChannel, nick, + []string{targetNick, channel}, + "is already on channel", + ) + + return 0, 0, false + } + + return chID, targetSID, true +} + +// executeInvite records the invite and sends +// notifications. +func (hdlr *Handlers) executeInvite( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, targetNick, channel string, + chID, targetSID int64, +) { + ctx := request.Context() + + invErr := hdlr.params.Database.AddChannelInvite( + ctx, chID, targetSID, nick, + ) + if invErr != nil { + hdlr.log.Error( + "add invite failed", "error", invErr, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + hdlr.enqueueNumeric( + ctx, clientID, irc.RplInviting, nick, + []string{targetNick, channel}, "", + ) + hdlr.broker.Notify(sessionID) + + invBody, mErr := json.Marshal([]string{channel}) + if mErr == nil { + _ = hdlr.fanOutSilent( + request, irc.CmdInvite, nick, targetNick, + json.RawMessage(invBody), + []int64{targetSID}, + ) + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) +} + // handleNames sends NAMES reply for a channel. func (hdlr *Handlers) handleNames( writer http.ResponseWriter, @@ -2975,8 +3874,10 @@ func (hdlr *Handlers) handleList( ) { ctx := request.Context() + // Use filtered list that hides +s channels + // from non-members. channels, err := hdlr.params.Database. - ListAllChannelsWithCounts(ctx) + ListAllChannelsWithCountsFiltered(ctx, sessionID) if err != nil { hdlr.log.Error( "list channels failed", "error", err, @@ -3084,7 +3985,8 @@ func (hdlr *Handlers) executeWhois( ) hdlr.deliverWhoisChannels( - ctx, clientID, nick, queryNick, targetSID, + ctx, clientID, nick, queryNick, + sessionID, targetSID, ) // 338 RPL_WHOISACTUALLY — oper-only. @@ -3193,10 +4095,14 @@ func (hdlr *Handlers) deliverWhoisChannels( ctx context.Context, clientID int64, nick, queryNick string, - targetSID int64, + querierSID, targetSID int64, ) { + // Use filtered query that hides +s channels from + // non-members. channels, chanErr := hdlr.params.Database. - GetSessionChannels(ctx, targetSID) + GetSessionChannelsFiltered( + ctx, targetSID, querierSID, + ) if chanErr != nil || len(channels) == 0 { return } diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index d52a9c2..6b0a477 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -4378,3 +4378,486 @@ func TestKickDefaultReason(t *testing.T) { ) } } + +// --- Tier 2 Handler Tests --- + +const ( + inviteCmd = "INVITE" + joinedStatus = "joined" +) + +// TestBanAddRemoveList verifies +b add, list, and -b +// remove via MODE commands. +func TestBanAddRemoveList(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("banop") + + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#bans", + }) + + // Add a ban. + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#bans", + bodyKey: []string{"+b", "*!*@evil.com"}, + }) + + _, lastID := tserver.pollMessages(opToken, 0) + + // List bans (+b with no argument). + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#bans", + bodyKey: []string{"+b"}, + }) + + msgs, _ := tserver.pollMessages(opToken, lastID) + + // Should have RPL_BANLIST (367). + banMsg := findNumericWithParams(msgs, "367") + if banMsg == nil { + t.Fatalf("expected 367 RPL_BANLIST, got %v", msgs) + } + + // Should have RPL_ENDOFBANLIST (368). + if !findNumeric(msgs, "368") { + t.Fatal("expected 368 RPL_ENDOFBANLIST") + } + + // Remove the ban. + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#bans", + bodyKey: []string{"-b", "*!*@evil.com"}, + }) + + _, lastID = tserver.pollMessages(opToken, lastID) + + // List again — should be empty (just end-of-list). + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#bans", + bodyKey: []string{"+b"}, + }) + + msgs, _ = tserver.pollMessages(opToken, lastID) + banMsg = findNumericWithParams(msgs, "367") + if banMsg != nil { + t.Fatal("expected no 367 after ban removal") + } + + if !findNumeric(msgs, "368") { + t.Fatal("expected 368 RPL_ENDOFBANLIST") + } +} + +// TestBanBlocksJoin verifies that a banned user cannot +// join a channel. +func TestBanBlocksJoin(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("banop2") + userToken := tserver.createSession("banned2") + + // Op creates channel and sets a ban. + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#banjoin", + }) + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#banjoin", + bodyKey: []string{"+b", "banned2!*@*"}, + }) + + // Banned user tries to join. + _, lastID := tserver.pollMessages(userToken, 0) + tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, toKey: "#banjoin", + }) + + msgs, _ := tserver.pollMessages(userToken, lastID) + + // Should get ERR_BANNEDFROMCHAN (474). + if !findNumeric(msgs, "474") { + t.Fatalf("expected 474 ERR_BANNEDFROMCHAN, got %v", msgs) + } +} + +// TestBanBlocksPrivmsg verifies that a banned user who +// is already in a channel cannot send messages. +func TestBanBlocksPrivmsg(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("banmsgop") + userToken := tserver.createSession("banmsgusr") + + // Both join. + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#banmsg", + }) + tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, toKey: "#banmsg", + }) + + // Op bans the user. + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#banmsg", + bodyKey: []string{"+b", "banmsgusr!*@*"}, + }) + + // User tries to send a message. + _, lastID := tserver.pollMessages(userToken, 0) + tserver.sendCommand(userToken, map[string]any{ + commandKey: privmsgCmd, + toKey: "#banmsg", + bodyKey: []string{"hello"}, + }) + + msgs, _ := tserver.pollMessages(userToken, lastID) + + // Should get ERR_CANNOTSENDTOCHAN (404). + if !findNumeric(msgs, "404") { + t.Fatalf("expected 404 ERR_CANNOTSENDTOCHAN, got %v", msgs) + } +} + +// TestInviteOnlyJoin verifies +i behavior: join rejected +// without invite, accepted with invite. +func TestInviteOnlyJoin(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("invop") + userToken := tserver.createSession("invusr") + + // Op creates channel and sets +i. + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#invonly", + }) + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#invonly", + bodyKey: []string{"+i"}, + }) + + // User tries to join without invite. + _, lastID := tserver.pollMessages(userToken, 0) + tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, toKey: "#invonly", + }) + + msgs, _ := tserver.pollMessages(userToken, lastID) + + if !findNumeric(msgs, "473") { + t.Fatalf( + "expected 473 ERR_INVITEONLYCHAN, got %v", + msgs, + ) + } + + // Op invites user. + tserver.sendCommand(opToken, map[string]any{ + commandKey: inviteCmd, + bodyKey: []string{"invusr", "#invonly"}, + }) + + // User tries again — should succeed with invite. + _, result := tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, toKey: "#invonly", + }) + + if result[statusKey] != joinedStatus { + t.Fatalf( + "expected join to succeed with invite, got %v", + result, + ) + } +} + +// TestSecretChannelHiddenFromList verifies +s hides a +// channel from LIST for non-members. +func TestSecretChannelHiddenFromList(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("secop") + outsiderToken := tserver.createSession("secout") + + // Op creates secret channel. + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#secret", + }) + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#secret", + bodyKey: []string{"+s"}, + }) + + // Outsider does LIST. + _, lastID := tserver.pollMessages(outsiderToken, 0) + tserver.sendCommand(outsiderToken, map[string]any{ + commandKey: "LIST", + }) + + msgs, _ := tserver.pollMessages(outsiderToken, lastID) + + // Should NOT see #secret in any 322 (RPL_LIST). + for _, msg := range msgs { + code, ok := msg["code"].(float64) + if !ok || int(code) != 322 { + continue + } + + params := getNumericParams(msg) + for _, p := range params { + if p == "#secret" { + t.Fatal("outsider should not see #secret in LIST") + } + } + } + + // Member does LIST — should see it. + _, lastID = tserver.pollMessages(opToken, 0) + tserver.sendCommand(opToken, map[string]any{ + commandKey: "LIST", + }) + + msgs, _ = tserver.pollMessages(opToken, lastID) + + found := false + + for _, msg := range msgs { + code, ok := msg["code"].(float64) + if !ok || int(code) != 322 { + continue + } + + params := getNumericParams(msg) + for _, p := range params { + if p == "#secret" { + found = true + } + } + } + + if !found { + t.Fatal("member should see #secret in LIST") + } +} + +// TestChannelKeyJoin verifies +k behavior: wrong/missing +// key is rejected, correct key allows join. +func TestChannelKeyJoin(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("keyop") + userToken := tserver.createSession("keyusr") + + // Op creates keyed channel. + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#keyed", + }) + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#keyed", + bodyKey: []string{"+k", "mykey"}, + }) + + // User tries without key. + _, lastID := tserver.pollMessages(userToken, 0) + tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, toKey: "#keyed", + }) + + msgs, _ := tserver.pollMessages(userToken, lastID) + + if !findNumeric(msgs, "475") { + t.Fatalf( + "expected 475 ERR_BADCHANNELKEY, got %v", + msgs, + ) + } + + // User tries with wrong key. + tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, + toKey: "#keyed", + bodyKey: []string{"wrongkey"}, + }) + + msgs, _ = tserver.pollMessages(userToken, lastID) + if !findNumeric(msgs, "475") { + t.Fatalf( + "expected 475 ERR_BADCHANNELKEY for wrong key, got %v", + msgs, + ) + } + + // User tries with correct key. + _, result := tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, + toKey: "#keyed", + bodyKey: []string{"mykey"}, + }) + + if result[statusKey] != joinedStatus { + t.Fatalf( + "expected join to succeed with correct key, got %v", + result, + ) + } +} + +// TestUserLimitEnforcement verifies +l behavior: blocks +// join when at capacity. +func TestUserLimitEnforcement(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("limop") + user1Token := tserver.createSession("limusr1") + user2Token := tserver.createSession("limusr2") + + // Op creates channel with limit 2. + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#limited", + }) + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#limited", + bodyKey: []string{"+l", "2"}, + }) + + // User1 joins — should succeed (2 members now: op + user1). + _, result := tserver.sendCommand(user1Token, map[string]any{ + commandKey: joinCmd, toKey: "#limited", + }) + if result[statusKey] != joinedStatus { + t.Fatalf("user1 should join, got %v", result) + } + + // User2 tries to join — should fail (at limit: 2/2). + _, lastID := tserver.pollMessages(user2Token, 0) + tserver.sendCommand(user2Token, map[string]any{ + commandKey: joinCmd, toKey: "#limited", + }) + + msgs, _ := tserver.pollMessages(user2Token, lastID) + + if !findNumeric(msgs, "471") { + t.Fatalf( + "expected 471 ERR_CHANNELISFULL, got %v", + msgs, + ) + } +} + +// TestModeStringIncludesNewModes verifies that querying +// channel mode returns the new modes (+i, +s, +k, +l). +func TestModeStringIncludesNewModes(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("modestrop") + + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#modestr", + }) + + // Set all tier 2 modes. + for _, modeChange := range [][]string{ + {"+i"}, {"+s"}, {"+k", "pw"}, {"+l", "50"}, + } { + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, + toKey: "#modestr", + bodyKey: modeChange, + }) + } + + _, lastID := tserver.pollMessages(opToken, 0) + + // Query mode. + tserver.sendCommand(opToken, map[string]any{ + commandKey: modeCmd, toKey: "#modestr", + }) + + msgs, _ := tserver.pollMessages(opToken, lastID) + modeMsg := findNumericWithParams(msgs, "324") + + if modeMsg == nil { + t.Fatal("expected 324 RPL_CHANNELMODEIS") + } + + params := getNumericParams(modeMsg) + if len(params) < 2 { + t.Fatalf("too few params in 324: %v", params) + } + + modeString := params[1] + + for _, c := range []string{"i", "s", "k", "l"} { + if !strings.Contains(modeString, c) { + t.Fatalf( + "mode string %q missing %q", + modeString, c, + ) + } + } +} + +// TestISUPPORT verifies the 005 numeric includes the +// updated CHANMODES string. +func TestISUPPORT(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("isupport") + + msgs, _ := tserver.pollMessages(token, 0) + + isupp := findNumericWithParams(msgs, "005") + if isupp == nil { + t.Fatal("expected 005 RPL_ISUPPORT") + } + + body, _ := isupp["body"].(string) + params := getNumericParams(isupp) + + combined := body + " " + strings.Join(params, " ") + + if !strings.Contains(combined, "CHANMODES=b,k,Hl,imnst") { + t.Fatalf( + "ISUPPORT missing updated CHANMODES, got body=%q params=%v", + body, params, + ) + } +} + +// TestNonOpCannotSetModes verifies non-operators +// cannot set +i, +s, +k, +l, +b. +func TestNonOpCannotSetModes(t *testing.T) { + tserver := newTestServer(t) + opToken := tserver.createSession("modeopx") + userToken := tserver.createSession("modeusrx") + + tserver.sendCommand(opToken, map[string]any{ + commandKey: joinCmd, toKey: "#noperm", + }) + tserver.sendCommand(userToken, map[string]any{ + commandKey: joinCmd, toKey: "#noperm", + }) + + modes := [][]string{ + {"+i"}, {"+s"}, {"+k", "key"}, {"+l", "10"}, + {"+b", "bad!*@*"}, + } + + for _, modeChange := range modes { + _, lastID := tserver.pollMessages(userToken, 0) + tserver.sendCommand(userToken, map[string]any{ + commandKey: modeCmd, + toKey: "#noperm", + bodyKey: modeChange, + }) + + msgs, _ := tserver.pollMessages(userToken, lastID) + + // Should get 482 ERR_CHANOPRIVSNEEDED. + if !findNumeric(msgs, "482") { + t.Fatalf( + "expected 482 for %v, got %v", + modeChange, msgs, + ) + } + } +} diff --git a/pkg/irc/commands.go b/pkg/irc/commands.go index 8b79e7d..91893ec 100644 --- a/pkg/irc/commands.go +++ b/pkg/irc/commands.go @@ -3,6 +3,7 @@ package irc // IRC command names (RFC 1459 / RFC 2812). const ( CmdAway = "AWAY" + CmdInvite = "INVITE" CmdJoin = "JOIN" CmdKick = "KICK" CmdList = "LIST" -- 2.49.1 From 48072cd26ee6a6e28746190e4972b8d5d220df60 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 24 Mar 2026 18:42:04 -0700 Subject: [PATCH 2/2] docs: fix RPL_MYINFO example in README to match code (ikmnostl) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4c6df20..be7d0f6 100644 --- a/README.md +++ b/README.md @@ -1079,7 +1079,7 @@ the server to the client (never C2S) and use 3-digit string codes in the | `001` | RPL_WELCOME | After session creation | `{"command":"001","to":"alice","body":["Welcome to the network, alice"]}` | | `002` | RPL_YOURHOST | After session creation | `{"command":"002","to":"alice","body":["Your host is neoirc, running version 0.1"]}` | | `003` | RPL_CREATED | After session creation | `{"command":"003","to":"alice","body":["This server was created 2026-02-10"]}` | -| `004` | RPL_MYINFO | After session creation | `{"command":"004","to":"alice","params":["neoirc","0.1","","mnst"]}` | +| `004` | RPL_MYINFO | After session creation | `{"command":"004","to":"alice","params":["neoirc","0.1","","ikmnostl"]}` | | `005` | RPL_ISUPPORT | After session creation | `{"command":"005","to":"alice","params":["CHANTYPES=#","NICKLEN=32","PREFIX=(ov)@+","CHANMODES=b,k,Hl,imnst","NETWORK=neoirc"],"body":["are supported by this server"]}` | | `221` | RPL_UMODEIS | In response to user MODE query | `{"command":"221","to":"alice","body":["+"]}` | | `251` | RPL_LUSERCLIENT | On connect or LUSERS command | `{"command":"251","to":"alice","body":["There are 5 users and 0 invisible on 1 servers"]}` | -- 2.49.1