fix: address 3 blocking review findings for IRC protocol listener
All checks were successful
check / check (push) Successful in 59s

1. ISUPPORT/applyChannelModes: extend IRC MODE handler to support +i/-i,
   +s/-s, +n/-n (routed through svc.SetChannelFlag), and +H/-H (hashcash
   bits with parameter parsing). Add 'n' (no external messages) as a
   proper DB-backed channel flag with is_no_external column (default: on).
   Update IRC ISUPPORT to CHANMODES=,,H,imnst to match actual support.

2. QueryChannelMode: rewrite to return complete mode string including all
   boolean flags (n, i, m, s, t) and parameterized modes (k, l, H),
   matching the HTTP handler's buildChannelModeString logic. Simplify
   buildChannelModeString to delegate to QueryChannelMode for consistency.

3. Service struct encapsulation: change exported fields (DB, Broker,
   Config, Log) to unexported (db, broker, config, log). Add NewTestService
   constructor for use by external test packages. Update ircserver
   export_test.go to use the new constructor.

Closes #89
This commit is contained in:
user
2026-03-28 11:48:01 -07:00
parent 260f798af4
commit f57a373053
7 changed files with 357 additions and 175 deletions

View File

@@ -2165,6 +2165,52 @@ func (database *Database) SetChannelSecret(
return nil return nil
} }
// --- No External Messages (+n) ---
// IsChannelNoExternal checks if a channel has +n mode.
func (database *Database) IsChannelNoExternal(
ctx context.Context,
channelID int64,
) (bool, error) {
var isNoExternal int
err := database.conn.QueryRowContext(ctx,
`SELECT is_no_external FROM channels
WHERE id = ?`,
channelID,
).Scan(&isNoExternal)
if err != nil {
return false, fmt.Errorf(
"check no external: %w", err,
)
}
return isNoExternal != 0, nil
}
// SetChannelNoExternal sets or unsets +n mode.
func (database *Database) SetChannelNoExternal(
ctx context.Context,
channelID int64,
noExternal bool,
) error {
val := 0
if noExternal {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET is_no_external = ?, updated_at = ?
WHERE id = ?`,
val, time.Now(), channelID)
if err != nil {
return fmt.Errorf("set no external: %w", err)
}
return nil
}
// ListAllChannelsWithCountsFiltered returns all channels // ListAllChannelsWithCountsFiltered returns all channels
// with member counts, excluding secret channels that // with member counts, excluding secret channels that
// the given session is not a member of. // the given session is not a member of.

View File

@@ -44,6 +44,7 @@ CREATE TABLE IF NOT EXISTS channels (
is_topic_locked INTEGER NOT NULL DEFAULT 1, is_topic_locked INTEGER NOT NULL DEFAULT 1,
is_invite_only INTEGER NOT NULL DEFAULT 0, is_invite_only INTEGER NOT NULL DEFAULT 0,
is_secret INTEGER NOT NULL DEFAULT 0, is_secret INTEGER NOT NULL DEFAULT 0,
is_no_external INTEGER NOT NULL DEFAULT 1,
channel_key TEXT NOT NULL DEFAULT '', channel_key TEXT NOT NULL DEFAULT '',
user_limit INTEGER NOT NULL DEFAULT 0, user_limit INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,

View File

@@ -2016,62 +2016,14 @@ func (hdlr *Handlers) handleChannelMode(
} }
// buildChannelModeString constructs the current mode // buildChannelModeString constructs the current mode
// string for a channel, including +n (always on), +t, +m, // string for a channel by delegating to the service
// +i, +s, +k, +l, and +H with their parameters. // layer's QueryChannelMode, which returns the complete
// mode string including all flags and parameters.
func (hdlr *Handlers) buildChannelModeString( func (hdlr *Handlers) buildChannelModeString(
ctx context.Context, ctx context.Context,
chID int64, chID int64,
) string { ) string {
modes := "+n" return hdlr.svc.QueryChannelMode(ctx, chID)
isInviteOnly, ioErr := hdlr.params.Database.
IsChannelInviteOnly(ctx, chID)
if ioErr == nil && isInviteOnly {
modes += "i"
}
isModerated, modErr := hdlr.params.Database.
IsChannelModerated(ctx, chID)
if modErr == nil && isModerated {
modes += "m"
}
isSecret, secErr := hdlr.params.Database.
IsChannelSecret(ctx, chID)
if secErr == nil && isSecret {
modes += "s"
}
isTopicLocked, tlErr := hdlr.params.Database.
IsChannelTopicLocked(ctx, chID)
if tlErr == nil && isTopicLocked {
modes += "t"
}
var modeParams string
key, keyErr := hdlr.params.Database.
GetChannelKey(ctx, chID)
if keyErr == nil && key != "" {
modes += "k"
modeParams += " " + key
}
limit, limErr := hdlr.params.Database.
GetChannelUserLimit(ctx, chID)
if limErr == nil && limit > 0 {
modes += "l"
modeParams += " " + strconv.Itoa(limit)
}
bits, bitsErr := hdlr.params.Database.
GetChannelHashcashBits(ctx, chID)
if bitsErr == nil && bits > 0 {
modes += "H"
modeParams += " " + strconv.Itoa(bits)
}
return modes + modeParams
} }
// queryChannelMode sends RPL_CHANNELMODEIS and // queryChannelMode sends RPL_CHANNELMODEIS and

View File

@@ -490,6 +490,124 @@ func (c *Conn) handleChannelMode(
) )
} }
// modeResult holds the delta strings produced by a
// single mode-char application.
type modeResult struct {
applied string
appliedArgs string
consumed int
skip bool
}
// applyHashcashMode handles +H/-H (hashcash difficulty).
func (c *Conn) applyHashcashMode(
ctx context.Context,
chID int64,
adding bool,
args []string,
argIdx int,
) modeResult {
if !adding {
_ = c.database.SetChannelHashcashBits(
ctx, chID, 0,
)
return modeResult{
applied: "-H",
appliedArgs: "",
consumed: 0,
skip: false,
}
}
if argIdx >= len(args) {
return modeResult{
applied: "",
appliedArgs: "",
consumed: 0,
skip: true,
}
}
bitsStr := args[argIdx]
bits, parseErr := strconv.Atoi(bitsStr)
if parseErr != nil ||
bits < 1 || bits > maxHashcashBits {
c.sendNumeric(
irc.ErrUnknownMode, "H",
"is unknown mode char to me",
)
return modeResult{
applied: "",
appliedArgs: "",
consumed: 1,
skip: true,
}
}
_ = c.database.SetChannelHashcashBits(
ctx, chID, bits,
)
return modeResult{
applied: "+H",
appliedArgs: " " + bitsStr,
consumed: 1,
skip: false,
}
}
// applyMemberMode handles +o/-o and +v/-v.
func (c *Conn) applyMemberMode(
ctx context.Context,
chID int64,
channel string,
modeChar rune,
adding bool,
args []string,
argIdx int,
) modeResult {
if argIdx >= len(args) {
return modeResult{
applied: "",
appliedArgs: "",
consumed: 0,
skip: true,
}
}
targetNick := args[argIdx]
err := c.svc.ApplyMemberMode(
ctx, chID, channel,
targetNick, modeChar, adding,
)
if err != nil {
c.sendIRCError(err)
return modeResult{
applied: "",
appliedArgs: "",
consumed: 1,
skip: true,
}
}
prefix := "+"
if !adding {
prefix = "-"
}
return modeResult{
applied: prefix + string(modeChar),
appliedArgs: " " + targetNick,
consumed: 1,
skip: false,
}
}
// applyChannelModes applies mode changes using the // applyChannelModes applies mode changes using the
// service for individual mode operations. // service for individual mode operations.
func (c *Conn) applyChannelModes( func (c *Conn) applyChannelModes(
@@ -505,52 +623,57 @@ func (c *Conn) applyChannelModes(
appliedArgs := "" appliedArgs := ""
for _, modeChar := range modeStr { for _, modeChar := range modeStr {
var res modeResult
switch modeChar { switch modeChar {
case '+': case '+':
adding = true adding = true
continue
case '-': case '-':
adding = false adding = false
case 'm', 't':
continue
case 'i', 'm', 'n', 's', 't':
_ = c.svc.SetChannelFlag( _ = c.svc.SetChannelFlag(
ctx, chID, modeChar, adding, ctx, chID, modeChar, adding,
) )
if adding { prefix := "+"
applied += "+" + string(modeChar) if !adding {
} else { prefix = "-"
applied += "-" + string(modeChar)
}
case 'o', 'v':
if argIdx >= len(args) {
break
} }
targetNick := args[argIdx] res = modeResult{
argIdx++ applied: prefix + string(modeChar),
appliedArgs: "",
err := c.svc.ApplyMemberMode( consumed: 0,
ctx, chID, channel, skip: false,
targetNick, modeChar, adding, }
case 'H':
res = c.applyHashcashMode(
ctx, chID, adding, args, argIdx,
)
case 'o', 'v':
res = c.applyMemberMode(
ctx, chID, channel,
modeChar, adding, args, argIdx,
) )
if err != nil {
c.sendIRCError(err)
continue
}
if adding {
applied += "+" + string(modeChar)
} else {
applied += "-" + string(modeChar)
}
appliedArgs += " " + targetNick
default: default:
c.sendNumeric( c.sendNumeric(
irc.ErrUnknownMode, irc.ErrUnknownMode,
string(modeChar), string(modeChar),
"is unknown mode char to me", "is unknown mode char to me",
) )
continue
}
argIdx += res.consumed
if !res.skip {
applied += res.applied
appliedArgs += res.appliedArgs
} }
} }

View File

@@ -28,6 +28,7 @@ const (
pongDeadline = 30 * time.Second pongDeadline = 30 * time.Second
maxNickLen = 32 maxNickLen = 32
minPasswordLen = 8 minPasswordLen = 8
maxHashcashBits = 40
) )
// cmdHandler is the signature for registered IRC command // cmdHandler is the signature for registered IRC command
@@ -434,7 +435,7 @@ func (c *Conn) deliverWelcome() {
"CHANTYPES=#", "CHANTYPES=#",
"NICKLEN=32", "NICKLEN=32",
"PREFIX=(ov)@+", "PREFIX=(ov)@+",
"CHANMODES=,,H,mnst", "CHANMODES=,,H,imnst",
"NETWORK="+c.serverSfx, "NETWORK="+c.serverSfx,
"are supported by this server", "are supported by this server",
) )

View File

@@ -19,12 +19,9 @@ func NewTestServer(
database *db.Database, database *db.Database,
brk *broker.Broker, brk *broker.Broker,
) *Server { ) *Server {
svc := &service.Service{ svc := service.NewTestService(
DB: database, database, brk, cfg, log,
Broker: brk, )
Config: cfg,
Log: log,
}
return &Server{ //nolint:exhaustruct return &Server{ //nolint:exhaustruct
log: log, log: log,

View File

@@ -8,6 +8,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"strconv"
"strings" "strings"
"git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/broker"
@@ -30,19 +31,35 @@ type Params struct {
// Service provides shared business logic for IRC commands. // Service provides shared business logic for IRC commands.
type Service struct { type Service struct {
DB *db.Database db *db.Database
Broker *broker.Broker broker *broker.Broker
Config *config.Config config *config.Config
Log *slog.Logger log *slog.Logger
} }
// New creates a new Service. // New creates a new Service.
func New(params Params) *Service { func New(params Params) *Service {
return &Service{ return &Service{
DB: params.Database, db: params.Database,
Broker: params.Broker, broker: params.Broker,
Config: params.Config, config: params.Config,
Log: params.Logger.Get(), log: params.Logger.Get(),
}
}
// NewTestService creates a Service for use in tests
// outside the service package.
func NewTestService(
database *db.Database,
brk *broker.Broker,
cfg *config.Config,
log *slog.Logger,
) *Service {
return &Service{
db: database,
broker: brk,
config: cfg,
log: log,
} }
} }
@@ -76,7 +93,7 @@ func (s *Service) FanOut(
params, body, meta json.RawMessage, params, body, meta json.RawMessage,
sessionIDs []int64, sessionIDs []int64,
) (int64, string, error) { ) (int64, string, error) {
dbID, msgUUID, err := s.DB.InsertMessage( dbID, msgUUID, err := s.db.InsertMessage(
ctx, command, from, to, params, body, meta, ctx, command, from, to, params, body, meta,
) )
if err != nil { if err != nil {
@@ -84,8 +101,8 @@ func (s *Service) FanOut(
} }
for _, sid := range sessionIDs { for _, sid := range sessionIDs {
_ = s.DB.EnqueueToSession(ctx, sid, dbID) _ = s.db.EnqueueToSession(ctx, sid, dbID)
s.Broker.Notify(sid) s.broker.Notify(sid)
} }
return dbID, msgUUID, nil return dbID, msgUUID, nil
@@ -120,7 +137,7 @@ func (s *Service) SendChannelMessage(
nick, command, channel string, nick, command, channel string,
body, meta json.RawMessage, body, meta json.RawMessage,
) (int64, string, error) { ) (int64, string, error) {
chID, err := s.DB.GetChannelByName(ctx, channel) chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil { if err != nil {
return 0, "", &IRCError{ return 0, "", &IRCError{
irc.ErrNoSuchChannel, irc.ErrNoSuchChannel,
@@ -129,7 +146,7 @@ func (s *Service) SendChannelMessage(
} }
} }
isMember, _ := s.DB.IsChannelMember( isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if !isMember { if !isMember {
@@ -141,7 +158,7 @@ func (s *Service) SendChannelMessage(
} }
// Ban check — banned users cannot send messages. // Ban check — banned users cannot send messages.
isBanned, banErr := s.DB.IsSessionBanned( isBanned, banErr := s.db.IsSessionBanned(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if banErr == nil && isBanned { if banErr == nil && isBanned {
@@ -152,12 +169,12 @@ func (s *Service) SendChannelMessage(
} }
} }
moderated, _ := s.DB.IsChannelModerated(ctx, chID) moderated, _ := s.db.IsChannelModerated(ctx, chID)
if moderated { if moderated {
isOp, _ := s.DB.IsChannelOperator( isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
isVoiced, _ := s.DB.IsChannelVoiced( isVoiced, _ := s.db.IsChannelVoiced(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
@@ -170,7 +187,7 @@ func (s *Service) SendChannelMessage(
} }
} }
memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
recipients := excludeSession(memberIDs, sessionID) recipients := excludeSession(memberIDs, sessionID)
dbID, uuid, fanErr := s.FanOut( dbID, uuid, fanErr := s.FanOut(
@@ -193,7 +210,7 @@ func (s *Service) SendDirectMessage(
nick, command, target string, nick, command, target string,
body, meta json.RawMessage, body, meta json.RawMessage,
) (*DirectMsgResult, error) { ) (*DirectMsgResult, error) {
targetSID, err := s.DB.GetSessionByNick(ctx, target) targetSID, err := s.db.GetSessionByNick(ctx, target)
if err != nil { if err != nil {
return nil, &IRCError{ return nil, &IRCError{
irc.ErrNoSuchNick, irc.ErrNoSuchNick,
@@ -202,7 +219,7 @@ func (s *Service) SendDirectMessage(
} }
} }
away, _ := s.DB.GetAway(ctx, targetSID) away, _ := s.db.GetAway(ctx, targetSID)
recipients := []int64{targetSID} recipients := []int64{targetSID}
if targetSID != sessionID { if targetSID != sessionID {
@@ -228,19 +245,19 @@ func (s *Service) JoinChannel(
sessionID int64, sessionID int64,
nick, channel, suppliedKey string, nick, channel, suppliedKey string,
) (*JoinResult, error) { ) (*JoinResult, error) {
chID, err := s.DB.GetOrCreateChannel(ctx, channel) chID, err := s.db.GetOrCreateChannel(ctx, channel)
if err != nil { if err != nil {
return nil, fmt.Errorf("get/create channel: %w", err) return nil, fmt.Errorf("get/create channel: %w", err)
} }
memberCount, countErr := s.DB.CountChannelMembers( memberCount, countErr := s.db.CountChannelMembers(
ctx, chID, ctx, chID,
) )
isCreator := countErr == nil && memberCount == 0 isCreator := countErr == nil && memberCount == 0
if !isCreator { if !isCreator {
if joinErr := checkJoinRestrictions( if joinErr := checkJoinRestrictions(
ctx, s.DB, chID, sessionID, ctx, s.db, chID, sessionID,
channel, suppliedKey, memberCount, channel, suppliedKey, memberCount,
); joinErr != nil { ); joinErr != nil {
return nil, joinErr return nil, joinErr
@@ -248,11 +265,11 @@ func (s *Service) JoinChannel(
} }
if isCreator { if isCreator {
err = s.DB.JoinChannelAsOperator( err = s.db.JoinChannelAsOperator(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
} else { } else {
err = s.DB.JoinChannel(ctx, chID, sessionID) err = s.db.JoinChannel(ctx, chID, sessionID)
} }
if err != nil { if err != nil {
@@ -260,9 +277,9 @@ func (s *Service) JoinChannel(
} }
// Clear invite after successful join. // Clear invite after successful join.
_ = s.DB.ClearChannelInvite(ctx, chID, sessionID) _ = s.db.ClearChannelInvite(ctx, chID, sessionID)
memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{channel}) //nolint:errchkjson body, _ := json.Marshal([]string{channel}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
@@ -284,7 +301,7 @@ func (s *Service) PartChannel(
sessionID int64, sessionID int64,
nick, channel, reason string, nick, channel, reason string,
) error { ) error {
chID, err := s.DB.GetChannelByName(ctx, channel) chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil { if err != nil {
return &IRCError{ return &IRCError{
irc.ErrNoSuchChannel, irc.ErrNoSuchChannel,
@@ -293,7 +310,7 @@ func (s *Service) PartChannel(
} }
} }
isMember, _ := s.DB.IsChannelMember( isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if !isMember { if !isMember {
@@ -304,7 +321,7 @@ func (s *Service) PartChannel(
} }
} }
memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
recipients := excludeSession(memberIDs, sessionID) recipients := excludeSession(memberIDs, sessionID)
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
@@ -313,8 +330,8 @@ func (s *Service) PartChannel(
nil, body, nil, recipients, nil, body, nil, recipients,
) )
s.DB.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec s.db.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec
s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec
return nil return nil
} }
@@ -326,7 +343,7 @@ func (s *Service) SetTopic(
sessionID int64, sessionID int64,
nick, channel, topic string, nick, channel, topic string,
) error { ) error {
chID, err := s.DB.GetChannelByName(ctx, channel) chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil { if err != nil {
return &IRCError{ return &IRCError{
irc.ErrNoSuchChannel, irc.ErrNoSuchChannel,
@@ -335,7 +352,7 @@ func (s *Service) SetTopic(
} }
} }
isMember, _ := s.DB.IsChannelMember( isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if !isMember { if !isMember {
@@ -346,9 +363,9 @@ func (s *Service) SetTopic(
} }
} }
topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID)
if topicLocked { if topicLocked {
isOp, _ := s.DB.IsChannelOperator( isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if !isOp { if !isOp {
@@ -360,15 +377,15 @@ func (s *Service) SetTopic(
} }
} }
if setErr := s.DB.SetTopic( if setErr := s.db.SetTopic(
ctx, channel, topic, ctx, channel, topic,
); setErr != nil { ); setErr != nil {
return fmt.Errorf("set topic: %w", setErr) return fmt.Errorf("set topic: %w", setErr)
} }
_ = s.DB.SetTopicMeta(ctx, channel, topic, nick) _ = s.db.SetTopicMeta(ctx, channel, topic, nick)
memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{topic}) //nolint:errchkjson body, _ := json.Marshal([]string{topic}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
@@ -387,7 +404,7 @@ func (s *Service) KickUser(
sessionID int64, sessionID int64,
nick, channel, targetNick, reason string, nick, channel, targetNick, reason string,
) error { ) error {
chID, err := s.DB.GetChannelByName(ctx, channel) chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil { if err != nil {
return &IRCError{ return &IRCError{
irc.ErrNoSuchChannel, irc.ErrNoSuchChannel,
@@ -396,7 +413,7 @@ func (s *Service) KickUser(
} }
} }
isOp, _ := s.DB.IsChannelOperator( isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if !isOp { if !isOp {
@@ -407,7 +424,7 @@ func (s *Service) KickUser(
} }
} }
targetSID, err := s.DB.GetSessionByNick( targetSID, err := s.db.GetSessionByNick(
ctx, targetNick, ctx, targetNick,
) )
if err != nil { if err != nil {
@@ -418,7 +435,7 @@ func (s *Service) KickUser(
} }
} }
isMember, _ := s.DB.IsChannelMember( isMember, _ := s.db.IsChannelMember(
ctx, chID, targetSID, ctx, chID, targetSID,
) )
if !isMember { if !isMember {
@@ -429,7 +446,7 @@ func (s *Service) KickUser(
} }
} }
memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
params, _ := json.Marshal( //nolint:errchkjson params, _ := json.Marshal( //nolint:errchkjson
[]string{targetNick}, []string{targetNick},
@@ -440,8 +457,8 @@ func (s *Service) KickUser(
params, body, nil, memberIDs, params, body, nil, memberIDs,
) )
s.DB.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec s.db.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec
s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec
return nil return nil
} }
@@ -453,7 +470,7 @@ func (s *Service) ChangeNick(
sessionID int64, sessionID int64,
oldNick, newNick string, oldNick, newNick string,
) error { ) error {
err := s.DB.ChangeNick(ctx, sessionID, newNick) err := s.db.ChangeNick(ctx, sessionID, newNick)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE") || if strings.Contains(err.Error(), "UNIQUE") ||
db.IsUniqueConstraintError(err) { db.IsUniqueConstraintError(err) {
@@ -485,7 +502,7 @@ func (s *Service) BroadcastQuit(
sessionID int64, sessionID int64,
nick, reason string, nick, reason string,
) { ) {
channels, err := s.DB.GetSessionChannels( channels, err := s.db.GetSessionChannels(
ctx, sessionID, ctx, sessionID,
) )
if err != nil { if err != nil {
@@ -495,7 +512,7 @@ func (s *Service) BroadcastQuit(
notified := make(map[int64]bool) notified := make(map[int64]bool)
for _, ch := range channels { for _, ch := range channels {
memberIDs, memErr := s.DB.GetChannelMemberIDs( memberIDs, memErr := s.db.GetChannelMemberIDs(
ctx, ch.ID, ctx, ch.ID,
) )
if memErr != nil { if memErr != nil {
@@ -526,11 +543,11 @@ func (s *Service) BroadcastQuit(
} }
for _, ch := range channels { for _, ch := range channels {
s.DB.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec s.db.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec
s.DB.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec s.db.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec
} }
s.DB.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec s.db.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec
} }
// SetAway sets or clears the away message. Returns true // SetAway sets or clears the away message. Returns true
@@ -540,7 +557,7 @@ func (s *Service) SetAway(
sessionID int64, sessionID int64,
message string, message string,
) (bool, error) { ) (bool, error) {
err := s.DB.SetAway(ctx, sessionID, message) err := s.db.SetAway(ctx, sessionID, message)
if err != nil { if err != nil {
return false, fmt.Errorf("set away: %w", err) return false, fmt.Errorf("set away: %w", err)
} }
@@ -555,8 +572,8 @@ func (s *Service) Oper(
sessionID int64, sessionID int64,
name, password string, name, password string,
) error { ) error {
cfgName := s.Config.OperName cfgName := s.config.OperName
cfgPassword := s.Config.OperPassword cfgPassword := s.config.OperPassword
// Use constant-time comparison and return the same // Use constant-time comparison and return the same
// error for all failures to prevent information // error for all failures to prevent information
@@ -575,7 +592,7 @@ func (s *Service) Oper(
} }
} }
_ = s.DB.SetSessionOper(ctx, sessionID, true) _ = s.db.SetSessionOper(ctx, sessionID, true)
return nil return nil
} }
@@ -587,7 +604,7 @@ func (s *Service) ValidateChannelOp(
sessionID int64, sessionID int64,
channel string, channel string,
) (int64, error) { ) (int64, error) {
chID, err := s.DB.GetChannelByName(ctx, channel) chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil { if err != nil {
return 0, &IRCError{ return 0, &IRCError{
irc.ErrNoSuchChannel, irc.ErrNoSuchChannel,
@@ -596,7 +613,7 @@ func (s *Service) ValidateChannelOp(
} }
} }
isOp, _ := s.DB.IsChannelOperator( isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID, ctx, chID, sessionID,
) )
if !isOp { if !isOp {
@@ -619,7 +636,7 @@ func (s *Service) ApplyMemberMode(
mode rune, mode rune,
adding bool, adding bool,
) error { ) error {
targetSID, err := s.DB.GetSessionByNick( targetSID, err := s.db.GetSessionByNick(
ctx, targetNick, ctx, targetNick,
) )
if err != nil { if err != nil {
@@ -630,7 +647,7 @@ func (s *Service) ApplyMemberMode(
} }
} }
isMember, _ := s.DB.IsChannelMember( isMember, _ := s.db.IsChannelMember(
ctx, chID, targetSID, ctx, chID, targetSID,
) )
if !isMember { if !isMember {
@@ -643,11 +660,11 @@ func (s *Service) ApplyMemberMode(
switch mode { switch mode {
case 'o': case 'o':
_ = s.DB.SetChannelMemberOperator( _ = s.db.SetChannelMemberOperator(
ctx, chID, targetSID, adding, ctx, chID, targetSID, adding,
) )
case 'v': case 'v':
_ = s.DB.SetChannelMemberVoiced( _ = s.db.SetChannelMemberVoiced(
ctx, chID, targetSID, adding, ctx, chID, targetSID, adding,
) )
} }
@@ -655,7 +672,8 @@ func (s *Service) ApplyMemberMode(
return nil return nil
} }
// SetChannelFlag applies +m/-m or +t/-t on a channel. // SetChannelFlag applies a simple boolean channel mode
// (+m/-m, +t/-t, +i/-i, +s/-s, +n/-n).
func (s *Service) SetChannelFlag( func (s *Service) SetChannelFlag(
ctx context.Context, ctx context.Context,
chID int64, chID int64,
@@ -664,29 +682,37 @@ func (s *Service) SetChannelFlag(
) error { ) error {
switch flag { switch flag {
case 'm': case 'm':
if err := s.DB.SetChannelModerated( if err := s.db.SetChannelModerated(
ctx, chID, setting, ctx, chID, setting,
); err != nil { ); err != nil {
return fmt.Errorf("set moderated: %w", err) return fmt.Errorf("set moderated: %w", err)
} }
case 't': case 't':
if err := s.DB.SetChannelTopicLocked( if err := s.db.SetChannelTopicLocked(
ctx, chID, setting, ctx, chID, setting,
); err != nil { ); err != nil {
return fmt.Errorf("set topic locked: %w", err) return fmt.Errorf("set topic locked: %w", err)
} }
case 'i': case 'i':
if err := s.DB.SetChannelInviteOnly( if err := s.db.SetChannelInviteOnly(
ctx, chID, setting, ctx, chID, setting,
); err != nil { ); err != nil {
return fmt.Errorf("set invite only: %w", err) return fmt.Errorf("set invite only: %w", err)
} }
case 's': case 's':
if err := s.DB.SetChannelSecret( if err := s.db.SetChannelSecret(
ctx, chID, setting, ctx, chID, setting,
); err != nil { ); err != nil {
return fmt.Errorf("set secret: %w", err) return fmt.Errorf("set secret: %w", err)
} }
case 'n':
if err := s.db.SetChannelNoExternal(
ctx, chID, setting,
); err != nil {
return fmt.Errorf(
"set no external: %w", err,
)
}
} }
return nil return nil
@@ -700,7 +726,7 @@ func (s *Service) BroadcastMode(
chID int64, chID int64,
modeText string, modeText string,
) { ) {
memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
@@ -709,24 +735,60 @@ func (s *Service) BroadcastMode(
) )
} }
// QueryChannelMode returns the channel mode string. // QueryChannelMode returns the complete channel mode
// string including all flags and parameterized modes.
func (s *Service) QueryChannelMode( func (s *Service) QueryChannelMode(
ctx context.Context, ctx context.Context,
chID int64, chID int64,
) string { ) string {
modes := "+" modes := "+"
moderated, _ := s.DB.IsChannelModerated(ctx, chID) noExternal, _ := s.db.IsChannelNoExternal(ctx, chID)
if noExternal {
modes += "n"
}
inviteOnly, _ := s.db.IsChannelInviteOnly(ctx, chID)
if inviteOnly {
modes += "i"
}
moderated, _ := s.db.IsChannelModerated(ctx, chID)
if moderated { if moderated {
modes += "m" modes += "m"
} }
topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) secret, _ := s.db.IsChannelSecret(ctx, chID)
if secret {
modes += "s"
}
topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID)
if topicLocked { if topicLocked {
modes += "t" modes += "t"
} }
return modes var modeParams string
key, _ := s.db.GetChannelKey(ctx, chID)
if key != "" {
modes += "k"
modeParams += " " + key
}
limit, _ := s.db.GetChannelUserLimit(ctx, chID)
if limit > 0 {
modes += "l"
modeParams += " " + strconv.Itoa(limit)
}
bits, _ := s.db.GetChannelHashcashBits(ctx, chID)
if bits > 0 {
modes += "H"
modeParams += " " + strconv.Itoa(bits)
}
return modes + modeParams
} }
// broadcastNickChange notifies channel peers of a nick // broadcastNickChange notifies channel peers of a nick
@@ -736,7 +798,7 @@ func (s *Service) broadcastNickChange(
sessionID int64, sessionID int64,
oldNick, newNick string, oldNick, newNick string,
) { ) {
channels, err := s.DB.GetSessionChannels( channels, err := s.db.GetSessionChannels(
ctx, sessionID, ctx, sessionID,
) )
if err != nil { if err != nil {
@@ -746,7 +808,7 @@ func (s *Service) broadcastNickChange(
body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson
notified := make(map[int64]bool) notified := make(map[int64]bool)
dbID, _, insErr := s.DB.InsertMessage( dbID, _, insErr := s.db.InsertMessage(
ctx, irc.CmdNick, oldNick, "", ctx, irc.CmdNick, oldNick, "",
nil, body, nil, nil, body, nil,
) )
@@ -755,12 +817,12 @@ func (s *Service) broadcastNickChange(
} }
// Notify the user themselves (for multi-client sync). // Notify the user themselves (for multi-client sync).
_ = s.DB.EnqueueToSession(ctx, sessionID, dbID) _ = s.db.EnqueueToSession(ctx, sessionID, dbID)
s.Broker.Notify(sessionID) s.broker.Notify(sessionID)
notified[sessionID] = true notified[sessionID] = true
for _, ch := range channels { for _, ch := range channels {
memberIDs, memErr := s.DB.GetChannelMemberIDs( memberIDs, memErr := s.db.GetChannelMemberIDs(
ctx, ch.ID, ctx, ch.ID,
) )
if memErr != nil { if memErr != nil {
@@ -774,8 +836,8 @@ func (s *Service) broadcastNickChange(
notified[mid] = true notified[mid] = true
_ = s.DB.EnqueueToSession(ctx, mid, dbID) _ = s.db.EnqueueToSession(ctx, mid, dbID)
s.Broker.Notify(mid) s.broker.Notify(mid)
} }
} }
} }