From 260f798af4b85f5dffde4e27f939040c909c0808 Mon Sep 17 00:00:00 2001 From: clawbot Date: Wed, 25 Mar 2026 18:01:36 -0700 Subject: [PATCH 1/2] feat: add IRC wire protocol listener with shared service layer Adds a backward-compatible IRC wire protocol listener (RFC 1459/2812) with a shared service layer used by both IRC and HTTP transports. - TCP listener on configurable port (default :6667) - Full IRC protocol: NICK, USER, JOIN, PART, PRIVMSG, MODE, TOPIC, etc. - Shared service layer (internal/service/) for consistent code paths - Tier 2 join restrictions (ban, invite-only, key, limit) in service layer - Ban check on PRIVMSG in service layer - SetChannelFlag handles +i and +s modes - Command dispatch via map[string]cmdHandler pattern - EXPOSE 6667 in Dockerfile - Service layer unit tests closes #89 --- Dockerfile | 2 +- README.md | 65 ++ cmd/neoircd/main.go | 12 +- internal/config/config.go | 3 + internal/db/testing.go | 25 + internal/handlers/api.go | 1541 +++++++---------------------- internal/handlers/api_test.go | 12 + internal/handlers/handlers.go | 7 +- internal/ircserver/commands.go | 1178 ++++++++++++++++++++++ internal/ircserver/conn.go | 501 ++++++++++ internal/ircserver/export_test.go | 52 + internal/ircserver/parser.go | 123 +++ internal/ircserver/parser_test.go | 328 ++++++ internal/ircserver/relay.go | 319 ++++++ internal/ircserver/server.go | 157 +++ internal/ircserver/server_test.go | 625 ++++++++++++ internal/service/service.go | 839 ++++++++++++++++ internal/service/service_test.go | 365 +++++++ pkg/irc/commands.go | 1 + 19 files changed, 4962 insertions(+), 1193 deletions(-) create mode 100644 internal/db/testing.go create mode 100644 internal/ircserver/commands.go create mode 100644 internal/ircserver/conn.go create mode 100644 internal/ircserver/export_test.go create mode 100644 internal/ircserver/parser.go create mode 100644 internal/ircserver/parser_test.go create mode 100644 internal/ircserver/relay.go create mode 100644 internal/ircserver/server.go create mode 100644 internal/ircserver/server_test.go create mode 100644 internal/service/service.go create mode 100644 internal/service/service_test.go diff --git a/Dockerfile b/Dockerfile index 273f53f..32acda7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,7 @@ RUN apk add --no-cache ca-certificates \ COPY --from=builder /neoircd /usr/local/bin/neoircd USER neoirc -EXPOSE 8080 +EXPOSE 8080 6667 HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ CMD wget -qO- http://localhost:8080/.well-known/healthcheck.json || exit 1 ENTRYPOINT ["neoircd"] diff --git a/README.md b/README.md index be7d0f6..fb97988 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ web client as a convenience/reference implementation, but the API comes first. - [Federation](#federation-server-to-server) - [Storage](#storage) - [Configuration](#configuration) +- [IRC Protocol Listener](#irc-protocol-listener) - [Deployment](#deployment) - [Client Development Guide](#client-development-guide) - [Rate Limiting & Abuse Prevention](#rate-limiting--abuse-prevention) @@ -2266,6 +2267,7 @@ directory is also loaded automatically via | `NEOIRC_OPER_PASSWORD` | string | `""` | Server operator (o-line) password. Both name and password must be set to enable OPER. | | `LOGIN_RATE_LIMIT` | float | `1` | Allowed login attempts per second per IP address. | | `LOGIN_RATE_BURST` | int | `5` | Maximum burst of login attempts per IP before rate limiting kicks in. | +| `IRC_LISTEN_ADDR` | string | `:6667` | TCP address for the traditional IRC protocol listener. Set to empty string to disable. | | `MAINTENANCE_MODE` | bool | `false` | Maintenance mode flag (reserved) | ### Example `.env` file @@ -2282,6 +2284,69 @@ NEOIRC_HASHCASH_BITS=20 --- +## IRC Protocol Listener + +neoirc includes an optional traditional IRC wire protocol listener (RFC +1459/2812) that allows standard IRC clients to connect directly. This enables +backward compatibility with existing IRC clients like irssi, weechat, hexchat, +and others. + +### Configuration + +The IRC listener is **enabled by default** on `:6667`. To disable it, set +`IRC_LISTEN_ADDR` to an empty string: + +```bash +IRC_LISTEN_ADDR= +``` + +### Supported Commands + +| Category | Commands | +|------------|------------------------------------------------------| +| Connection | `NICK`, `USER`, `PASS`, `QUIT`, `PING`/`PONG`, `CAP` | +| Channels | `JOIN`, `PART`, `MODE`, `TOPIC`, `NAMES`, `LIST`, `KICK`, `INVITE` | +| Messaging | `PRIVMSG`, `NOTICE` | +| Info | `WHO`, `WHOIS`, `LUSERS`, `MOTD`, `AWAY` | +| Operator | `OPER` (requires `NEOIRC_OPER_NAME` and `NEOIRC_OPER_PASSWORD`) | + +### Protocol Details + +- **Wire format**: CR-LF delimited lines, max 512 bytes per line +- **Connection registration**: Clients must send `NICK` and `USER` to register. + An optional `PASS` before registration sets the session password (minimum 8 + characters). +- **CAP negotiation**: `CAP LS` and `CAP END` are silently handled for + compatibility with modern clients. No capabilities are advertised. +- **Channel prefixes**: Channels must start with `#`. If omitted, it's + automatically prepended. +- **First joiner**: The first user to join a channel is automatically granted + operator status (`@`). +- **Channel modes**: `+m` (moderated), `+t` (topic lock), `+o` (operator), + `+v` (voice) + +### Bridge to HTTP API + +Messages sent by IRC clients appear in channels visible to HTTP/JSON API +clients and vice versa. The IRC listener and HTTP API share the same database, +broker, and session infrastructure. A user connected via IRC and a user +connected via the HTTP API can communicate in the same channels seamlessly. + +### Docker Usage + +To expose the IRC port in Docker (the listener is enabled by default on +`:6667`): + +```bash +docker run -d \ + -p 8080:8080 \ + -p 6667:6667 \ + -v neoirc-data:/var/lib/neoirc \ + neoirc +``` + +--- + ## Deployment ### Docker (Recommended) diff --git a/cmd/neoircd/main.go b/cmd/neoircd/main.go index d8d6601..8a5b141 100644 --- a/cmd/neoircd/main.go +++ b/cmd/neoircd/main.go @@ -2,14 +2,17 @@ package main import ( + "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/globals" "git.eeqj.de/sneak/neoirc/internal/handlers" "git.eeqj.de/sneak/neoirc/internal/healthcheck" + "git.eeqj.de/sneak/neoirc/internal/ircserver" "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/middleware" "git.eeqj.de/sneak/neoirc/internal/server" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/internal/stats" "go.uber.org/fx" ) @@ -28,16 +31,23 @@ func main() { fx.New( fx.Provide( + broker.New, config.New, db.New, globals.New, handlers.New, + ircserver.New, logger.New, server.New, middleware.New, healthcheck.New, + service.New, stats.New, ), - fx.Invoke(func(*server.Server) {}), + fx.Invoke(func( + _ *server.Server, + _ *ircserver.Server, + ) { + }), ).Run() } diff --git a/internal/config/config.go b/internal/config/config.go index 05f6882..c3377d7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -50,6 +50,7 @@ type Config struct { OperPassword string LoginRateLimit float64 LoginRateBurst int + IRCListenAddr string params *Params log *slog.Logger } @@ -86,6 +87,7 @@ func New( viper.SetDefault("NEOIRC_OPER_PASSWORD", "") viper.SetDefault("LOGIN_RATE_LIMIT", "1") viper.SetDefault("LOGIN_RATE_BURST", "5") + viper.SetDefault("IRC_LISTEN_ADDR", ":6667") err := viper.ReadInConfig() if err != nil { @@ -116,6 +118,7 @@ func New( OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"), LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"), LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"), + IRCListenAddr: viper.GetString("IRC_LISTEN_ADDR"), log: log, params: ¶ms, } diff --git a/internal/db/testing.go b/internal/db/testing.go new file mode 100644 index 0000000..d036611 --- /dev/null +++ b/internal/db/testing.go @@ -0,0 +1,25 @@ +package db + +import ( + "context" + "database/sql" + "log/slog" +) + +// NewTestDatabaseFromConn creates a Database wrapping an +// existing *sql.DB connection. Intended for integration +// tests in other packages. +func NewTestDatabaseFromConn(conn *sql.DB) *Database { + return &Database{ //nolint:exhaustruct + conn: conn, + log: slog.Default(), + } +} + +// RunMigrations applies all schema migrations. Exposed +// for integration tests in other packages. +func (database *Database) RunMigrations( + ctx context.Context, +) error { + return database.runMigrations(ctx) +} diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 34f1987..22825af 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -2,7 +2,6 @@ package handlers import ( "context" - "crypto/subtle" "encoding/json" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/hashcash" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/pkg/irc" "github.com/go-chi/chi/v5" ) @@ -182,51 +182,6 @@ func (hdlr *Handlers) requireAuth( return sessionID, clientID, nick, true } -// fanOut stores a message and enqueues it to all specified -// session IDs, then notifies them. -func (hdlr *Handlers) fanOut( - request *http.Request, - command, from, target string, - body json.RawMessage, - meta json.RawMessage, - sessionIDs []int64, -) (string, error) { - dbID, msgUUID, err := hdlr.params.Database.InsertMessage( - request.Context(), command, from, target, nil, body, meta, - ) - if err != nil { - return "", fmt.Errorf("insert message: %w", err) - } - - for _, sid := range sessionIDs { - enqErr := hdlr.params.Database.EnqueueToSession( - request.Context(), sid, dbID, - ) - if enqErr != nil { - hdlr.log.Error("enqueue failed", - "error", enqErr, "session_id", sid) - } - - hdlr.broker.Notify(sid) - } - - return msgUUID, nil -} - -// fanOutSilent calls fanOut and discards the UUID. -func (hdlr *Handlers) fanOutSilent( - request *http.Request, - command, from, target string, - body json.RawMessage, - sessionIDs []int64, -) error { - _, err := hdlr.fanOut( - request, command, from, target, body, nil, sessionIDs, - ) - - return err -} - // HandleCreateSession creates a new user session. func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc { return func( @@ -1218,6 +1173,43 @@ func (hdlr *Handlers) respondIRCError( http.StatusOK) } +// handleServiceError maps a service-layer error to an IRC +// numeric reply or a generic HTTP 500 response. Returns +// true if an error was handled (response sent). +func (hdlr *Handlers) handleServiceError( + writer http.ResponseWriter, + request *http.Request, + clientID, sessionID int64, + nick string, + err error, +) bool { + if err == nil { + return false + } + + var ircErr *service.IRCError + if errors.As(err, &ircErr) { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + ircErr.Code, nick, ircErr.Params, + ircErr.Message, + ) + + return true + } + + hdlr.log.Error( + "service error", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return true +} + func (hdlr *Handlers) handleChannelMsg( writer http.ResponseWriter, request *http.Request, @@ -1228,141 +1220,45 @@ func (hdlr *Handlers) handleChannelMsg( ) { ctx := request.Context() - chID, ok := hdlr.resolveChannelForSend( - writer, request, - sessionID, clientID, nick, target, + // Hashcash validation is HTTP-specific; needs chID. + isNotice := command == irc.CmdNotice + + if !isNotice { + chID, chErr := hdlr.params.Database. + GetChannelByName(ctx, target) + if chErr == nil { + hashcashErr := hdlr.validateChannelHashcash( + request, clientID, sessionID, + writer, nick, target, body, meta, chID, + ) + if hashcashErr != nil { + return + } + } + } + + // Delegate validation + fan-out to service layer. + dbID, uuid, err := hdlr.svc.SendChannelMessage( + ctx, sessionID, nick, + command, target, body, meta, ) - if !ok { - return - } - - // Enforce +b (ban): banned users cannot send. - isBanned, banErr := hdlr.params.Database. - IsSessionBanned(ctx, chID, sessionID) - if banErr == nil && isBanned { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrCannotSendToChan, nick, []string{target}, - "Cannot send to channel (+b)", - ) - - return - } - - // Enforce +m (moderated): only +o and +v can send. - if !hdlr.checkModeratedSend( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, target, chID, + clientID, sessionID, nick, err, ) { return } - // NOTICE skips hashcash validation on +H channels. - if command != irc.CmdNotice { - hashcashErr := hdlr.validateChannelHashcash( - request, clientID, sessionID, - writer, nick, target, body, meta, chID, - ) - if hashcashErr != nil { - return - } - } - - hdlr.sendChannelMsg( - writer, request, command, nick, target, - body, meta, chID, + // HTTP echo: enqueue to sender so all their clients + // see the message in long-poll responses. + _ = hdlr.params.Database.EnqueueToSession( + ctx, sessionID, dbID, ) -} + hdlr.broker.Notify(sessionID) -// resolveChannelForSend validates that the channel exists -// and the session is a member. Returns the channel ID. -func (hdlr *Handlers) resolveChannelForSend( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, target string, -) (int64, bool) { - ctx := request.Context() - - chID, err := hdlr.params.Database.GetChannelByName( - ctx, target, - ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, []string{target}, - "No such channel", - ) - - return 0, false - } - - isMember, memErr := hdlr.params.Database.IsChannelMember( - ctx, chID, sessionID, - ) - if memErr != nil { - hdlr.log.Error( - "check membership failed", "error", memErr, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return 0, false - } - - if !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrCannotSendToChan, nick, []string{target}, - "Cannot send to channel", - ) - - return 0, false - } - - return chID, true -} - -// checkModeratedSend checks if a user can send to a -// moderated channel. Returns true if sending is allowed. -func (hdlr *Handlers) checkModeratedSend( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, target string, - chID int64, -) bool { - ctx := request.Context() - - isModerated, err := hdlr.params.Database. - IsChannelModerated(ctx, chID) - if err != nil || !isModerated { - return true - } - - isOp, opErr := hdlr.params.Database. - IsChannelOperator(ctx, chID, sessionID) - if opErr == nil && isOp { - return true - } - - isVoiced, vErr := hdlr.params.Database. - IsChannelVoiced(ctx, chID, sessionID) - if vErr == nil && isVoiced { - return true - } - - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrCannotSendToChan, nick, - []string{target}, - "Cannot send to channel (+m)", - ) - - return false + hdlr.respondJSON(writer, request, + map[string]string{"id": uuid, "status": "sent"}, + http.StatusOK) } // validateChannelHashcash checks whether the channel @@ -1519,49 +1415,6 @@ func (hdlr *Handlers) extractHashcashFromMeta( return stamp } -func (hdlr *Handlers) sendChannelMsg( - writer http.ResponseWriter, - request *http.Request, - command, nick, target string, - body json.RawMessage, - meta json.RawMessage, - chID int64, -) { - memberIDs, err := hdlr.params.Database.GetChannelMemberIDs( - request.Context(), chID, - ) - if err != nil { - hdlr.log.Error( - "get channel members failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - msgUUID, err := hdlr.fanOut( - request, command, nick, target, body, meta, memberIDs, - ) - if err != nil { - hdlr.log.Error("send message failed", "error", err) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - hdlr.respondJSON(writer, request, - map[string]string{"id": msgUUID, "status": "sent"}, - http.StatusOK) -} - func (hdlr *Handlers) handleDirectMsg( writer http.ResponseWriter, request *http.Request, @@ -1570,60 +1423,32 @@ func (hdlr *Handlers) handleDirectMsg( body json.RawMessage, meta json.RawMessage, ) { - targetSID, err := hdlr.params.Database.GetSessionByNick( - request.Context(), target, + result, err := hdlr.svc.SendDirectMessage( + request.Context(), sessionID, nick, + command, target, body, meta, ) - if err != nil { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.ErrNoSuchNick, nick, []string{target}, - "No such nick/channel", - ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) - - return - } - - recipients := []int64{targetSID} - if targetSID != sessionID { - recipients = append(recipients, sessionID) - } - - msgUUID, err := hdlr.fanOut( - request, command, nick, target, body, meta, recipients, - ) - if err != nil { - hdlr.log.Error("send dm failed", "error", err) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } // Per RFC 2812: NOTICE must NOT trigger auto-replies // including RPL_AWAY. - if command != irc.CmdNotice { - awayMsg, awayErr := hdlr.params.Database.GetAway( - request.Context(), targetSID, + if command != irc.CmdNotice && result.AwayMsg != "" { + hdlr.enqueueNumeric( + request.Context(), clientID, + irc.RplAway, nick, + []string{target}, result.AwayMsg, ) - if awayErr == nil && awayMsg != "" { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.RplAway, nick, - []string{target}, awayMsg, - ) - hdlr.broker.Notify(sessionID) - } + hdlr.broker.Notify(sessionID) } hdlr.respondJSON(writer, request, - map[string]string{"id": msgUUID, "status": "sent"}, + map[string]string{ + "id": result.UUID, "status": "sent", + }, http.StatusOK) } @@ -1679,141 +1504,20 @@ func (hdlr *Handlers) executeJoin( sessionID, clientID int64, nick, channel, suppliedKey string, ) { - ctx := request.Context() - - chID, isCreator, ok := hdlr.resolveJoinChannel( - writer, request, channel, + result, err := hdlr.svc.JoinChannel( + request.Context(), sessionID, nick, + channel, suppliedKey, ) - if !ok { - return - } - - // Tier 2 join checks only apply to non-creators. - if !isCreator { - if !hdlr.checkJoinAllowed( - writer, request, - sessionID, clientID, nick, - channel, chID, suppliedKey, - ) { - return - } - } - - if err := hdlr.addMemberToChannel( - ctx, chID, sessionID, isCreator, - ); err != nil { - hdlr.log.Error( - "join channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - hdlr.broadcastJoin( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, channel, chID, - ) -} - -// resolveJoinChannel gets or creates the channel and -// determines if the joiner would be the first member -// (i.e., the channel creator/op). -func (hdlr *Handlers) resolveJoinChannel( - writer http.ResponseWriter, - request *http.Request, - channel string, -) (int64, bool, bool) { - ctx := request.Context() - - chID, err := hdlr.params.Database.GetOrCreateChannel( - ctx, channel, - ) - if err != nil { - hdlr.log.Error( - "get/create channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return 0, false, false + clientID, sessionID, nick, err, + ) { + return } - memberCount, countErr := hdlr.params.Database. - CountChannelMembers(ctx, chID) - if countErr != nil { - hdlr.log.Error( - "count members failed", "error", countErr, - ) - } - - isCreator := countErr == nil && memberCount == 0 - - return chID, isCreator, true -} - -// addMemberToChannel adds a session to a channel, with -// operator privileges if they're the first member. -func (hdlr *Handlers) addMemberToChannel( - ctx context.Context, - chID, sessionID int64, - isCreator bool, -) error { - if isCreator { - err := hdlr.params.Database.JoinChannelAsOperator( - ctx, chID, sessionID, - ) - if err != nil { - return fmt.Errorf( - "join as operator: %w", err, - ) - } - - return nil - } - - err := hdlr.params.Database.JoinChannel( - ctx, chID, sessionID, - ) - if err != nil { - return fmt.Errorf("join channel: %w", err) - } - - return nil -} - -// broadcastJoin fans out the JOIN, clears invites, and -// sends numerics. -func (hdlr *Handlers) broadcastJoin( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, -) { - ctx := request.Context() - - memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( - ctx, chID, - ) - - _ = hdlr.fanOutSilent( - request, irc.CmdJoin, nick, channel, nil, memberIDs, - ) - - _ = hdlr.params.Database.ClearChannelInvite( - ctx, chID, sessionID, - ) - hdlr.deliverJoinNumerics( - request, clientID, sessionID, nick, channel, chID, + request, clientID, sessionID, nick, + channel, result.ChannelID, ) hdlr.respondJSON(writer, request, @@ -1824,87 +1528,6 @@ func (hdlr *Handlers) broadcastJoin( http.StatusOK) } -// checkJoinAllowed runs Tier 2 restrictions for an -// existing channel. Returns true if join is allowed. -func (hdlr *Handlers) checkJoinAllowed( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, - suppliedKey string, -) bool { - ctx := request.Context() - - // 1. Ban check — prevents banned users from joining. - isBanned, banErr := hdlr.params.Database. - IsSessionBanned(ctx, chID, sessionID) - if banErr == nil && isBanned { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrBannedFromChan, nick, - []string{channel}, - "Cannot join channel (+b)", - ) - - return false - } - - // 2. Invite-only check (+i). - isInviteOnly, ioErr := hdlr.params.Database. - IsChannelInviteOnly(ctx, chID) - if ioErr == nil && isInviteOnly { - hasInvite, invErr := hdlr.params.Database. - HasChannelInvite(ctx, chID, sessionID) - if invErr != nil || !hasInvite { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrInviteOnlyChan, nick, - []string{channel}, - "Cannot join channel (+i)", - ) - - return false - } - } - - // 3. Channel key check (+k). - key, keyErr := hdlr.params.Database. - GetChannelKey(ctx, chID) - if keyErr == nil && key != "" { - if suppliedKey != key { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrBadChannelKey, nick, - []string{channel}, - "Cannot join channel (+k)", - ) - - return false - } - } - - // 4. User limit check (+l). - limit, limErr := hdlr.params.Database. - GetChannelUserLimit(ctx, chID) - if limErr == nil && limit > 0 { - count, cntErr := hdlr.params.Database. - CountChannelMembers(ctx, chID) - if cntErr == nil && count >= int64(limit) { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrChannelIsFull, nick, - []string{channel}, - "Cannot join channel (+l)", - ) - - return false - } - } - - return true -} - // deliverJoinNumerics sends RPL_TOPIC/RPL_NOTOPIC, // RPL_NAMREPLY, and RPL_ENDOFNAMES to the joining client. func (hdlr *Handlers) deliverJoinNumerics( @@ -2039,15 +1662,12 @@ func (hdlr *Handlers) handlePart( body json.RawMessage, ) { if target == "" { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.ErrNeedMoreParams, nick, []string{irc.CmdPart}, + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdPart}, "Not enough parameters", ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) return } @@ -2057,51 +1677,27 @@ func (hdlr *Handlers) handlePart( channel = "#" + channel } - chID, err := hdlr.params.Database.GetChannelByName( - request.Context(), channel, - ) - if err != nil { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.ErrNoSuchChannel, nick, []string{channel}, - "No such channel", - ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) - - return + // Extract reason from body for the service call. + reason := "" + if body != nil { + var lines []string + if json.Unmarshal(body, &lines) == nil && + len(lines) > 0 { + reason = lines[0] + } } - memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( - request.Context(), chID, + err := hdlr.svc.PartChannel( + request.Context(), sessionID, + nick, channel, reason, ) - - _ = hdlr.fanOutSilent( - request, irc.CmdPart, nick, channel, body, memberIDs, - ) - - err = hdlr.params.Database.PartChannel( - request.Context(), chID, sessionID, - ) - if err != nil { - hdlr.log.Error( - "part channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - _ = hdlr.params.Database.DeleteChannelIfEmpty( - request.Context(), chID, - ) - hdlr.respondJSON(writer, request, map[string]string{ "status": "parted", @@ -2162,34 +1758,16 @@ func (hdlr *Handlers) executeNickChange( sessionID, clientID int64, nick, newNick string, ) { - err := hdlr.params.Database.ChangeNick( - request.Context(), sessionID, newNick, + err := hdlr.svc.ChangeNick( + request.Context(), sessionID, nick, newNick, ) - if err != nil { - if db.IsUniqueConstraintError(err) { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNicknameInUse, nick, []string{newNick}, - "Nickname is already in use", - ) - - return - } - - hdlr.log.Error( - "change nick failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - hdlr.broadcastNick(request, sessionID, nick, newNick) - hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "nick": newNick, @@ -2197,70 +1775,19 @@ func (hdlr *Handlers) executeNickChange( http.StatusOK) } -func (hdlr *Handlers) broadcastNick( - request *http.Request, - sessionID int64, - oldNick, newNick string, -) { - channels, _ := hdlr.params.Database. - GetSessionChannels( - request.Context(), sessionID, - ) - - notified := map[int64]bool{sessionID: true} - - nickBody, err := json.Marshal([]string{newNick}) - if err != nil { - hdlr.log.Error( - "marshal nick body", "error", err, - ) - - return - } - - dbID, _, _ := hdlr.params.Database.InsertMessage( - request.Context(), irc.CmdNick, oldNick, "", - nil, json.RawMessage(nickBody), nil, - ) - - _ = hdlr.params.Database.EnqueueToSession( - request.Context(), sessionID, dbID, - ) - - hdlr.broker.Notify(sessionID) - - for _, chanInfo := range channels { - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs( - request.Context(), chanInfo.ID, - ) - - for _, mid := range memberIDs { - if !notified[mid] { - notified[mid] = true - - _ = hdlr.params.Database.EnqueueToSession( - request.Context(), mid, dbID, - ) - - hdlr.broker.Notify(mid) - } - } - } -} - func (hdlr *Handlers) handleTopic( writer http.ResponseWriter, request *http.Request, sessionID, clientID int64, nick, target string, - body json.RawMessage, + _ json.RawMessage, bodyLines func() []string, ) { if target == "" { hdlr.respondIRCError( writer, request, clientID, sessionID, - irc.ErrNeedMoreParams, nick, []string{irc.CmdTopic}, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdTopic}, "Not enough parameters", ) @@ -2271,7 +1798,8 @@ func (hdlr *Handlers) handleTopic( if len(lines) == 0 { hdlr.respondIRCError( writer, request, clientID, sessionID, - irc.ErrNeedMoreParams, nick, []string{irc.CmdTopic}, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdTopic}, "Not enough parameters", ) @@ -2283,131 +1811,24 @@ func (hdlr *Handlers) handleTopic( channel = "#" + channel } - chID, err := hdlr.params.Database.GetChannelByName( - request.Context(), channel, + topic := strings.Join(lines, " ") + + err := hdlr.svc.SetTopic( + request.Context(), sessionID, + nick, channel, topic, ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, []string{channel}, - "No such channel", - ) - - return - } - - isMember, err := hdlr.params.Database.IsChannelMember( - request.Context(), chID, sessionID, - ) - if err != nil { - hdlr.log.Error( - "check membership failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - if !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNotOnChannel, nick, []string{channel}, - "You're not on that channel", - ) - - return - } - - hdlr.executeTopic( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, - channel, strings.Join(lines, " "), - body, chID, - ) -} - -func (hdlr *Handlers) executeTopic( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel, topic string, - body json.RawMessage, - chID int64, -) { - ctx := request.Context() - - // Enforce +t: only operators can change topic when - // topic lock is active. - isLocked, lockErr := hdlr.params.Database. - IsChannelTopicLocked(ctx, chID) - if lockErr == nil && isLocked { - isOp, opErr := hdlr.params.Database. - IsChannelOperator(ctx, chID, sessionID) - if opErr != nil || !isOp { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrChanOpPrivsNeeded, nick, - []string{channel}, - "You're not channel operator", - ) - - return - } - } - - setErr := hdlr.params.Database.SetTopicMeta( - request.Context(), channel, topic, nick, - ) - if setErr != nil { - hdlr.log.Error( - "set topic failed", "error", setErr, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + clientID, sessionID, nick, err, + ) { return } - memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( - request.Context(), chID, + hdlr.deliverSetTopicNumerics( + request.Context(), clientID, sessionID, + nick, channel, topic, ) - _ = hdlr.fanOutSilent( - request, irc.CmdTopic, nick, channel, body, memberIDs, - ) - - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.RplTopic, nick, []string{channel}, topic, - ) - - // 333 RPL_TOPICWHOTIME - topicMeta, tmErr := hdlr.params.Database. - GetTopicMeta(request.Context(), chID) - if tmErr == nil && topicMeta != nil { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.RplTopicWhoTime, nick, - []string{ - channel, - topicMeta.SetBy, - strconv.FormatInt( - topicMeta.SetAt.Unix(), 10, - ), - }, - "", - ) - } - - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "topic": topic, @@ -2415,6 +1836,43 @@ func (hdlr *Handlers) executeTopic( http.StatusOK) } +// deliverSetTopicNumerics sends RPL_TOPIC and +// RPL_TOPICWHOTIME to the client after a topic change. +func (hdlr *Handlers) deliverSetTopicNumerics( + ctx context.Context, + clientID, sessionID int64, + nick, channel, topic string, +) { + hdlr.enqueueNumeric( + ctx, clientID, irc.RplTopic, nick, + []string{channel}, topic, + ) + + chID, chErr := hdlr.params.Database.GetChannelByName( + ctx, channel, + ) + if chErr == nil { + topicMeta, tmErr := hdlr.params.Database. + GetTopicMeta(ctx, chID) + if tmErr == nil && topicMeta != nil { + hdlr.enqueueNumeric( + ctx, clientID, + irc.RplTopicWhoTime, nick, + []string{ + channel, + topicMeta.SetBy, + strconv.FormatInt( + topicMeta.SetAt.Unix(), 10, + ), + }, + "", + ) + } + } + + hdlr.broker.Notify(sessionID) +} + // dispatchInfoCommand handles informational IRC commands // that produce server-side numerics (MOTD, PING). func (hdlr *Handlers) dispatchInfoCommand( @@ -2457,51 +1915,17 @@ func (hdlr *Handlers) handleQuit( nick string, body json.RawMessage, ) { - channels, _ := hdlr.params.Database. - GetSessionChannels( - request.Context(), sessionID, - ) - - notified := map[int64]bool{} - - var dbID int64 - - if len(channels) > 0 { - dbID, _, _ = hdlr.params.Database.InsertMessage( - request.Context(), irc.CmdQuit, nick, "", - nil, body, nil, - ) - } - - for _, chanInfo := range channels { - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs( - request.Context(), chanInfo.ID, - ) - - for _, mid := range memberIDs { - if mid != sessionID && !notified[mid] { - notified[mid] = true - - _ = hdlr.params.Database.EnqueueToSession( - request.Context(), mid, dbID, - ) - - hdlr.broker.Notify(mid) - } + reason := "Client quit" + if body != nil { + var lines []string + if json.Unmarshal(body, &lines) == nil && + len(lines) > 0 { + reason = lines[0] } - - _ = hdlr.params.Database.PartChannel( - request.Context(), chanInfo.ID, sessionID, - ) - - _ = hdlr.params.Database.DeleteChannelIfEmpty( - request.Context(), chanInfo.ID, - ) } - _ = hdlr.params.Database.DeleteSession( - request.Context(), sessionID, + hdlr.svc.BroadcastQuit( + request.Context(), sessionID, nick, reason, ) hdlr.clearAuthCookie(writer, request) @@ -2695,30 +2119,6 @@ func (hdlr *Handlers) queryChannelMode( // requireChannelOp checks that the session has +o in the // channel. If not, it sends ERR_CHANOPRIVSNEEDED and // returns false. -func (hdlr *Handlers) requireChannelOp( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, -) bool { - isOp, err := hdlr.params.Database.IsChannelOperator( - request.Context(), chID, sessionID, - ) - if err != nil || !isOp { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrChanOpPrivsNeeded, nick, - []string{channel}, - "You're not channel operator", - ) - - return false - } - - return true -} - // applyChannelMode handles setting channel modes. // Supports +o/-o, +v/-v, +m/-m, +t/-t, +H/-H. func (hdlr *Handlers) applyChannelMode( @@ -2894,67 +2294,6 @@ func (hdlr *Handlers) applyParameterizedMode( } } -// resolveUserModeTarget validates a user-mode change -// target and returns the target session ID if valid. -// Returns -1 on error (error response already sent). -func (hdlr *Handlers) resolveUserModeTarget( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, - modeArgs []string, -) (int64, string, bool) { - if !hdlr.requireChannelOp( - writer, request, - sessionID, clientID, nick, channel, chID, - ) { - return -1, "", false - } - - if len(modeArgs) < 2 { //nolint:mnd // mode + nick - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNeedMoreParams, nick, - []string{irc.CmdMode}, - "Not enough parameters", - ) - - return -1, "", false - } - - ctx := request.Context() - targetNick := modeArgs[1] - - targetSID, err := hdlr.params.Database. - GetSessionByNick(ctx, targetNick) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchNick, nick, - []string{targetNick}, - "No such nick/channel", - ) - - return -1, "", false - } - - isMember, memErr := hdlr.params.Database. - IsChannelMember(ctx, chID, targetSID) - if memErr != nil || !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrUserNotInChannel, nick, - []string{targetNick, channel}, - "They aren't on that channel", - ) - - return -1, "", false - } - - return targetSID, targetNick, true -} - // applyUserMode handles +o/-o and +v/-v mode changes. // isOperMode=true for +o/-o, false for +v/-v. func (hdlr *Handlers) applyUserMode( @@ -2966,46 +2305,53 @@ func (hdlr *Handlers) applyUserMode( modeArgs []string, isOperMode bool, ) { - ctx := request.Context() - - targetSID, targetNick, ok := hdlr.resolveUserModeTarget( - writer, request, - sessionID, clientID, nick, - channel, chID, modeArgs, + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + request.Context(), sessionID, channel, ) - if !ok { + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, opErr, + ) { return } + if len(modeArgs) < 2 { //nolint:mnd // mode + nick + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdMode}, + "Not enough parameters", + ) + + return + } + + targetNick := modeArgs[1] setting := strings.HasPrefix(modeArgs[0], "+") - var err error + var modeChar rune if isOperMode { - err = hdlr.params.Database.SetChannelMemberOperator( - ctx, chID, targetSID, setting, - ) + modeChar = 'o' } else { - err = hdlr.params.Database.SetChannelMemberVoiced( - ctx, chID, targetSID, setting, - ) + modeChar = 'v' } - if err != nil { - hdlr.log.Error( - "set user mode failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + err := hdlr.svc.ApplyMemberMode( + request.Context(), chID, channel, + targetNick, modeChar, setting, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - hdlr.broadcastUserModeChange( - request, nick, channel, chID, - modeArgs[0], targetNick, + modeText := modeArgs[0] + " " + targetNick + hdlr.svc.BroadcastMode( + request.Context(), nick, channel, chID, + modeText, ) hdlr.respondJSON(writer, request, @@ -3013,32 +2359,6 @@ func (hdlr *Handlers) applyUserMode( http.StatusOK) } -// broadcastUserModeChange fans out a user-mode change -// to all channel members. -func (hdlr *Handlers) broadcastUserModeChange( - request *http.Request, - nick, channel string, - chID int64, - modeStr, targetNick string, -) { - ctx := request.Context() - - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chID) - - modeBody, err := json.Marshal( - []string{modeStr, targetNick}, - ) - if err != nil { - return - } - - _ = hdlr.fanOutSilent( - request, irc.CmdMode, nick, channel, - json.RawMessage(modeBody), memberIDs, - ) -} - // setChannelFlag handles +m/-m and +t/-t mode changes. func (hdlr *Handlers) setChannelFlag( writer http.ResponseWriter, @@ -3049,70 +2369,35 @@ func (hdlr *Handlers) setChannelFlag( flag string, setting bool, ) { - ctx := request.Context() - - if !hdlr.requireChannelOp( + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + request.Context(), sessionID, channel, + ) + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, channel, chID, + clientID, sessionID, nick, opErr, ) { return } - var err error - - switch flag { - case "m": - err = hdlr.params.Database.SetChannelModerated( - ctx, chID, setting, - ) - case "t": - err = hdlr.params.Database.SetChannelTopicLocked( - ctx, chID, setting, - ) - case "i": - err = hdlr.params.Database.SetChannelInviteOnly( - ctx, chID, setting, - ) - case "s": - err = hdlr.params.Database.SetChannelSecret( - ctx, chID, setting, - ) - } - - if err != nil { - hdlr.log.Error( - "set channel flag failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + err := hdlr.svc.SetChannelFlag( + request.Context(), chID, rune(flag[0]), setting, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - // Broadcast the MODE change. modeStr := "+" + flag if !setting { modeStr = "-" + flag } - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chID) - - modeBody, mErr := json.Marshal([]string{modeStr}) - if mErr != nil { - hdlr.log.Error( - "marshal mode body", "error", mErr, - ) - - return - } - - _ = hdlr.fanOutSilent( - request, irc.CmdMode, nick, channel, - json.RawMessage(modeBody), memberIDs, + hdlr.svc.BroadcastMode( + request.Context(), nick, channel, chID, + modeStr, ) hdlr.respondJSON(writer, request, @@ -3140,6 +2425,17 @@ func (hdlr *Handlers) setHashcashMode( ) { ctx := request.Context() + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + ctx, sessionID, channel, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, opErr, + ) { + return + } + if len(modeArgs) < 2 { //nolint:mnd // +H requires a bits arg hdlr.respondIRCError( writer, request, clientID, sessionID, @@ -3203,6 +2499,17 @@ func (hdlr *Handlers) clearHashcashMode( ) { ctx := request.Context() + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + ctx, sessionID, channel, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, opErr, + ) { + return + } + err := hdlr.params.Database.SetChannelHashcashBits( ctx, chID, 0, ) @@ -4453,52 +3760,16 @@ func (hdlr *Handlers) HandleLogout() http.HandlerFunc { } // cleanupUser parts the user from all channels (notifying -// members) and deletes the session. +// members) and deletes the session via the shared service +// layer. func (hdlr *Handlers) cleanupUser( ctx context.Context, sessionID int64, nick string, ) { - channels, _ := hdlr.params.Database. - GetSessionChannels(ctx, sessionID) - - notified := map[int64]bool{} - - var quitDBID int64 - - if len(channels) > 0 { - quitDBID, _, _ = hdlr.params.Database.InsertMessage( - ctx, irc.CmdQuit, nick, "", - nil, nil, nil, - ) - } - - for _, chanInfo := range channels { - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chanInfo.ID) - - for _, mid := range memberIDs { - if mid != sessionID && !notified[mid] { - notified[mid] = true - - _ = hdlr.params.Database.EnqueueToSession( - ctx, mid, quitDBID, - ) - - hdlr.broker.Notify(mid) - } - } - - _ = hdlr.params.Database.PartChannel( - ctx, chanInfo.ID, sessionID, - ) - - _ = hdlr.params.Database.DeleteChannelIfEmpty( - ctx, chanInfo.ID, - ) - } - - _ = hdlr.params.Database.DeleteSession(ctx, sessionID) + hdlr.svc.BroadcastQuit( + ctx, sessionID, nick, "Connection closed", + ) } // HandleUsersMe returns the current user's session info. @@ -4545,7 +3816,8 @@ func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc { } } -// handleOper handles the OPER command for server operator authentication. +// handleOper handles the OPER command for server operator +// authentication. func (hdlr *Handlers) handleOper( writer http.ResponseWriter, request *http.Request, @@ -4567,39 +3839,13 @@ func (hdlr *Handlers) handleOper( return } - operName := lines[0] - operPass := lines[1] - - cfgName := hdlr.params.Config.OperName - cfgPass := hdlr.params.Config.OperPassword - - if cfgName == "" || cfgPass == "" || - subtle.ConstantTimeCompare([]byte(operName), []byte(cfgName)) != 1 || - subtle.ConstantTimeCompare([]byte(operPass), []byte(cfgPass)) != 1 { - hdlr.enqueueNumeric( - ctx, clientID, irc.ErrNoOperHost, nick, - nil, "No O-lines for your host", - ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) - - return - } - - err := hdlr.params.Database.SetSessionOper( - ctx, sessionID, true, + err := hdlr.svc.Oper( + ctx, sessionID, lines[0], lines[1], ) - if err != nil { - hdlr.log.Error( - "set oper failed", "error", err, - ) - hdlr.respondError( - writer, request, "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } @@ -4633,21 +3879,17 @@ func (hdlr *Handlers) handleAway( awayMsg = strings.Join(lines, " ") } - err := hdlr.params.Database.SetAway( + cleared, err := hdlr.svc.SetAway( ctx, sessionID, awayMsg, ) - if err != nil { - hdlr.log.Error("set away failed", "error", err) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - if awayMsg == "" { + if cleared { // 305 RPL_UNAWAY hdlr.enqueueNumeric( ctx, clientID, irc.RplUnaway, nick, nil, @@ -4676,6 +3918,8 @@ func (hdlr *Handlers) handleKick( body json.RawMessage, bodyLines func() []string, ) { + _ = body + if target == "" { hdlr.respondIRCError( writer, request, clientID, sessionID, @@ -4711,178 +3955,22 @@ func (hdlr *Handlers) handleKick( reason = lines[1] } - hdlr.executeKick( - writer, request, - sessionID, clientID, nick, - channel, targetNick, reason, body, + err := hdlr.svc.KickUser( + request.Context(), sessionID, nick, + channel, targetNick, reason, ) -} - -func (hdlr *Handlers) executeKick( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel, targetNick, reason string, - _ json.RawMessage, -) { - ctx := request.Context() - - chID, targetSID, ok := hdlr.validateKick( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, - channel, targetNick, - ) - if !ok { - return - } - - if !hdlr.broadcastKick( - writer, request, - nick, channel, targetNick, reason, chID, + clientID, sessionID, nick, err, ) { return } - // Remove the kicked user from the channel. - _ = hdlr.params.Database.PartChannel( - ctx, chID, targetSID, - ) - - // Clean up empty channel. - _ = hdlr.params.Database.DeleteChannelIfEmpty( - ctx, chID, - ) - hdlr.respondJSON(writer, request, map[string]string{"status": "ok"}, http.StatusOK) } -// validateKick checks the channel exists, the kicker is -// an operator, and the target is in the channel. -func (hdlr *Handlers) validateKick( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel, targetNick string, -) (int64, int64, bool) { - ctx := request.Context() - - chID, err := hdlr.params.Database.GetChannelByName( - ctx, channel, - ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, - []string{channel}, - "No such channel", - ) - - return 0, 0, false - } - - if !hdlr.requireChannelOp( - writer, request, - sessionID, clientID, nick, channel, chID, - ) { - return 0, 0, false - } - - targetSID, err := hdlr.params.Database. - GetSessionByNick(ctx, targetNick) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchNick, nick, - []string{targetNick}, - "No such nick/channel", - ) - - return 0, 0, false - } - - isMember, memErr := hdlr.params.Database. - IsChannelMember(ctx, chID, targetSID) - if memErr != nil || !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrUserNotInChannel, nick, - []string{targetNick, channel}, - "They aren't on that channel", - ) - - return 0, 0, false - } - - return chID, targetSID, true -} - -// broadcastKick inserts a KICK message and fans it out -// to all channel members. -func (hdlr *Handlers) broadcastKick( - writer http.ResponseWriter, - request *http.Request, - nick, channel, targetNick, reason string, - chID int64, -) bool { - ctx := request.Context() - - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chID) - - kickBody, bErr := json.Marshal([]string{reason}) - if bErr != nil { - hdlr.log.Error("marshal kick body", "error", bErr) - - return false - } - - kickParams, pErr := json.Marshal( - []string{targetNick}, - ) - if pErr != nil { - hdlr.log.Error( - "marshal kick params", "error", pErr, - ) - - return false - } - - dbID, _, insertErr := hdlr.params.Database. - InsertMessage( - ctx, irc.CmdKick, nick, channel, - json.RawMessage(kickParams), - json.RawMessage(kickBody), nil, - ) - if insertErr != nil { - hdlr.log.Error( - "insert kick message", "error", insertErr, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return false - } - - for _, sid := range memberIDs { - enqErr := hdlr.params.Database.EnqueueToSession( - ctx, sid, dbID, - ) - if enqErr != nil { - hdlr.log.Error("enqueue kick failed", - "error", enqErr, "session_id", sid) - } - - hdlr.broker.Notify(sid) - } - - return true -} - // deliverWhoisIdle sends RPL_WHOISIDLE (317) with idle // time and signon time. func (hdlr *Handlers) deliverWhoisIdle( @@ -4922,3 +4010,76 @@ func (hdlr *Handlers) deliverWhoisIdle( "seconds idle, signon time", ) } + +// fanOut inserts a message and enqueues it to all given +// sessions. +func (hdlr *Handlers) fanOut( + request *http.Request, + command, from, target string, + body json.RawMessage, + meta json.RawMessage, + sessionIDs []int64, +) (string, error) { + dbID, msgUUID, err := hdlr.params.Database.InsertMessage( + request.Context(), command, from, target, + nil, body, meta, + ) + if err != nil { + return "", fmt.Errorf("insert message: %w", err) + } + + for _, sid := range sessionIDs { + enqErr := hdlr.params.Database.EnqueueToSession( + request.Context(), sid, dbID, + ) + if enqErr != nil { + hdlr.log.Error("enqueue failed", + "error", enqErr, "session_id", sid) + } + + hdlr.broker.Notify(sid) + } + + return msgUUID, nil +} + +// fanOutSilent calls fanOut and discards the UUID. +func (hdlr *Handlers) fanOutSilent( + request *http.Request, + command, from, target string, + body json.RawMessage, + sessionIDs []int64, +) error { + _, err := hdlr.fanOut( + request, command, from, target, + body, nil, sessionIDs, + ) + + return err +} + +// requireChannelOp checks if the session is a channel +// operator and sends ERR_CHANOPRIVSNEEDED if not. +func (hdlr *Handlers) requireChannelOp( + writer http.ResponseWriter, + request *http.Request, + sessionID, clientID int64, + nick, channel string, + chID int64, +) bool { + isOp, err := hdlr.params.Database.IsChannelOperator( + request.Context(), chID, sessionID, + ) + if err != nil || !isOp { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrChanOpPrivsNeeded, nick, + []string{channel}, + "You're not channel operator", + ) + + return false + } + + return true +} diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index 6b0a477..ad70bdb 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/globals" @@ -27,6 +28,7 @@ import ( "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/middleware" "git.eeqj.de/sneak/neoirc/internal/server" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/internal/stats" "go.uber.org/fx" "go.uber.org/fx/fxtest" @@ -206,6 +208,14 @@ func newTestHandlers( hcheck *healthcheck.Healthcheck, tracker *stats.Tracker, ) (*handlers.Handlers, error) { + brk := broker.New() + svc := service.New(service.Params{ //nolint:exhaustruct + Logger: log, + Config: cfg, + Database: database, + Broker: brk, + }) + hdlr, err := handlers.New(lifecycle, handlers.Params{ //nolint:exhaustruct Logger: log, Globals: globs, @@ -213,6 +223,8 @@ func newTestHandlers( Database: database, Healthcheck: hcheck, Stats: tracker, + Broker: brk, + Service: svc, }) if err != nil { return nil, fmt.Errorf("test handlers: %w", err) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 5d5c225..d185b5d 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -17,6 +17,7 @@ import ( "git.eeqj.de/sneak/neoirc/internal/healthcheck" "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/ratelimit" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/internal/stats" "go.uber.org/fx" ) @@ -33,6 +34,8 @@ type Params struct { Database *db.Database Healthcheck *healthcheck.Healthcheck Stats *stats.Tracker + Broker *broker.Broker + Service *service.Service } const defaultIdleTimeout = 30 * 24 * time.Hour @@ -48,6 +51,7 @@ type Handlers struct { log *slog.Logger hc *healthcheck.Healthcheck broker *broker.Broker + svc *service.Service hashcashVal *hashcash.Validator channelHashcash *hashcash.ChannelValidator loginLimiter *ratelimit.Limiter @@ -79,7 +83,8 @@ func New( params: ¶ms, log: params.Logger.Get(), hc: params.Healthcheck, - broker: broker.New(), + broker: params.Broker, + svc: params.Service, hashcashVal: hashcash.NewValidator(resource), channelHashcash: hashcash.NewChannelValidator(), loginLimiter: ratelimit.New(loginRate, loginBurst), diff --git a/internal/ircserver/commands.go b/internal/ircserver/commands.go new file mode 100644 index 0000000..297ea4f --- /dev/null +++ b/internal/ircserver/commands.go @@ -0,0 +1,1178 @@ +package ircserver + +import ( + "context" + "encoding/json" + "errors" + "strconv" + "strings" + "time" + + "git.eeqj.de/sneak/neoirc/internal/service" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +// sendIRCError maps a service.IRCError to an IRC numeric +// reply on the wire. +func (c *Conn) sendIRCError(err error) { + var ircErr *service.IRCError + if errors.As(err, &ircErr) { + args := make([]string, 0, len(ircErr.Params)+1) + args = append(args, ircErr.Params...) + args = append(args, ircErr.Message) + c.sendNumeric(ircErr.Code, args...) + } +} + +// handleCAP silently acknowledges CAP negotiation. +func (c *Conn) handleCAP(msg *Message) { + if len(msg.Params) == 0 { + return + } + + sub := strings.ToUpper(msg.Params[0]) + if sub == "LS" { + c.send(FormatMessage( + c.serverSfx, "CAP", "*", "LS", "", + )) + } + + // CAP END and other subcommands are silently ignored. +} + +// handlePing replies with a PONG. +func (c *Conn) handlePing(msg *Message) { + token := c.serverSfx + + if len(msg.Params) > 0 { + token = msg.Params[0] + } + + c.sendFromServer("PONG", c.serverSfx, token) +} + +// handleNick changes the user's nickname via the shared +// service layer. +func (c *Conn) handleNick( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + newNick := msg.Params[0] + if len(newNick) > maxNickLen { + newNick = newNick[:maxNickLen] + } + + oldMask := c.hostmask() + oldNick := c.nick + + err := c.svc.ChangeNick( + ctx, c.sessionID, oldNick, newNick, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + c.mu.Lock() + c.nick = newNick + c.mu.Unlock() + + // Echo NICK change to the client on wire. + c.send(FormatMessage(oldMask, "NICK", newNick)) +} + +// handlePrivmsg handles PRIVMSG and NOTICE commands via +// the shared service layer. +func (c *Conn) handlePrivmsg( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoRecipient, + "No recipient given ("+msg.Command+")", + ) + + return + } + + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNoTextToSend, "No text to send", + ) + + return + } + + target := msg.Params[0] + text := msg.Params[1] + body, _ := json.Marshal([]string{text}) //nolint:errchkjson + + if strings.HasPrefix(target, "#") { + _, _, err := c.svc.SendChannelMessage( + ctx, c.sessionID, c.nick, + msg.Command, target, body, nil, + ) + if err != nil { + c.sendIRCError(err) + } + } else { + result, err := c.svc.SendDirectMessage( + ctx, c.sessionID, c.nick, + msg.Command, target, body, nil, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + if result.AwayMsg != "" { + c.sendNumeric( + irc.RplAway, target, result.AwayMsg, + ) + } + } +} + +// handleJoin joins one or more channels via the shared +// service layer. +func (c *Conn) handleJoin( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "JOIN", "Not enough parameters", + ) + + return + } + + channels := strings.Split(msg.Params[0], ",") + + for _, chanName := range channels { + chanName = strings.TrimSpace(chanName) + + if !strings.HasPrefix(chanName, "#") { + chanName = "#" + chanName + } + + c.joinChannel(ctx, chanName) + } +} + +// joinChannel joins a single channel using the service +// and delivers topic/names on the wire. +func (c *Conn) joinChannel( + ctx context.Context, + channel string, +) { + result, err := c.svc.JoinChannel( + ctx, c.sessionID, c.nick, channel, "", + ) + if err != nil { + c.sendIRCError(err) + + if !errors.As(err, new(*service.IRCError)) { + c.log.Error( + "join channel failed", "error", err, + ) + } + + return + } + + // Send JOIN echo to this client directly on wire. + c.send(FormatMessage(c.hostmask(), "JOIN", channel)) + + // Send topic. + c.deliverTopic(ctx, channel, result.ChannelID) + + // Send NAMES. + c.deliverNames(ctx, channel, result.ChannelID) +} + +// deliverTopic sends RPL_TOPIC or RPL_NOTOPIC. +func (c *Conn) deliverTopic( + ctx context.Context, + channel string, + chID int64, +) { + channels, err := c.database.ListChannels( + ctx, c.sessionID, + ) + + topic := "" + + if err == nil { + for _, ch := range channels { + if ch.Name == channel { + topic = ch.Topic + + break + } + } + } + + if topic == "" { + c.sendNumeric( + irc.RplNoTopic, channel, "No topic is set", + ) + + return + } + + c.sendNumeric(irc.RplTopic, channel, topic) + + meta, tmErr := c.database.GetTopicMeta(ctx, chID) + if tmErr == nil && meta != nil { + c.sendNumeric( + irc.RplTopicWhoTime, channel, + meta.SetBy, + strconv.FormatInt(meta.SetAt.Unix(), 10), + ) + } +} + +// deliverNames sends RPL_NAMREPLY and RPL_ENDOFNAMES. +func (c *Conn) deliverNames( + ctx context.Context, + channel string, + chID int64, +) { + members, err := c.database.ChannelMembers(ctx, chID) + if err != nil { + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) + + return + } + + names := make([]string, 0, len(members)) + + for _, member := range members { + prefix := "" + if member.IsOperator { + prefix = "@" + } else if member.IsVoiced { + prefix = "+" + } + + names = append(names, prefix+member.Nick) + } + + nameStr := strings.Join(names, " ") + + c.sendNumeric( + irc.RplNamReply, "=", channel, nameStr, + ) + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) +} + +// handlePart leaves one or more channels via the shared +// service layer. +func (c *Conn) handlePart( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PART", "Not enough parameters", + ) + + return + } + + reason := "" + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + channels := strings.Split(msg.Params[0], ",") + + for _, ch := range channels { + ch = strings.TrimSpace(ch) + c.partChannel(ctx, ch, reason) + } +} + +// partChannel leaves a single channel using the service. +func (c *Conn) partChannel( + ctx context.Context, + channel, reason string, +) { + err := c.svc.PartChannel( + ctx, c.sessionID, c.nick, channel, reason, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + // Echo PART to the client on wire. + if reason != "" { + c.send(FormatMessage( + c.hostmask(), "PART", channel, reason, + )) + } else { + c.send(FormatMessage( + c.hostmask(), "PART", channel, + )) + } +} + +// handleQuit handles the QUIT command. +func (c *Conn) handleQuit(msg *Message) { + reason := "Client quit" + + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + + c.send("ERROR :Closing Link: " + c.hostname + + " (Quit: " + reason + ")") + c.closed = true +} + +// handleTopic gets or sets a channel topic via the shared +// service layer. +func (c *Conn) handleTopic( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "TOPIC", "Not enough parameters", + ) + + return + } + + channel := msg.Params[0] + + // If no second param, query the topic. + if len(msg.Params) < 2 { //nolint:mnd + c.queryTopic(ctx, channel) + + return + } + + // Set topic via service. + newTopic := msg.Params[1] + + err := c.svc.SetTopic( + ctx, c.sessionID, c.nick, channel, newTopic, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + // Echo TOPIC to the setting client on wire. + c.send(FormatMessage( + c.hostmask(), "TOPIC", channel, newTopic, + )) +} + +// queryTopic sends the current topic for a channel. +func (c *Conn) queryTopic( + ctx context.Context, + channel string, +) { + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + c.deliverTopic(ctx, channel, chID) +} + +// handleMode handles MODE queries and changes. +func (c *Conn) handleMode( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "MODE", "Not enough parameters", + ) + + return + } + + target := msg.Params[0] + + if strings.HasPrefix(target, "#") { + c.handleChannelMode(ctx, msg) + } else { + c.handleUserMode(msg) + } +} + +// handleChannelMode handles MODE for channels. +func (c *Conn) handleChannelMode( + ctx context.Context, + msg *Message, +) { + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + // Query mode if no mode string given. + if len(msg.Params) < 2 { //nolint:mnd + modeStr := c.svc.QueryChannelMode(ctx, chID) + c.sendNumeric( + irc.RplChannelModeIs, channel, modeStr, + ) + + created, _ := c.database.GetChannelCreatedAt( + ctx, chID, + ) + if !created.IsZero() { + c.sendNumeric( + irc.RplCreationTime, channel, + strconv.FormatInt(created.Unix(), 10), + ) + } + + return + } + + // Need ops to change modes — validated by service. + _, opErr := c.svc.ValidateChannelOp( + ctx, c.sessionID, channel, + ) + if opErr != nil { + c.sendIRCError(opErr) + + return + } + + modeStr := msg.Params[1] + modeArgs := msg.Params[2:] + + c.applyChannelModes( + ctx, channel, chID, modeStr, modeArgs, + ) +} + +// applyChannelModes applies mode changes using the +// service for individual mode operations. +func (c *Conn) applyChannelModes( + ctx context.Context, + channel string, + chID int64, + modeStr string, + args []string, +) { + adding := true + argIdx := 0 + applied := "" + appliedArgs := "" + + for _, modeChar := range modeStr { + switch modeChar { + case '+': + adding = true + case '-': + adding = false + case 'm', 't': + _ = c.svc.SetChannelFlag( + ctx, chID, modeChar, adding, + ) + + if adding { + applied += "+" + string(modeChar) + } else { + applied += "-" + string(modeChar) + } + case 'o', 'v': + if argIdx >= len(args) { + break + } + + targetNick := args[argIdx] + argIdx++ + + err := c.svc.ApplyMemberMode( + ctx, chID, channel, + targetNick, modeChar, adding, + ) + if err != nil { + c.sendIRCError(err) + + continue + } + + if adding { + applied += "+" + string(modeChar) + } else { + applied += "-" + string(modeChar) + } + + appliedArgs += " " + targetNick + default: + c.sendNumeric( + irc.ErrUnknownMode, + string(modeChar), + "is unknown mode char to me", + ) + } + } + + if applied != "" { + modeReply := applied + if appliedArgs != "" { + modeReply += appliedArgs + } + + c.send(FormatMessage( + c.hostmask(), "MODE", channel, modeReply, + )) + + c.svc.BroadcastMode( + ctx, c.nick, channel, chID, modeReply, + ) + } +} + +// handleUserMode handles MODE for users. +func (c *Conn) handleUserMode(msg *Message) { + target := msg.Params[0] + + if !strings.EqualFold(target, c.nick) { + c.sendNumeric( + irc.ErrUsersDoNotMatch, + "Can't change mode for other users", + ) + + return + } + + // We don't support user modes beyond the basics. + c.sendNumeric(irc.RplUmodeIs, "+") +} + +// handleNames replies with channel member list. +func (c *Conn) handleNames( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.RplEndOfNames, + "*", "End of /NAMES list", + ) + + return + } + + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) + + return + } + + c.deliverNames(ctx, channel, chID) +} + +// handleList sends the channel list. +func (c *Conn) handleList(ctx context.Context) { + channels, err := c.database.ListAllChannelsWithCounts( + ctx, + ) + if err != nil { + c.sendNumeric( + irc.RplListEnd, "End of /LIST", + ) + + return + } + + c.sendNumeric(irc.RplListStart, "Channel", "Users Name") + + for idx := range channels { + c.sendNumeric( + irc.RplList, channels[idx].Name, + strconv.FormatInt( + channels[idx].MemberCount, 10, + ), + channels[idx].Topic, + ) + } + + c.sendNumeric(irc.RplListEnd, "End of /LIST") +} + +// handleWhois replies with user info. Individual numeric +// replies are split into focused helper methods. +func (c *Conn) handleWhois( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + target := msg.Params[0] + + if len(msg.Params) > 1 { + target = msg.Params[1] + } + + targetID, err := c.database.GetSessionByNick( + ctx, target, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, target, "No such nick", + ) + c.sendNumeric( + irc.RplEndOfWhois, + target, "End of /WHOIS list", + ) + + return + } + + c.whoisUser(ctx, target, targetID) + c.whoisServer(target) + c.whoisOper(ctx, target, targetID) + c.whoisChannels(ctx, target, targetID) + c.whoisIdle(ctx, target, targetID) + c.whoisAway(ctx, target, targetID) + + c.sendNumeric( + irc.RplEndOfWhois, + target, "End of /WHOIS list", + ) +} + +// whoisUser sends 311 RPL_WHOISUSER. +func (c *Conn) whoisUser( + ctx context.Context, + target string, + targetID int64, +) { + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, targetID, + ) + + username := target + hostname := "*" + + if hostInfo != nil { + username = hostInfo.Username + hostname = hostInfo.Hostname + } + + c.sendNumeric( + irc.RplWhoisUser, target, + username, hostname, "*", target, + ) +} + +// whoisServer sends 312 RPL_WHOISSERVER. +func (c *Conn) whoisServer(target string) { + c.sendNumeric( + irc.RplWhoisServer, target, + c.serverSfx, "neoirc server", + ) +} + +// whoisOper sends 313 RPL_WHOISOPERATOR if applicable. +func (c *Conn) whoisOper( + ctx context.Context, + target string, + targetID int64, +) { + isOper, _ := c.database.IsSessionOper(ctx, targetID) + if isOper { + c.sendNumeric( + irc.RplWhoisOperator, + target, "is an IRC operator", + ) + } +} + +// whoisChannels sends 319 RPL_WHOISCHANNELS. +func (c *Conn) whoisChannels( + ctx context.Context, + target string, + targetID int64, +) { + userChannels, _ := c.database.GetSessionChannels( + ctx, targetID, + ) + if len(userChannels) == 0 { + return + } + + chanList := make([]string, 0, len(userChannels)) + + for _, userChan := range userChannels { + chID, getErr := c.database.GetChannelByName( + ctx, userChan.Name, + ) + if getErr != nil { + chanList = append(chanList, userChan.Name) + + continue + } + + isChOp, _ := c.database.IsChannelOperator( + ctx, chID, targetID, + ) + isVoiced, _ := c.database.IsChannelVoiced( + ctx, chID, targetID, + ) + + prefix := "" + if isChOp { + prefix = "@" + } else if isVoiced { + prefix = "+" + } + + chanList = append(chanList, prefix+userChan.Name) + } + + c.sendNumeric( + irc.RplWhoisChannels, target, + strings.Join(chanList, " "), + ) +} + +// whoisIdle sends 317 RPL_WHOISIDLE. +func (c *Conn) whoisIdle( + ctx context.Context, + target string, + targetID int64, +) { + lastSeen, _ := c.database.GetSessionLastSeen( + ctx, targetID, + ) + created, _ := c.database.GetSessionCreatedAt( + ctx, targetID, + ) + + if lastSeen.IsZero() { + return + } + + idle := int64(time.Since(lastSeen).Seconds()) + + signonTS := int64(0) + if !created.IsZero() { + signonTS = created.Unix() + } + + c.sendNumeric( + irc.RplWhoisIdle, target, + strconv.FormatInt(idle, 10), + strconv.FormatInt(signonTS, 10), + "seconds idle, signon time", + ) +} + +// whoisAway sends 301 RPL_AWAY if the target is away. +func (c *Conn) whoisAway( + ctx context.Context, + target string, + targetID int64, +) { + away, _ := c.database.GetAway(ctx, targetID) + if away != "" { + c.sendNumeric(irc.RplAway, target, away) + } +} + +// handleWho sends WHO replies for a channel. +func (c *Conn) handleWho( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.RplEndOfWho, "*", "End of /WHO list", + ) + + return + } + + target := msg.Params[0] + + if !strings.HasPrefix(target, "#") { + // WHO for a nick. + c.whoNick(ctx, target) + + return + } + + chID, err := c.database.GetChannelByName(ctx, target) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) + + return + } + + members, err := c.database.ChannelMembers(ctx, chID) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) + + return + } + + for _, member := range members { + flags := "H" + if member.IsOperator { + flags += "@" + } else if member.IsVoiced { + flags += "+" + } + + c.sendNumeric( + irc.RplWhoReply, + target, member.Username, member.Hostname, + c.serverSfx, member.Nick, flags, + "0 "+member.Nick, + ) + } + + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) +} + +// whoNick sends WHO reply for a single nick. +func (c *Conn) whoNick(ctx context.Context, nick string) { + targetID, err := c.database.GetSessionByNick(ctx, nick) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, nick, "End of /WHO list", + ) + + return + } + + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, targetID, + ) + + username := nick + hostname := "*" + + if hostInfo != nil { + username = hostInfo.Username + hostname = hostInfo.Hostname + } + + c.sendNumeric( + irc.RplWhoReply, + "*", username, hostname, + c.serverSfx, nick, "H", + "0 "+nick, + ) + c.sendNumeric( + irc.RplEndOfWho, nick, "End of /WHO list", + ) +} + +// handleLusers replies with server statistics. +func (c *Conn) handleLusers(ctx context.Context) { + c.deliverLusers(ctx) +} + +// handleOper handles the OPER command via the shared +// service layer. +func (c *Conn) handleOper( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "OPER", "Not enough parameters", + ) + + return + } + + err := c.svc.Oper( + ctx, c.sessionID, + msg.Params[0], msg.Params[1], + ) + if err != nil { + c.sendIRCError(err) + + return + } + + c.sendNumeric( + irc.RplYoureOper, + "You are now an IRC operator", + ) +} + +// handleAway sets or clears the AWAY status via the +// shared service layer. +func (c *Conn) handleAway( + ctx context.Context, + msg *Message, +) { + message := "" + if len(msg.Params) > 0 { + message = msg.Params[0] + } + + cleared, err := c.svc.SetAway( + ctx, c.sessionID, message, + ) + if err != nil { + c.log.Error("set away failed", "error", err) + + return + } + + if cleared { + c.sendNumeric( + irc.RplUnaway, + "You are no longer marked as being away", + ) + } else { + c.sendNumeric( + irc.RplNowAway, + "You have been marked as being away", + ) + } +} + +// handleKick kicks a user from a channel via the shared +// service layer. +func (c *Conn) handleKick( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "KICK", "Not enough parameters", + ) + + return + } + + channel := msg.Params[0] + targetNick := msg.Params[1] + + reason := targetNick + if len(msg.Params) > 2 { //nolint:mnd + reason = msg.Params[2] + } + + err := c.svc.KickUser( + ctx, c.sessionID, c.nick, + channel, targetNick, reason, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + // Echo KICK on wire. + c.send(FormatMessage( + c.hostmask(), "KICK", channel, targetNick, reason, + )) +} + +// handlePassPostReg handles PASS after registration (for +// setting a session password). +func (c *Conn) handlePassPostReg( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PASS", "Not enough parameters", + ) + + return + } + + password := msg.Params[0] + if len(password) < minPasswordLen { + c.sendFromServer("NOTICE", c.nick, + "Password must be at least 8 characters", + ) + + return + } + + c.setPassword(ctx, password) + + c.sendFromServer("NOTICE", c.nick, + "Password set. You can reconnect using "+ + "PASS with your nick.", + ) +} + +// handleInvite handles the INVITE command. +func (c *Conn) handleInvite( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "INVITE", "Not enough parameters", + ) + + return + } + + targetNick := msg.Params[0] + channel := msg.Params[1] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, c.sessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrNotOnChannel, + channel, "You're not on that channel", + ) + + return + } + + targetID, err := c.database.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, + targetNick, "No such nick", + ) + + return + } + + c.sendNumeric( + irc.RplInviting, targetNick, channel, + ) + + // Send INVITE notice to target via service fan-out. + body, _ := json.Marshal( //nolint:errchkjson + []string{"You have been invited to " + channel}, + ) + + _, _, _ = c.svc.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, "INVITE", c.nick, targetNick, + nil, body, nil, []int64{targetID}, + ) +} + +// handleUserhost replies with USERHOST info. +func (c *Conn) handleUserhost( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + return + } + + replies := make([]string, 0, len(msg.Params)) + + for _, nick := range msg.Params { + sid, err := c.database.GetSessionByNick(ctx, nick) + if err != nil { + continue + } + + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, sid, + ) + + host := "*" + if hostInfo != nil { + host = hostInfo.Hostname + } + + isOper, _ := c.database.IsSessionOper(ctx, sid) + + operStar := "" + if isOper { + operStar = "*" + } + + replies = append( + replies, + nick+operStar+"=+"+nick+"@"+host, + ) + } + + c.sendNumeric( + irc.RplUserHost, + strings.Join(replies, " "), + ) +} diff --git a/internal/ircserver/conn.go b/internal/ircserver/conn.go new file mode 100644 index 0000000..173ac19 --- /dev/null +++ b/internal/ircserver/conn.go @@ -0,0 +1,501 @@ +package ircserver + +import ( + "bufio" + "context" + "fmt" + "log/slog" + "net" + "strconv" + "strings" + "sync" + "time" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/service" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +const ( + maxLineLen = 512 + readTimeout = 5 * time.Minute + writeTimeout = 30 * time.Second + dnsTimeout = 3 * time.Second + pollInterval = 100 * time.Millisecond + pingInterval = 90 * time.Second + pongDeadline = 30 * time.Second + maxNickLen = 32 + minPasswordLen = 8 +) + +// cmdHandler is the signature for registered IRC command +// handlers. +type cmdHandler func(ctx context.Context, msg *Message) + +// Conn represents a single IRC client TCP connection. +type Conn struct { + conn net.Conn + log *slog.Logger + database *db.Database + brk *broker.Broker + cfg *config.Config + svc *service.Service + serverSfx string + commands map[string]cmdHandler + + mu sync.Mutex + nick string + username string + realname string + hostname string + remoteIP string + sessionID int64 + clientID int64 + + registered bool + gotNick bool + gotUser bool + passWord string + + lastQueueID int64 + closed bool + cancel context.CancelFunc +} + +func newConn( + ctx context.Context, + tcpConn net.Conn, + log *slog.Logger, + database *db.Database, + brk *broker.Broker, + cfg *config.Config, + svc *service.Service, +) *Conn { + host, _, _ := net.SplitHostPort(tcpConn.RemoteAddr().String()) + + srvName := cfg.ServerName + if srvName == "" { + srvName = "neoirc" + } + + conn := &Conn{ //nolint:exhaustruct // zero-value defaults + conn: tcpConn, + log: log, + database: database, + brk: brk, + cfg: cfg, + svc: svc, + serverSfx: srvName, + remoteIP: host, + hostname: resolveHost(ctx, host), + } + + conn.commands = conn.buildCommandMap() + + return conn +} + +// buildCommandMap returns a map from IRC command strings +// to handler functions. +func (c *Conn) buildCommandMap() map[string]cmdHandler { + return map[string]cmdHandler{ + irc.CmdPing: func(_ context.Context, msg *Message) { + c.handlePing(msg) + }, + "PONG": func(context.Context, *Message) {}, + irc.CmdNick: c.handleNick, + irc.CmdPrivmsg: c.handlePrivmsg, + irc.CmdNotice: c.handlePrivmsg, + irc.CmdJoin: c.handleJoin, + irc.CmdPart: c.handlePart, + irc.CmdQuit: func(_ context.Context, msg *Message) { + c.handleQuit(msg) + }, + irc.CmdTopic: c.handleTopic, + irc.CmdMode: c.handleMode, + irc.CmdNames: c.handleNames, + irc.CmdList: func(ctx context.Context, _ *Message) { c.handleList(ctx) }, + irc.CmdWhois: c.handleWhois, + irc.CmdWho: c.handleWho, + irc.CmdLusers: func(ctx context.Context, _ *Message) { c.handleLusers(ctx) }, + irc.CmdMotd: func(context.Context, *Message) { c.deliverMOTD() }, + irc.CmdOper: c.handleOper, + irc.CmdAway: c.handleAway, + irc.CmdKick: c.handleKick, + irc.CmdPass: c.handlePassPostReg, + "INVITE": c.handleInvite, + "CAP": func(_ context.Context, msg *Message) { + c.handleCAP(msg) + }, + "USERHOST": c.handleUserhost, + } +} + +// resolveHost does a reverse DNS lookup, returning the IP +// on failure. +func resolveHost(ctx context.Context, addr string) string { + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + resolver := &net.Resolver{} //nolint:exhaustruct + + names, err := resolver.LookupAddr(ctx, addr) + if err != nil || len(names) == 0 { + return addr + } + + return strings.TrimSuffix(names[0], ".") +} + +// serve is the main loop for a single IRC client connection. +func (c *Conn) serve(ctx context.Context) { + ctx, c.cancel = context.WithCancel(ctx) + defer c.cleanup(ctx) + + scanner := bufio.NewScanner(c.conn) + scanner.Buffer(make([]byte, maxLineLen), maxLineLen) + + for { + _ = c.conn.SetReadDeadline( + time.Now().Add(readTimeout), + ) + + if !scanner.Scan() { + return + } + + line := scanner.Text() + if line == "" { + continue + } + + msg := ParseMessage(line) + if msg == nil { + continue + } + + c.handleMessage(ctx, msg) + + if c.closed { + return + } + } +} + +func (c *Conn) cleanup(ctx context.Context) { + c.mu.Lock() + wasRegistered := c.registered + sessID := c.sessionID + nick := c.nick + c.closed = true + c.mu.Unlock() + + if wasRegistered && sessID > 0 { + c.svc.BroadcastQuit( + ctx, sessID, nick, "Connection closed", + ) + } + + c.conn.Close() //nolint:errcheck,gosec +} + +// send writes a formatted IRC line to the connection. +func (c *Conn) send(line string) { + _ = c.conn.SetWriteDeadline( + time.Now().Add(writeTimeout), + ) + + _, _ = fmt.Fprintf(c.conn, "%s\r\n", line) +} + +// sendNumeric sends a numeric reply from the server. +func (c *Conn) sendNumeric( + code irc.IRCMessageType, + params ...string, +) { + nick := c.nick + if nick == "" { + nick = "*" + } + + allParams := make([]string, 0, 1+len(params)) + allParams = append(allParams, nick) + allParams = append(allParams, params...) + + c.send(FormatMessage( + c.serverSfx, code.Code(), allParams..., + )) +} + +// sendFromServer sends a message from the server. +func (c *Conn) sendFromServer( + command string, params ...string, +) { + c.send(FormatMessage(c.serverSfx, command, params...)) +} + +// hostmask returns the client's full hostmask +// (nick!user@host). +func (c *Conn) hostmask() string { + user := c.username + if user == "" { + user = c.nick + } + + host := c.hostname + if host == "" { + host = c.remoteIP + } + + return c.nick + "!" + user + "@" + host +} + +// handleMessage dispatches a parsed IRC message using +// the command handler map. +func (c *Conn) handleMessage( + ctx context.Context, + msg *Message, +) { + // Before registration, only NICK, USER, PASS, PING, + // QUIT, and CAP are accepted. + if !c.registered { + c.handlePreRegistration(ctx, msg) + + return + } + + handler, ok := c.commands[msg.Command] + if !ok { + c.sendNumeric( + irc.ErrUnknownCommand, + msg.Command, "Unknown command", + ) + + return + } + + handler(ctx, msg) +} + +// handlePreRegistration handles messages before the +// connection is registered (NICK+USER received). +func (c *Conn) handlePreRegistration( + ctx context.Context, + msg *Message, +) { + switch msg.Command { + case irc.CmdPass: + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PASS", "Not enough parameters", + ) + + return + } + + c.passWord = msg.Params[0] + case irc.CmdNick: + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, + "No nickname given", + ) + + return + } + + c.nick = msg.Params[0] + if len(c.nick) > maxNickLen { + c.nick = c.nick[:maxNickLen] + } + + c.gotNick = true + case irc.CmdUser: + if len(msg.Params) < 4 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "USER", "Not enough parameters", + ) + + return + } + + c.username = msg.Params[0] + c.realname = msg.Params[3] + c.gotUser = true + case irc.CmdPing: + c.handlePing(msg) + + return + case irc.CmdQuit: + c.handleQuit(msg) + + return + case "CAP": + c.handleCAP(msg) + + return + default: + c.sendNumeric( + irc.ErrNotRegistered, + "You have not registered", + ) + + return + } + + // Try to complete registration once we have both + // NICK and USER. + if c.gotNick && c.gotUser { + c.completeRegistration(ctx) + } +} + +// completeRegistration creates a session and sends the +// welcome burst. +func (c *Conn) completeRegistration(ctx context.Context) { + // Check if nick is valid. + if c.nick == "" { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + // Create session in DB. + sessionID, clientID, _, err := c.database.CreateSession( + ctx, c.nick, c.username, c.hostname, c.remoteIP, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint") || + strings.Contains(err.Error(), "nick") { + c.sendNumeric( + irc.ErrNicknameInUse, + c.nick, "Nickname is already in use", + ) + + return + } + + c.log.Error( + "failed to create session", "error", err, + ) + c.send("ERROR :Internal server error") + c.closed = true + + return + } + + c.mu.Lock() + c.sessionID = sessionID + c.clientID = clientID + c.registered = true + c.mu.Unlock() + + // If PASS was provided before registration, set the + // session password. + if c.passWord != "" && len(c.passWord) >= minPasswordLen { + c.setPassword(ctx, c.passWord) + } + + // Send welcome burst. + c.deliverWelcome() + c.deliverLusers(ctx) + c.deliverMOTD() + + // Start the message relay goroutine. + go c.relayMessages(ctx) +} + +// deliverWelcome sends 001-005 welcome numerics. +func (c *Conn) deliverWelcome() { + c.sendNumeric(irc.RplWelcome, fmt.Sprintf( + "Welcome to the %s Network, %s", + c.serverSfx, c.hostmask(), + )) + c.sendNumeric(irc.RplYourHost, fmt.Sprintf( + "Your host is %s, running version neoirc", + c.serverSfx, + )) + c.sendNumeric( + irc.RplCreated, + "This server was created recently", + ) + c.sendNumeric( + irc.RplMyInfo, + c.serverSfx, "neoirc", "", "mnst", + ) + c.sendNumeric( + irc.RplIsupport, + "CHANTYPES=#", + "NICKLEN=32", + "PREFIX=(ov)@+", + "CHANMODES=,,H,mnst", + "NETWORK="+c.serverSfx, + "are supported by this server", + ) +} + +// deliverLusers sends 251/252/254/255 server statistics. +func (c *Conn) deliverLusers(ctx context.Context) { + users, _ := c.database.GetUserCount(ctx) + opers, _ := c.database.GetOperCount(ctx) + channels, _ := c.database.GetChannelCount(ctx) + + c.sendNumeric(irc.RplLuserClient, fmt.Sprintf( + "There are %d users and 0 invisible on 1 servers", + users, + )) + c.sendNumeric( + irc.RplLuserOp, + strconv.FormatInt(opers, 10), + "operator(s) online", + ) + c.sendNumeric( + irc.RplLuserChannels, + strconv.FormatInt(channels, 10), + "channels formed", + ) + c.sendNumeric(irc.RplLuserMe, fmt.Sprintf( + "I have %d clients and 1 servers", users, + )) +} + +// deliverMOTD sends 375/372/376 MOTD lines. +func (c *Conn) deliverMOTD() { + motd := c.cfg.MOTD + if motd == "" { + c.sendNumeric( + irc.ErrNoMotd, "MOTD File is missing", + ) + + return + } + + c.sendNumeric(irc.RplMotdStart, fmt.Sprintf( + "- %s Message of the Day -", c.serverSfx, + )) + + for _, line := range strings.Split(motd, "\n") { + c.sendNumeric(irc.RplMotd, "- "+line) + } + + c.sendNumeric( + irc.RplEndOfMotd, "End of /MOTD command", + ) +} + +// setPassword sets a bcrypt password on the session. +func (c *Conn) setPassword(ctx context.Context, pw string) { + // Use the database's auth module to hash and store. + err := c.database.SetPassword(ctx, c.sessionID, pw) + if err != nil { + c.log.Error( + "failed to set password", "error", err, + ) + } +} diff --git a/internal/ircserver/export_test.go b/internal/ircserver/export_test.go new file mode 100644 index 0000000..dc75ee7 --- /dev/null +++ b/internal/ircserver/export_test.go @@ -0,0 +1,52 @@ +package ircserver + +import ( + "context" + "log/slog" + "net" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/service" +) + +// NewTestServer creates a Server suitable for testing. +// The caller must call Stop() when finished. +func NewTestServer( + log *slog.Logger, + cfg *config.Config, + database *db.Database, + brk *broker.Broker, +) *Server { + svc := &service.Service{ + DB: database, + Broker: brk, + Config: cfg, + Log: log, + } + + return &Server{ //nolint:exhaustruct + log: log, + cfg: cfg, + database: database, + brk: brk, + svc: svc, + conns: make(map[*Conn]struct{}), + } +} + +// Start exposes the unexported start method for tests. +func (s *Server) Start(addr string) error { + return s.start(context.Background(), addr) +} + +// Stop exposes the unexported stop method for tests. +func (s *Server) Stop() { + s.stop() +} + +// Listener returns the server's net.Listener for tests. +func (s *Server) Listener() net.Listener { + return s.listener +} diff --git a/internal/ircserver/parser.go b/internal/ircserver/parser.go new file mode 100644 index 0000000..5374bc0 --- /dev/null +++ b/internal/ircserver/parser.go @@ -0,0 +1,123 @@ +// Package ircserver implements a traditional IRC wire protocol +// listener (RFC 1459/2812) that bridges to the neoirc HTTP/JSON +// server internals. +package ircserver + +import "strings" + +// Message represents a parsed IRC wire protocol message. +type Message struct { + // Prefix is the optional :prefix at the start (may be + // empty for client-to-server messages). + Prefix string + // Command is the IRC command (e.g., "PRIVMSG", "NICK"). + Command string + // Params holds the positional parameters, including the + // trailing parameter (which was preceded by ':' on the + // wire). + Params []string +} + +// ParseMessage parses a single IRC wire protocol line +// (without the trailing CR-LF) into a Message. +// Returns nil if the line is empty. +// +// IRC message format (RFC 1459 §2.3.1): +// +// [":" prefix SPACE] command { SPACE param } [SPACE ":" trailing] +func ParseMessage(line string) *Message { + if line == "" { + return nil + } + + msg := &Message{} //nolint:exhaustruct // fields set below + + // Extract prefix if present. + if line[0] == ':' { + idx := strings.IndexByte(line, ' ') + if idx < 0 { + // Only a prefix, no command — invalid. + return nil + } + + msg.Prefix = line[1:idx] + line = line[idx+1:] + } + + // Skip leading spaces. + line = strings.TrimLeft(line, " ") + if line == "" { + return nil + } + + // Extract command. + idx := strings.IndexByte(line, ' ') + if idx < 0 { + msg.Command = strings.ToUpper(line) + + return msg + } + + msg.Command = strings.ToUpper(line[:idx]) + line = line[idx+1:] + + // Extract parameters. + for line != "" { + line = strings.TrimLeft(line, " ") + if line == "" { + break + } + + // Trailing parameter (everything after ':'). + if line[0] == ':' { + msg.Params = append(msg.Params, line[1:]) + + break + } + + idx = strings.IndexByte(line, ' ') + if idx < 0 { + msg.Params = append(msg.Params, line) + + break + } + + msg.Params = append(msg.Params, line[:idx]) + line = line[idx+1:] + } + + return msg +} + +// FormatMessage formats an IRC message into wire protocol +// format (without the trailing CR-LF). +func FormatMessage( + prefix, command string, + params ...string, +) string { + var buf strings.Builder + + if prefix != "" { + buf.WriteByte(':') + buf.WriteString(prefix) + buf.WriteByte(' ') + } + + buf.WriteString(command) + + for i, param := range params { + buf.WriteByte(' ') + + isLast := i == len(params)-1 + needsColon := strings.Contains(param, " ") || + param == "" || param[0] == ':' + + if isLast && needsColon { + buf.WriteByte(':') + } + + buf.WriteString(param) + } + + return buf.String() +} diff --git a/internal/ircserver/parser_test.go b/internal/ircserver/parser_test.go new file mode 100644 index 0000000..8585eab --- /dev/null +++ b/internal/ircserver/parser_test.go @@ -0,0 +1,328 @@ +package ircserver_test + +import ( + "testing" + + "git.eeqj.de/sneak/neoirc/internal/ircserver" +) + +//nolint:funlen // table-driven test +func TestParseMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want *ircserver.Message + wantNil bool + }{ + { + name: "empty", + input: "", + want: nil, + wantNil: true, + }, + { + name: "simple command", + input: "PING", + want: &ircserver.Message{ + Prefix: "", + Command: "PING", + Params: nil, + }, + wantNil: false, + }, + { + name: "command with one param", + input: "NICK alice", + want: &ircserver.Message{ + Prefix: "", + Command: "NICK", + Params: []string{"alice"}, + }, + wantNil: false, + }, + { + name: "command case insensitive", + input: "nick Alice", + want: &ircserver.Message{ + Prefix: "", + Command: "NICK", + Params: []string{"Alice"}, + }, + wantNil: false, + }, + { + name: "privmsg with trailing", + input: "PRIVMSG #general :hello world", + want: &ircserver.Message{ + Prefix: "", + Command: "PRIVMSG", + Params: []string{"#general", "hello world"}, + }, + wantNil: false, + }, + { + name: "with prefix", + input: ":server.example.com 001 alice :Welcome to IRC", + want: &ircserver.Message{ + Prefix: "server.example.com", + Command: "001", + Params: []string{"alice", "Welcome to IRC"}, + }, + wantNil: false, + }, + { + name: "user command", + input: "USER alice 0 * :Alice Smith", + want: &ircserver.Message{ + Prefix: "", + Command: "USER", + Params: []string{ + "alice", "0", "*", "Alice Smith", + }, + }, + wantNil: false, + }, + { + name: "join channel", + input: "JOIN #general", + want: &ircserver.Message{ + Prefix: "", + Command: "JOIN", + Params: []string{"#general"}, + }, + wantNil: false, + }, + { + name: "quit with trailing", + input: "QUIT :leaving now", + want: &ircserver.Message{ + Prefix: "", + Command: "QUIT", + Params: []string{"leaving now"}, + }, + wantNil: false, + }, + { + name: "quit without reason", + input: "QUIT", + want: &ircserver.Message{ + Prefix: "", + Command: "QUIT", + Params: nil, + }, + wantNil: false, + }, + { + name: "mode query", + input: "MODE #general", + want: &ircserver.Message{ + Prefix: "", + Command: "MODE", + Params: []string{"#general"}, + }, + wantNil: false, + }, + { + name: "kick with reason", + input: "KICK #general bob :misbehaving", + want: &ircserver.Message{ + Prefix: "", + Command: "KICK", + Params: []string{ + "#general", "bob", "misbehaving", + }, + }, + wantNil: false, + }, + { + name: "empty trailing", + input: "PRIVMSG #general :", + want: &ircserver.Message{ + Prefix: "", + Command: "PRIVMSG", + Params: []string{"#general", ""}, + }, + wantNil: false, + }, + { + name: "pass command", + input: "PASS mysecret", + want: &ircserver.Message{ + Prefix: "", + Command: "PASS", + Params: []string{"mysecret"}, + }, + wantNil: false, + }, + { + name: "ping with server", + input: "PING :irc.example.com", + want: &ircserver.Message{ + Prefix: "", + Command: "PING", + Params: []string{"irc.example.com"}, + }, + wantNil: false, + }, + { + name: "topic with trailing spaces", + input: "TOPIC #general :Welcome to the channel!", + want: &ircserver.Message{ + Prefix: "", + Command: "TOPIC", + Params: []string{ + "#general", + "Welcome to the channel!", + }, + }, + wantNil: false, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := ircserver.ParseMessage(testCase.input) + if testCase.wantNil { + if got != nil { + t.Fatalf("expected nil, got %+v", got) + } + + return + } + + if got == nil { + t.Fatal("expected non-nil message") + } + + if got.Prefix != testCase.want.Prefix { + t.Errorf( + "prefix: got %q, want %q", + got.Prefix, testCase.want.Prefix, + ) + } + + if got.Command != testCase.want.Command { + t.Errorf( + "command: got %q, want %q", + got.Command, testCase.want.Command, + ) + } + + if len(got.Params) != len(testCase.want.Params) { + t.Fatalf( + "params length: got %d, want %d (%v vs %v)", + len(got.Params), + len(testCase.want.Params), + got.Params, + testCase.want.Params, + ) + } + + for i, p := range got.Params { + if p != testCase.want.Params[i] { + t.Errorf( + "param[%d]: got %q, want %q", + i, p, testCase.want.Params[i], + ) + } + } + }) + } +} + +func TestFormatMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prefix string + command string + params []string + want string + }{ + { + name: "simple command", + prefix: "", + command: "PING", + params: nil, + want: "PING", + }, + { + name: "with prefix", + prefix: "server", + command: "PONG", + params: []string{"server"}, + want: ":server PONG server", + }, + { + name: "privmsg with trailing", + prefix: "alice!alice@host", + command: "PRIVMSG", + params: []string{"#general", "hello world"}, + want: ":alice!alice@host PRIVMSG #general :hello world", + }, + { + name: "numeric reply", + prefix: "server", + command: "001", + params: []string{"alice", "Welcome to IRC"}, + want: ":server 001 alice :Welcome to IRC", + }, + { + name: "empty trailing", + prefix: "server", + command: "PRIVMSG", + params: []string{"#chan", ""}, + want: ":server PRIVMSG #chan :", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := ircserver.FormatMessage( + testCase.prefix, testCase.command, testCase.params..., + ) + if got != testCase.want { + t.Errorf("got %q, want %q", got, testCase.want) + } + }) + } +} + +func TestParseFormatRoundTrip(t *testing.T) { + t.Parallel() + + // Round-trip only works for lines where the last + // parameter either contains a space (gets ':' prefix + // on format) or is a non-trailing single token. + lines := []string{ + "PING", + "NICK alice", + "PRIVMSG #general :hello world", + "JOIN #general", + "MODE #general", + } + + for _, line := range lines { + msg := ircserver.ParseMessage(line) + if msg == nil { + t.Fatalf("failed to parse: %q", line) + } + + formatted := ircserver.FormatMessage( + msg.Prefix, msg.Command, msg.Params..., + ) + if formatted != line { + t.Errorf( + "round-trip failed: input %q, got %q", + line, formatted, + ) + } + } +} diff --git a/internal/ircserver/relay.go b/internal/ircserver/relay.go new file mode 100644 index 0000000..c1f3436 --- /dev/null +++ b/internal/ircserver/relay.go @@ -0,0 +1,319 @@ +package ircserver + +import ( + "context" + "encoding/json" + "strings" + "time" + + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +// relayMessages polls the client output queue and delivers +// IRC-formatted messages to the TCP connection. It runs +// in a goroutine for the lifetime of the connection. +func (c *Conn) relayMessages(ctx context.Context) { + // Use a ticker as a fallback; primary wakeup is via + // broker notification. + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + default: + } + + // Drain any available messages. + delivered := c.drainQueue(ctx) + + if delivered { + // Tight loop while there are messages. + continue + } + + // Wait for notification or timeout. + waitCh := c.brk.Wait(c.sessionID) + + select { + case <-waitCh: + // New message notification — loop back. + case <-ticker.C: + // Periodic check. + case <-ctx.Done(): + c.brk.Remove(c.sessionID, waitCh) + + return + } + } +} + +const relayPollLimit = 100 + +// drainQueue polls the output queue and delivers all +// pending messages. Returns true if at least one message +// was delivered. +func (c *Conn) drainQueue(ctx context.Context) bool { + msgs, lastID, err := c.database.PollMessages( + ctx, c.clientID, c.lastQueueID, relayPollLimit, + ) + if err != nil { + return false + } + + if len(msgs) == 0 { + return false + } + + for i := range msgs { + c.deliverIRCMessage(ctx, &msgs[i]) + } + + if lastID > c.lastQueueID { + c.lastQueueID = lastID + } + + return true +} + +// deliverIRCMessage converts a db.IRCMessage to wire +// protocol and sends it. +// +//nolint:cyclop // dispatch table +func (c *Conn) deliverIRCMessage( + _ context.Context, + msg *db.IRCMessage, +) { + command := msg.Command + + // Decode body as []string for the trailing text. + var bodyLines []string + + if msg.Body != nil { + _ = json.Unmarshal(msg.Body, &bodyLines) + } + + text := "" + if len(bodyLines) > 0 { + text = bodyLines[0] + } + + // Route by command type. + switch { + case isNumeric(command): + c.deliverNumeric(msg, text) + case command == irc.CmdPrivmsg || command == irc.CmdNotice: + c.deliverTextMessage(msg, command, text) + case command == irc.CmdJoin: + c.deliverJoin(msg) + case command == irc.CmdPart: + c.deliverPart(msg, text) + case command == irc.CmdNick: + c.deliverNickChange(msg, text) + case command == irc.CmdQuit: + c.deliverQuitMsg(msg, text) + case command == irc.CmdTopic: + c.deliverTopicChange(msg, text) + case command == irc.CmdKick: + c.deliverKickMsg(msg, text) + case command == "INVITE": + c.deliverInviteMsg(msg, text) + case command == irc.CmdMode: + c.deliverMode(msg, text) + case command == irc.CmdPing: + // Server-originated PING — reply with PONG. + c.sendFromServer("PING", c.serverSfx) + default: + // Unknown command — deliver as server notice. + if text != "" { + c.sendFromServer("NOTICE", c.nick, text) + } + } +} + +// isNumeric returns true if the command is a 3-digit +// numeric code. +func isNumeric(cmd string) bool { + return len(cmd) == 3 && + cmd[0] >= '0' && cmd[0] <= '9' && + cmd[1] >= '0' && cmd[1] <= '9' && + cmd[2] >= '0' && cmd[2] <= '9' +} + +// deliverNumeric sends a numeric reply. +func (c *Conn) deliverNumeric( + msg *db.IRCMessage, + text string, +) { + from := msg.From + if from == "" { + from = c.serverSfx + } + + var params []string + + if msg.Params != nil { + _ = json.Unmarshal(msg.Params, ¶ms) + } + + allParams := make([]string, 0, 1+len(params)+1) + allParams = append(allParams, c.nick) + allParams = append(allParams, params...) + + if text != "" { + allParams = append(allParams, text) + } + + c.send(FormatMessage(from, msg.Command, allParams...)) +} + +// deliverTextMessage sends PRIVMSG or NOTICE. +func (c *Conn) deliverTextMessage( + msg *db.IRCMessage, + command, text string, +) { + from := msg.From + target := msg.To + + // Don't echo our own messages back. + if strings.EqualFold(from, c.nick) { + return + } + + prefix := from + if !strings.Contains(prefix, "!") { + prefix = from + "!" + from + "@*" + } + + c.send(FormatMessage(prefix, command, target, text)) +} + +// deliverJoin sends a JOIN notification. +func (c *Conn) deliverJoin(msg *db.IRCMessage) { + // Don't echo our own JOINs (we already sent them + // during joinChannel). + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + c.send(FormatMessage(prefix, "JOIN", channel)) +} + +// deliverPart sends a PART notification. +func (c *Conn) deliverPart(msg *db.IRCMessage, text string) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + if text != "" { + c.send(FormatMessage( + prefix, "PART", channel, text, + )) + } else { + c.send(FormatMessage(prefix, "PART", channel)) + } +} + +// deliverNickChange sends a NICK change notification. +func (c *Conn) deliverNickChange( + msg *db.IRCMessage, + newNick string, +) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + + c.send(FormatMessage(prefix, "NICK", newNick)) +} + +// deliverQuitMsg sends a QUIT notification. +func (c *Conn) deliverQuitMsg( + msg *db.IRCMessage, + text string, +) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + + if text != "" { + c.send(FormatMessage( + prefix, "QUIT", "Quit: "+text, + )) + } else { + c.send(FormatMessage(prefix, "QUIT", "Quit")) + } +} + +// deliverTopicChange sends a TOPIC change notification. +func (c *Conn) deliverTopicChange( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + c.send(FormatMessage(prefix, "TOPIC", channel, text)) +} + +// deliverKickMsg sends a KICK notification. +func (c *Conn) deliverKickMsg( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + var params []string + + if msg.Params != nil { + _ = json.Unmarshal(msg.Params, ¶ms) + } + + kickTarget := "" + if len(params) > 0 { + kickTarget = params[0] + } + + if kickTarget != "" { + c.send(FormatMessage( + prefix, "KICK", channel, kickTarget, text, + )) + } else { + c.send(FormatMessage( + prefix, "KICK", channel, "?", text, + )) + } +} + +// deliverInviteMsg sends an INVITE notification. +func (c *Conn) deliverInviteMsg( + _ *db.IRCMessage, + text string, +) { + c.sendFromServer("NOTICE", c.nick, text) +} + +// deliverMode sends a MODE change notification. +func (c *Conn) deliverMode( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + target := msg.To + + if text != "" { + c.send(FormatMessage(prefix, "MODE", target, text)) + } +} diff --git a/internal/ircserver/server.go b/internal/ircserver/server.go new file mode 100644 index 0000000..0ee9256 --- /dev/null +++ b/internal/ircserver/server.go @@ -0,0 +1,157 @@ +package ircserver + +import ( + "context" + "fmt" + "log/slog" + "net" + "sync" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/logger" + "git.eeqj.de/sneak/neoirc/internal/service" + "go.uber.org/fx" +) + +// Params defines the dependencies for creating an IRC +// Server. +type Params struct { + fx.In + + Logger *logger.Logger + Config *config.Config + Database *db.Database + Broker *broker.Broker + Service *service.Service +} + +// Server is the TCP IRC protocol server. +type Server struct { + log *slog.Logger + cfg *config.Config + database *db.Database + brk *broker.Broker + svc *service.Service + listener net.Listener + mu sync.Mutex + conns map[*Conn]struct{} + cancel context.CancelFunc +} + +// New creates a new IRC Server and registers its lifecycle +// hooks. The listener is only started if IRC_LISTEN_ADDR +// is configured; otherwise the server is inert. +func New( + lifecycle fx.Lifecycle, + params Params, +) *Server { + srv := &Server{ + log: params.Logger.Get(), + cfg: params.Config, + database: params.Database, + brk: params.Broker, + svc: params.Service, + conns: make(map[*Conn]struct{}), + listener: nil, + cancel: nil, + mu: sync.Mutex{}, + } + + listenAddr := params.Config.IRCListenAddr + if listenAddr == "" { + return srv + } + + lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return srv.start(ctx, listenAddr) + }, + OnStop: func(_ context.Context) error { + srv.stop() + + return nil + }, + }) + + return srv +} + +// start begins listening for TCP connections. +// +//nolint:contextcheck // long-lived server ctx, not the short Fx one +func (s *Server) start(_ context.Context, addr string) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("irc listen: %w", err) + } + + s.listener = ln + + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + s.log.Info( + "irc server listening", "addr", addr, + ) + + go s.acceptLoop(ctx) + + return nil +} + +// stop shuts down the listener and all connections. +func (s *Server) stop() { + if s.cancel != nil { + s.cancel() + } + + if s.listener != nil { + s.listener.Close() //nolint:errcheck,gosec + } + + s.mu.Lock() + for c := range s.conns { + c.conn.Close() //nolint:errcheck,gosec + } + s.mu.Unlock() +} + +// acceptLoop accepts new connections. +func (s *Server) acceptLoop(ctx context.Context) { + for { + tcpConn, err := s.listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + return + default: + s.log.Error( + "irc accept error", "error", err, + ) + + continue + } + } + + client := newConn( + ctx, tcpConn, s.log, + s.database, s.brk, s.cfg, s.svc, + ) + + s.mu.Lock() + s.conns[client] = struct{}{} + s.mu.Unlock() + + go func() { + defer func() { + s.mu.Lock() + delete(s.conns, client) + s.mu.Unlock() + }() + + client.serve(ctx) + }() + } +} diff --git a/internal/ircserver/server_test.go b/internal/ircserver/server_test.go new file mode 100644 index 0000000..fb21e6b --- /dev/null +++ b/internal/ircserver/server_test.go @@ -0,0 +1,625 @@ +package ircserver_test + +import ( + "bufio" + "database/sql" + "fmt" + "log/slog" + "net" + "os" + "strings" + "testing" + "time" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/ircserver" + + _ "modernc.org/sqlite" +) + +const testTimeout = 5 * time.Second + +func TestMain(m *testing.M) { + db.SetBcryptCost(4) + + os.Exit(m.Run()) +} + +// testEnv holds the shared test infrastructure. +type testEnv struct { + database *db.Database + brk *broker.Broker + cfg *config.Config + srv *ircserver.Server +} + +func newTestEnv(t *testing.T) *testEnv { + t.Helper() + + dsn := fmt.Sprintf( + "file:%s?mode=memory&cache=shared&_journal_mode=WAL", + t.Name(), + ) + + conn, err := sql.Open("sqlite", dsn) + if err != nil { + t.Fatalf("open db: %v", err) + } + + conn.SetMaxOpenConns(1) + + _, err = conn.ExecContext( + t.Context(), "PRAGMA foreign_keys = ON", + ) + if err != nil { + t.Fatalf("pragma: %v", err) + } + + database := db.NewTestDatabaseFromConn(conn) + + err = database.RunMigrations(t.Context()) + if err != nil { + t.Fatalf("migrate: %v", err) + } + + brk := broker.New() + + cfg := &config.Config{ //nolint:exhaustruct + ServerName: "test.irc", + MOTD: "Welcome to test IRC", + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + addr := listener.Addr().String() + + err = listener.Close() + if err != nil { + t.Fatalf("close listener: %v", err) + } + + log := slog.New(slog.NewTextHandler( + os.Stderr, + &slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct + )) + + srv := ircserver.NewTestServer(log, cfg, database, brk) + + err = srv.Start(addr) + if err != nil { + t.Fatalf("start irc server: %v", err) + } + + t.Cleanup(func() { + srv.Stop() + + err := conn.Close() + if err != nil { + t.Logf("close db: %v", err) + } + }) + + return &testEnv{ + database: database, + brk: brk, + cfg: cfg, + srv: srv, + } +} + +// dial connects to the test server. +func (env *testEnv) dial(t *testing.T) *testClient { + t.Helper() + + conn, err := net.DialTimeout( + "tcp", + env.srv.Listener().Addr().String(), + testTimeout, + ) + if err != nil { + t.Fatalf("dial: %v", err) + } + + t.Cleanup(func() { + err := conn.Close() + if err != nil { + t.Logf("close conn: %v", err) + } + }) + + return &testClient{ + t: t, + conn: conn, + scanner: bufio.NewScanner(conn), + } +} + +// testClient wraps a raw TCP connection with helpers. +type testClient struct { + t *testing.T + conn net.Conn + scanner *bufio.Scanner +} + +func (tc *testClient) send(line string) { + tc.t.Helper() + + _ = tc.conn.SetWriteDeadline( + time.Now().Add(testTimeout), + ) + + _, err := fmt.Fprintf(tc.conn, "%s\r\n", line) + if err != nil { + tc.t.Fatalf("send: %v", err) + } +} + +func (tc *testClient) readLine() string { + tc.t.Helper() + + _ = tc.conn.SetReadDeadline( + time.Now().Add(testTimeout), + ) + + if !tc.scanner.Scan() { + err := tc.scanner.Err() + if err != nil { + tc.t.Fatalf("read: %v", err) + } + + tc.t.Fatal("connection closed unexpectedly") + } + + return tc.scanner.Text() +} + +// readUntil reads lines until one matches the predicate. +func (tc *testClient) readUntil( + pred func(string) bool, +) []string { + tc.t.Helper() + + var lines []string + + for { + line := tc.readLine() + lines = append(lines, line) + + if pred(line) { + return lines + } + } +} + +// register sends NICK + USER and reads through the welcome +// burst. +func (tc *testClient) register(nick string) []string { + tc.t.Helper() + + tc.send("NICK " + nick) + tc.send("USER " + nick + " 0 * :Test User") + + return tc.readUntil(func(line string) bool { + return strings.Contains(line, " 376 ") || + strings.Contains(line, " 422 ") + }) +} + +// assertContains checks that at least one line matches the +// given substring. +func assertContains( + t *testing.T, + lines []string, + substr, description string, +) { + t.Helper() + + for _, line := range lines { + if strings.Contains(line, substr) { + return + } + } + + t.Errorf("did not find %q in output: %s", substr, description) +} + +// joinAndDrain joins a channel and reads until +// RPL_ENDOFNAMES. +func (tc *testClient) joinAndDrain(channel string) { + tc.t.Helper() + + tc.send("JOIN " + channel) + + tc.readUntil(func(line string) bool { + return strings.Contains(line, " 366 ") + }) +} + +// sendAndExpect sends a command and reads until a line +// containing the expected substring is found. +func (tc *testClient) sendAndExpect( + cmd, expect string, +) []string { + tc.t.Helper() + + tc.send(cmd) + + return tc.readUntil(func(line string) bool { + return strings.Contains(line, expect) + }) +} + +func TestRegistration(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + lines := client.register("alice") + assertContains(t, lines, " 001 ", "RPL_WELCOME") +} + +func TestWelcomeContainsNick(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + lines := client.register("bob") + + for _, line := range lines { + if strings.Contains(line, " 001 ") && + !strings.Contains(line, "bob") { + t.Errorf("001 should contain nick: %s", line) + } + } +} + +func TestPingPong(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("pingtest") + lines := client.sendAndExpect("PING :hello", "PONG") + assertContains(t, lines, "PONG", "PONG response") +} + +func TestJoinChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("joiner") + client.send("JOIN #test") + + lines := client.readUntil(func(line string) bool { + return strings.Contains(line, " 366 ") + }) + + assertContains(t, lines, "JOIN", "JOIN echo") + assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES") +} + +func TestPrivmsgBetweenClients(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + alice := env.dial(t) + alice.register("alice_pm") + + bob := env.dial(t) + bob.register("bob_pm") + + alice.joinAndDrain("#chat") + bob.joinAndDrain("#chat") + + alice.send("PRIVMSG #chat :hello bob!") + lines := bob.sendAndExpect("PING :sync", "hello bob!") + assertContains(t, lines, "hello bob!", "channel PRIVMSG") +} + +func TestNickChange(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("oldnick") + lines := client.sendAndExpect("NICK newnick", "newnick") + assertContains(t, lines, "NICK", "NICK change echo") +} + +func TestDuplicateNick(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + first := env.dial(t) + first.register("taken") + + second := env.dial(t) + second.send("NICK taken") + second.send("USER taken 0 * :Test") + + lines := second.readUntil(func(line string) bool { + return strings.Contains(line, " 433 ") + }) + + assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE") +} + +func TestListChannels(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("lister") + client.joinAndDrain("#listtest") + lines := client.sendAndExpect("LIST", " 323 ") + assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term +} + +func TestWhois(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("whoistest") + lines := client.sendAndExpect( + "WHOIS whoistest", " 318 ", + ) + + assertContains(t, lines, " 311 ", "RPL_WHOISUSER") + assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS") +} + +func TestQuit(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("quitter") + lines := client.sendAndExpect( + "QUIT :goodbye", "ERROR", + ) + + assertContains(t, lines, "goodbye", "QUIT reason") +} + +func TestTopicSetAndGet(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("topicuser") + client.joinAndDrain("#topictest") + lines := client.sendAndExpect( + "TOPIC #topictest :New topic here", + "New topic here", + ) + + assertContains( + t, lines, "New topic here", "TOPIC echo", + ) +} + +func TestUnknownCommand(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("unknowncmd") + lines := client.sendAndExpect("FOOBAR", " 421 ") + assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND") +} + +func TestDirectMessage(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + sender := env.dial(t) + sender.register("dmsender") + + receiver := env.dial(t) + receiver.register("dmreceiver") + + // Give relay goroutines time to start. + time.Sleep(100 * time.Millisecond) + + sender.send("PRIVMSG dmreceiver :hello privately") + + lines := receiver.readUntil(func(line string) bool { + return strings.Contains(line, "hello privately") + }) + + assertContains( + t, lines, "hello privately", "direct PRIVMSG", + ) +} + +func TestCAPNegotiation(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.send("CAP LS 302") + + line := client.readLine() + if !strings.Contains(line, "CAP") { + t.Errorf("expected CAP response, got: %s", line) + } + + client.send("CAP END") + lines := client.register("capuser") + assertContains(t, lines, " 001 ", "registration after CAP") +} + +func TestPartChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("parter") + client.joinAndDrain("#parttest") + lines := client.sendAndExpect( + "PART #parttest :leaving", "PART", + ) + + assertContains(t, lines, "#parttest", "PART echo") +} + +func TestModeQuery(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("modeuser") + client.joinAndDrain("#modetest") + lines := client.sendAndExpect( + "MODE #modetest", " 324 ", + ) + + assertContains( + t, lines, " 324 ", "RPL_CHANNELMODEIS", + ) +} + +func TestWhoChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("whouser") + client.joinAndDrain("#whotest") + lines := client.sendAndExpect("WHO #whotest", " 315 ") + assertContains(t, lines, " 352 ", "RPL_WHOREPLY") +} + +func TestLusers(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("luseruser") + lines := client.sendAndExpect("LUSERS", " 255 ") + assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT") +} + +func TestMotd(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("motduser") + lines := client.sendAndExpect("MOTD", " 376 ") + assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD") +} + +func TestAwaySetAndClear(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("awayuser") + + setLines := client.sendAndExpect( + "AWAY :brb lunch", " 306 ", + ) + assertContains(t, setLines, " 306 ", "RPL_NOWAWAY") + + clearLines := client.sendAndExpect("AWAY", " 305 ") + assertContains(t, clearLines, " 305 ", "RPL_UNAWAY") +} + +func TestHandlePassPostRegistration(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("passuser") + lines := client.sendAndExpect( + "PASS :mypassword123", "Password set", + ) + + assertContains( + t, lines, "Password set", "password confirmation", + ) +} + +func TestPreRegistrationNotRegistered(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.send("PRIVMSG #test :hello") + + line := client.readLine() + if !strings.Contains(line, " 451 ") { + t.Errorf( + "expected ERR_NOTREGISTERED (451), got: %s", + line, + ) + } +} + +func TestNamesNonExistentChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("namesuser") + lines := client.sendAndExpect( + "NAMES #doesnotexist", " 366 ", + ) + + assertContains( + t, lines, " 366 ", + "RPL_ENDOFNAMES for non-existent channel", + ) +} + +func BenchmarkParseMessage(b *testing.B) { + line := ":nick!user@host PRIVMSG #channel :Hello, world!" + + b.ResetTimer() + + for range b.N { + _ = ircserver.ParseMessage(line) + } +} + +func BenchmarkFormatMessage(b *testing.B) { + b.ResetTimer() + + for range b.N { + _ = ircserver.FormatMessage( + "nick!user@host", "PRIVMSG", + "#channel", "Hello, world!", + ) + } +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 0000000..a8304fa --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,839 @@ +// Package service provides shared business logic for both +// the IRC wire protocol and HTTP/JSON transports. +package service + +import ( + "context" + "crypto/subtle" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/logger" + "git.eeqj.de/sneak/neoirc/pkg/irc" + "go.uber.org/fx" +) + +// Params defines the dependencies for creating a Service. +type Params struct { + fx.In + + Logger *logger.Logger + Config *config.Config + Database *db.Database + Broker *broker.Broker +} + +// Service provides shared business logic for IRC commands. +type Service struct { + DB *db.Database + Broker *broker.Broker + Config *config.Config + Log *slog.Logger +} + +// New creates a new Service. +func New(params Params) *Service { + return &Service{ + DB: params.Database, + Broker: params.Broker, + Config: params.Config, + Log: params.Logger.Get(), + } +} + +// IRCError represents an IRC protocol-level error with a +// numeric code that both transports can map to responses. +type IRCError struct { + Code irc.IRCMessageType + Params []string + Message string +} + +func (e *IRCError) Error() string { return e.Message } + +// JoinResult contains the outcome of a channel join. +type JoinResult struct { + ChannelID int64 + IsCreator bool +} + +// DirectMsgResult contains the outcome of a direct message. +type DirectMsgResult struct { + UUID string + AwayMsg string +} + +// FanOut inserts a message and enqueues it to all given +// session IDs, notifying each via the broker. +func (s *Service) FanOut( + ctx context.Context, + command, from, to string, + params, body, meta json.RawMessage, + sessionIDs []int64, +) (int64, string, error) { + dbID, msgUUID, err := s.DB.InsertMessage( + ctx, command, from, to, params, body, meta, + ) + if err != nil { + return 0, "", fmt.Errorf("insert message: %w", err) + } + + for _, sid := range sessionIDs { + _ = s.DB.EnqueueToSession(ctx, sid, dbID) + s.Broker.Notify(sid) + } + + return dbID, msgUUID, nil +} + +// excludeSession returns a copy of ids without the given +// session. +func excludeSession( + ids []int64, + exclude int64, +) []int64 { + out := make([]int64, 0, len(ids)) + + for _, id := range ids { + if id != exclude { + out = append(out, id) + } + } + + return out +} + +// SendChannelMessage validates membership and moderation, +// then fans out a message to all channel members except +// the sender. Returns the database row ID, message UUID, +// and any error. The dbID lets callers enqueue the same +// message to the sender when echo is needed (HTTP +// transport). +func (s *Service) SendChannelMessage( + ctx context.Context, + sessionID int64, + nick, command, channel string, + body, meta json.RawMessage, +) (int64, string, error) { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return 0, "", &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, sessionID, + ) + if !isMember { + return 0, "", &IRCError{ + irc.ErrCannotSendToChan, + []string{channel}, + "Cannot send to channel", + } + } + + // Ban check — banned users cannot send messages. + isBanned, banErr := s.DB.IsSessionBanned( + ctx, chID, sessionID, + ) + if banErr == nil && isBanned { + return 0, "", &IRCError{ + irc.ErrCannotSendToChan, + []string{channel}, + "Cannot send to channel (+b)", + } + } + + moderated, _ := s.DB.IsChannelModerated(ctx, chID) + if moderated { + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + isVoiced, _ := s.DB.IsChannelVoiced( + ctx, chID, sessionID, + ) + + if !isOp && !isVoiced { + return 0, "", &IRCError{ + irc.ErrCannotSendToChan, + []string{channel}, + "Cannot send to channel (+m)", + } + } + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + recipients := excludeSession(memberIDs, sessionID) + + dbID, uuid, fanErr := s.FanOut( + ctx, command, nick, channel, + nil, body, meta, recipients, + ) + if fanErr != nil { + return 0, "", fanErr + } + + return dbID, uuid, nil +} + +// SendDirectMessage validates the target and sends a +// direct message, returning the message UUID and any away +// message set on the target. +func (s *Service) SendDirectMessage( + ctx context.Context, + sessionID int64, + nick, command, target string, + body, meta json.RawMessage, +) (*DirectMsgResult, error) { + targetSID, err := s.DB.GetSessionByNick(ctx, target) + if err != nil { + return nil, &IRCError{ + irc.ErrNoSuchNick, + []string{target}, + "No such nick", + } + } + + away, _ := s.DB.GetAway(ctx, targetSID) + + recipients := []int64{targetSID} + if targetSID != sessionID { + recipients = append(recipients, sessionID) + } + + _, uuid, fanErr := s.FanOut( + ctx, command, nick, target, + nil, body, meta, recipients, + ) + if fanErr != nil { + return nil, fanErr + } + + return &DirectMsgResult{UUID: uuid, AwayMsg: away}, nil +} + +// JoinChannel creates or joins a channel, making the +// first joiner the operator. Fans out the JOIN to all +// channel members. +func (s *Service) JoinChannel( + ctx context.Context, + sessionID int64, + nick, channel, suppliedKey string, +) (*JoinResult, error) { + chID, err := s.DB.GetOrCreateChannel(ctx, channel) + if err != nil { + return nil, fmt.Errorf("get/create channel: %w", err) + } + + memberCount, countErr := s.DB.CountChannelMembers( + ctx, chID, + ) + isCreator := countErr == nil && memberCount == 0 + + if !isCreator { + if joinErr := checkJoinRestrictions( + ctx, s.DB, chID, sessionID, + channel, suppliedKey, memberCount, + ); joinErr != nil { + return nil, joinErr + } + } + + if isCreator { + err = s.DB.JoinChannelAsOperator( + ctx, chID, sessionID, + ) + } else { + err = s.DB.JoinChannel(ctx, chID, sessionID) + } + + if err != nil { + return nil, fmt.Errorf("join channel: %w", err) + } + + // Clear invite after successful join. + _ = s.DB.ClearChannelInvite(ctx, chID, sessionID) + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{channel}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdJoin, nick, channel, + nil, body, nil, memberIDs, + ) + + return &JoinResult{ + ChannelID: chID, + IsCreator: isCreator, + }, nil +} + +// PartChannel validates membership, broadcasts PART to +// remaining members, removes the user, and cleans up empty +// channels. +func (s *Service) PartChannel( + ctx context.Context, + sessionID int64, + nick, channel, reason string, +) error { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, sessionID, + ) + if !isMember { + return &IRCError{ + irc.ErrNotOnChannel, + []string{channel}, + "You're not on that channel", + } + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + recipients := excludeSession(memberIDs, sessionID) + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdPart, nick, channel, + nil, body, nil, recipients, + ) + + s.DB.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec + s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + + return nil +} + +// SetTopic validates membership and topic-lock, sets the +// topic, and broadcasts the change. +func (s *Service) SetTopic( + ctx context.Context, + sessionID int64, + nick, channel, topic string, +) error { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, sessionID, + ) + if !isMember { + return &IRCError{ + irc.ErrNotOnChannel, + []string{channel}, + "You're not on that channel", + } + } + + topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) + if topicLocked { + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + if !isOp { + return &IRCError{ + irc.ErrChanOpPrivsNeeded, + []string{channel}, + "You're not channel operator", + } + } + } + + if setErr := s.DB.SetTopic( + ctx, channel, topic, + ); setErr != nil { + return fmt.Errorf("set topic: %w", setErr) + } + + _ = s.DB.SetTopicMeta(ctx, channel, topic, nick) + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{topic}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdTopic, nick, channel, + nil, body, nil, memberIDs, + ) + + return nil +} + +// KickUser validates operator status and target +// membership, broadcasts the KICK, removes the target, +// and cleans up empty channels. +func (s *Service) KickUser( + ctx context.Context, + sessionID int64, + nick, channel, targetNick, reason string, +) error { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + if !isOp { + return &IRCError{ + irc.ErrChanOpPrivsNeeded, + []string{channel}, + "You're not channel operator", + } + } + + targetSID, err := s.DB.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + return &IRCError{ + irc.ErrNoSuchNick, + []string{targetNick}, + "No such nick/channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, targetSID, + ) + if !isMember { + return &IRCError{ + irc.ErrUserNotInChannel, + []string{targetNick, channel}, + "They aren't on that channel", + } + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + params, _ := json.Marshal( //nolint:errchkjson + []string{targetNick}, + ) + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdKick, nick, channel, + params, body, nil, memberIDs, + ) + + s.DB.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec + s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + + return nil +} + +// ChangeNick changes a user's nickname and broadcasts the +// change to all users sharing channels. +func (s *Service) ChangeNick( + ctx context.Context, + sessionID int64, + oldNick, newNick string, +) error { + err := s.DB.ChangeNick(ctx, sessionID, newNick) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE") || + db.IsUniqueConstraintError(err) { + return &IRCError{ + irc.ErrNicknameInUse, + []string{newNick}, + "Nickname is already in use", + } + } + + return &IRCError{ + irc.ErrErroneusNickname, + []string{newNick}, + "Erroneous nickname", + } + } + + s.broadcastNickChange(ctx, sessionID, oldNick, newNick) + + return nil +} + +// BroadcastQuit broadcasts a QUIT to all channel peers, +// parts all channels, and deletes the session. Uses the +// FanOut pattern: one message row fanned out to all unique +// peer sessions. +func (s *Service) BroadcastQuit( + ctx context.Context, + sessionID int64, + nick, reason string, +) { + channels, err := s.DB.GetSessionChannels( + ctx, sessionID, + ) + if err != nil { + return + } + + notified := make(map[int64]bool) + + for _, ch := range channels { + memberIDs, memErr := s.DB.GetChannelMemberIDs( + ctx, ch.ID, + ) + if memErr != nil { + continue + } + + for _, mid := range memberIDs { + if mid == sessionID || notified[mid] { + continue + } + + notified[mid] = true + } + } + + if len(notified) > 0 { + recipients := make([]int64, 0, len(notified)) + for sid := range notified { + recipients = append(recipients, sid) + } + + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + _, _, _ = s.FanOut( + ctx, irc.CmdQuit, nick, "", + nil, body, nil, recipients, + ) + } + + for _, ch := range channels { + s.DB.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec + s.DB.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec + } + + s.DB.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec +} + +// SetAway sets or clears the away message. Returns true +// if the message was cleared (empty string). +func (s *Service) SetAway( + ctx context.Context, + sessionID int64, + message string, +) (bool, error) { + err := s.DB.SetAway(ctx, sessionID, message) + if err != nil { + return false, fmt.Errorf("set away: %w", err) + } + + return message == "", nil +} + +// Oper validates operator credentials and grants oper +// status to the session. +func (s *Service) Oper( + ctx context.Context, + sessionID int64, + name, password string, +) error { + cfgName := s.Config.OperName + cfgPassword := s.Config.OperPassword + + // Use constant-time comparison and return the same + // error for all failures to prevent information + // leakage about valid operator names. + if cfgName == "" || cfgPassword == "" || + subtle.ConstantTimeCompare( + []byte(name), []byte(cfgName), + ) != 1 || + subtle.ConstantTimeCompare( + []byte(password), []byte(cfgPassword), + ) != 1 { + return &IRCError{ + irc.ErrNoOperHost, + nil, + "No O-lines for your host", + } + } + + _ = s.DB.SetSessionOper(ctx, sessionID, true) + + return nil +} + +// ValidateChannelOp checks that the session is a channel +// operator. Returns the channel ID. +func (s *Service) ValidateChannelOp( + ctx context.Context, + sessionID int64, + channel string, +) (int64, error) { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return 0, &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + if !isOp { + return 0, &IRCError{ + irc.ErrChanOpPrivsNeeded, + []string{channel}, + "You're not channel operator", + } + } + + return chID, nil +} + +// ApplyMemberMode applies +o/-o or +v/-v on a channel +// member after validating the target. +func (s *Service) ApplyMemberMode( + ctx context.Context, + chID int64, + channel, targetNick string, + mode rune, + adding bool, +) error { + targetSID, err := s.DB.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + return &IRCError{ + irc.ErrNoSuchNick, + []string{targetNick}, + "No such nick/channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, targetSID, + ) + if !isMember { + return &IRCError{ + irc.ErrUserNotInChannel, + []string{targetNick, channel}, + "They aren't on that channel", + } + } + + switch mode { + case 'o': + _ = s.DB.SetChannelMemberOperator( + ctx, chID, targetSID, adding, + ) + case 'v': + _ = s.DB.SetChannelMemberVoiced( + ctx, chID, targetSID, adding, + ) + } + + return nil +} + +// SetChannelFlag applies +m/-m or +t/-t on a channel. +func (s *Service) SetChannelFlag( + ctx context.Context, + chID int64, + flag rune, + setting bool, +) error { + switch flag { + case 'm': + if err := s.DB.SetChannelModerated( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf("set moderated: %w", err) + } + case 't': + if err := s.DB.SetChannelTopicLocked( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf("set topic locked: %w", err) + } + case 'i': + if err := s.DB.SetChannelInviteOnly( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf("set invite only: %w", err) + } + case 's': + if err := s.DB.SetChannelSecret( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf("set secret: %w", err) + } + } + + return nil +} + +// BroadcastMode fans out a MODE change to all channel +// members. +func (s *Service) BroadcastMode( + ctx context.Context, + nick, channel string, + chID int64, + modeText string, +) { + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdMode, nick, channel, + nil, body, nil, memberIDs, + ) +} + +// QueryChannelMode returns the channel mode string. +func (s *Service) QueryChannelMode( + ctx context.Context, + chID int64, +) string { + modes := "+" + + moderated, _ := s.DB.IsChannelModerated(ctx, chID) + if moderated { + modes += "m" + } + + topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) + if topicLocked { + modes += "t" + } + + return modes +} + +// broadcastNickChange notifies channel peers of a nick +// change. +func (s *Service) broadcastNickChange( + ctx context.Context, + sessionID int64, + oldNick, newNick string, +) { + channels, err := s.DB.GetSessionChannels( + ctx, sessionID, + ) + if err != nil { + return + } + + body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson + notified := make(map[int64]bool) + + dbID, _, insErr := s.DB.InsertMessage( + ctx, irc.CmdNick, oldNick, "", + nil, body, nil, + ) + if insErr != nil { + return + } + + // Notify the user themselves (for multi-client sync). + _ = s.DB.EnqueueToSession(ctx, sessionID, dbID) + s.Broker.Notify(sessionID) + notified[sessionID] = true + + for _, ch := range channels { + memberIDs, memErr := s.DB.GetChannelMemberIDs( + ctx, ch.ID, + ) + if memErr != nil { + continue + } + + for _, mid := range memberIDs { + if notified[mid] { + continue + } + + notified[mid] = true + + _ = s.DB.EnqueueToSession(ctx, mid, dbID) + s.Broker.Notify(mid) + } + } +} + +// checkJoinRestrictions validates Tier 2 join conditions: +// bans, invite-only, channel key, and user limit. +func checkJoinRestrictions( + ctx context.Context, + database *db.Database, + chID, sessionID int64, + channel, suppliedKey string, + memberCount int64, +) error { + isBanned, banErr := database.IsSessionBanned( + ctx, chID, sessionID, + ) + if banErr == nil && isBanned { + return &IRCError{ + Code: irc.ErrBannedFromChan, + Params: []string{channel}, + Message: "Cannot join channel (+b)", + } + } + + isInviteOnly, ioErr := database.IsChannelInviteOnly( + ctx, chID, + ) + if ioErr == nil && isInviteOnly { + hasInvite, invErr := database.HasChannelInvite( + ctx, chID, sessionID, + ) + if invErr != nil || !hasInvite { + return &IRCError{ + Code: irc.ErrInviteOnlyChan, + Params: []string{channel}, + Message: "Cannot join channel (+i)", + } + } + } + + key, keyErr := database.GetChannelKey(ctx, chID) + if keyErr == nil && key != "" && suppliedKey != key { + return &IRCError{ + Code: irc.ErrBadChannelKey, + Params: []string{channel}, + Message: "Cannot join channel (+k)", + } + } + + limit, limErr := database.GetChannelUserLimit(ctx, chID) + if limErr == nil && limit > 0 && + memberCount >= int64(limit) { + return &IRCError{ + Code: irc.ErrChannelIsFull, + Params: []string{channel}, + Message: "Cannot join channel (+l)", + } + } + + return nil +} diff --git a/internal/service/service_test.go b/internal/service/service_test.go new file mode 100644 index 0000000..345af9c --- /dev/null +++ b/internal/service/service_test.go @@ -0,0 +1,365 @@ +// Tests use a global viper instance for configuration, +// making parallel execution unsafe. +// +//nolint:paralleltest +package service_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "testing" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/globals" + "git.eeqj.de/sneak/neoirc/internal/logger" + "git.eeqj.de/sneak/neoirc/internal/service" + "git.eeqj.de/sneak/neoirc/pkg/irc" + "go.uber.org/fx" + "go.uber.org/fx/fxtest" + "golang.org/x/crypto/bcrypt" +) + +func TestMain(m *testing.M) { + db.SetBcryptCost(bcrypt.MinCost) + os.Exit(m.Run()) +} + +// testEnv holds all dependencies for a service test. +type testEnv struct { + svc *service.Service + db *db.Database + broker *broker.Broker + app *fxtest.App +} + +func newTestEnv(t *testing.T) *testEnv { + t.Helper() + + dbURL := fmt.Sprintf( + "file:svc_test_%p?mode=memory&cache=shared", + t, + ) + + var ( + database *db.Database + svc *service.Service + ) + + brk := broker.New() + + app := fxtest.New(t, + fx.Provide( + func() *globals.Globals { + return &globals.Globals{ //nolint:exhaustruct + Appname: "neoirc-test", + Version: "test", + } + }, + logger.New, + func( + lifecycle fx.Lifecycle, + globs *globals.Globals, + log *logger.Logger, + ) (*config.Config, error) { + cfg, err := config.New( + lifecycle, config.Params{ //nolint:exhaustruct + Globals: globs, Logger: log, + }, + ) + if err != nil { + return nil, fmt.Errorf( + "test config: %w", err, + ) + } + + cfg.DBURL = dbURL + cfg.Port = 0 + cfg.OperName = "admin" + cfg.OperPassword = "secret" + + return cfg, nil + }, + func( + lifecycle fx.Lifecycle, + log *logger.Logger, + cfg *config.Config, + ) (*db.Database, error) { + return db.New(lifecycle, db.Params{ //nolint:exhaustruct + Logger: log, Config: cfg, + }) + }, + func() *broker.Broker { return brk }, + service.New, + ), + fx.Populate(&database, &svc), + ) + + app.RequireStart() + + t.Cleanup(func() { + app.RequireStop() + }) + + return &testEnv{ + svc: svc, + db: database, + broker: brk, + app: app, + } +} + +// createSession is a test helper that creates a session +// and returns the session ID. +func createSession( + ctx context.Context, + t *testing.T, + database *db.Database, + nick string, +) int64 { + t.Helper() + + sessionID, _, _, err := database.CreateSession( + ctx, nick, nick, "localhost", "127.0.0.1", + ) + if err != nil { + t.Fatalf("create session %s: %v", nick, err) + } + + return sessionID +} + +func TestFanOut(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + body, _ := json.Marshal([]string{"hello"}) //nolint:errchkjson + + dbID, uuid, err := env.svc.FanOut( + ctx, irc.CmdPrivmsg, "alice", "#test", + nil, body, nil, + []int64{sid1, sid2}, + ) + if err != nil { + t.Fatalf("FanOut: %v", err) + } + + if dbID == 0 { + t.Error("expected non-zero dbID") + } + + if uuid == "" { + t.Error("expected non-empty UUID") + } +} + +func TestJoinChannel(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid := createSession(ctx, t, env.db, "alice") + + result, err := env.svc.JoinChannel( + ctx, sid, "alice", "#general", "", + ) + if err != nil { + t.Fatalf("JoinChannel: %v", err) + } + + if result.ChannelID == 0 { + t.Error("expected non-zero channel ID") + } + + if !result.IsCreator { + t.Error("first joiner should be creator") + } + + // Second user joins — not creator. + sid2 := createSession(ctx, t, env.db, "bob") + + result2, err := env.svc.JoinChannel( + ctx, sid2, "bob", "#general", "", + ) + if err != nil { + t.Fatalf("JoinChannel bob: %v", err) + } + + if result2.IsCreator { + t.Error("second joiner should not be creator") + } + + if result2.ChannelID != result.ChannelID { + t.Error("both should join the same channel") + } +} + +func TestPartChannel(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid := createSession(ctx, t, env.db, "alice") + + _, err := env.svc.JoinChannel( + ctx, sid, "alice", "#general", "", + ) + if err != nil { + t.Fatalf("JoinChannel: %v", err) + } + + err = env.svc.PartChannel( + ctx, sid, "alice", "#general", "bye", + ) + if err != nil { + t.Fatalf("PartChannel: %v", err) + } + + // Parting a non-existent channel returns error. + err = env.svc.PartChannel( + ctx, sid, "alice", "#nonexistent", "", + ) + if err == nil { + t.Error("expected error for non-existent channel") + } + + var ircErr *service.IRCError + if !errors.As(err, &ircErr) { + t.Errorf("expected IRCError, got %T", err) + } +} + +func TestSendChannelMessage(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + _, err := env.svc.JoinChannel( + ctx, sid1, "alice", "#chat", "", + ) + if err != nil { + t.Fatalf("join alice: %v", err) + } + + _, err = env.svc.JoinChannel( + ctx, sid2, "bob", "#chat", "", + ) + if err != nil { + t.Fatalf("join bob: %v", err) + } + + body, _ := json.Marshal([]string{"hello world"}) //nolint:errchkjson + + dbID, uuid, err := env.svc.SendChannelMessage( + ctx, sid1, "alice", + irc.CmdPrivmsg, "#chat", body, nil, + ) + if err != nil { + t.Fatalf("SendChannelMessage: %v", err) + } + + if dbID == 0 { + t.Error("expected non-zero dbID") + } + + if uuid == "" { + t.Error("expected non-empty UUID") + } + + // Non-member cannot send. + sid3 := createSession(ctx, t, env.db, "charlie") + + _, _, err = env.svc.SendChannelMessage( + ctx, sid3, "charlie", + irc.CmdPrivmsg, "#chat", body, nil, + ) + if err == nil { + t.Error("expected error for non-member send") + } +} + +func TestBroadcastQuit(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + _, err := env.svc.JoinChannel( + ctx, sid1, "alice", "#room", "", + ) + if err != nil { + t.Fatalf("join alice: %v", err) + } + + _, err = env.svc.JoinChannel( + ctx, sid2, "bob", "#room", "", + ) + if err != nil { + t.Fatalf("join bob: %v", err) + } + + // BroadcastQuit should not panic and should clean up. + env.svc.BroadcastQuit( + ctx, sid1, "alice", "Goodbye", + ) + + // Session should be deleted. + _, lookupErr := env.db.GetSessionByNick(ctx, "alice") + if lookupErr == nil { + t.Error("expected session to be deleted after quit") + } +} + +func TestSendChannelMessage_Moderated(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + result, err := env.svc.JoinChannel( + ctx, sid1, "alice", "#modchat", "", + ) + if err != nil { + t.Fatalf("join alice: %v", err) + } + + _, err = env.svc.JoinChannel( + ctx, sid2, "bob", "#modchat", "", + ) + if err != nil { + t.Fatalf("join bob: %v", err) + } + + // Set channel to moderated. + chID := result.ChannelID + _ = env.svc.SetChannelFlag(ctx, chID, 'm', true) + + body, _ := json.Marshal([]string{"test"}) //nolint:errchkjson + + // Bob (non-op, non-voiced) should fail to send. + _, _, err = env.svc.SendChannelMessage( + ctx, sid2, "bob", + irc.CmdPrivmsg, "#modchat", body, nil, + ) + if err == nil { + t.Error("expected error for non-voiced user in moderated channel") + } + + // Alice (operator) should succeed. + _, _, err = env.svc.SendChannelMessage( + ctx, sid1, "alice", + irc.CmdPrivmsg, "#modchat", body, nil, + ) + if err != nil { + t.Errorf("operator should be able to send in moderated channel: %v", err) + } +} diff --git a/pkg/irc/commands.go b/pkg/irc/commands.go index 91893ec..60ef85e 100644 --- a/pkg/irc/commands.go +++ b/pkg/irc/commands.go @@ -21,6 +21,7 @@ const ( CmdPrivmsg = "PRIVMSG" CmdQuit = "QUIT" CmdTopic = "TOPIC" + CmdUser = "USER" CmdWho = "WHO" CmdWhois = "WHOIS" ) -- 2.49.1 From f57a3730533908cad858eaaede1fcd5cdcd32193 Mon Sep 17 00:00:00 2001 From: user Date: Sat, 28 Mar 2026 11:48:01 -0700 Subject: [PATCH 2/2] fix: address 3 blocking review findings for IRC protocol listener 1. ISUPPORT/applyChannelModes: extend IRC MODE handler to support +i/-i, +s/-s, +n/-n (routed through svc.SetChannelFlag), and +H/-H (hashcash bits with parameter parsing). Add 'n' (no external messages) as a proper DB-backed channel flag with is_no_external column (default: on). Update IRC ISUPPORT to CHANMODES=,,H,imnst to match actual support. 2. QueryChannelMode: rewrite to return complete mode string including all boolean flags (n, i, m, s, t) and parameterized modes (k, l, H), matching the HTTP handler's buildChannelModeString logic. Simplify buildChannelModeString to delegate to QueryChannelMode for consistency. 3. Service struct encapsulation: change exported fields (DB, Broker, Config, Log) to unexported (db, broker, config, log). Add NewTestService constructor for use by external test packages. Update ircserver export_test.go to use the new constructor. Closes #89 --- internal/db/queries.go | 46 ++++++ internal/db/schema/001_initial.sql | 1 + internal/handlers/api.go | 56 +------- internal/ircserver/commands.go | 179 +++++++++++++++++++---- internal/ircserver/conn.go | 21 +-- internal/ircserver/export_test.go | 9 +- internal/service/service.go | 220 ++++++++++++++++++----------- 7 files changed, 357 insertions(+), 175 deletions(-) diff --git a/internal/db/queries.go b/internal/db/queries.go index 9029337..d954571 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -2165,6 +2165,52 @@ func (database *Database) SetChannelSecret( return nil } +// --- No External Messages (+n) --- + +// IsChannelNoExternal checks if a channel has +n mode. +func (database *Database) IsChannelNoExternal( + ctx context.Context, + channelID int64, +) (bool, error) { + var isNoExternal int + + err := database.conn.QueryRowContext(ctx, + `SELECT is_no_external FROM channels + WHERE id = ?`, + channelID, + ).Scan(&isNoExternal) + if err != nil { + return false, fmt.Errorf( + "check no external: %w", err, + ) + } + + return isNoExternal != 0, nil +} + +// SetChannelNoExternal sets or unsets +n mode. +func (database *Database) SetChannelNoExternal( + ctx context.Context, + channelID int64, + noExternal bool, +) error { + val := 0 + if noExternal { + val = 1 + } + + _, err := database.conn.ExecContext(ctx, + `UPDATE channels + SET is_no_external = ?, updated_at = ? + WHERE id = ?`, + val, time.Now(), channelID) + if err != nil { + return fmt.Errorf("set no external: %w", err) + } + + return nil +} + // ListAllChannelsWithCountsFiltered returns all channels // with member counts, excluding secret channels that // the given session is not a member of. diff --git a/internal/db/schema/001_initial.sql b/internal/db/schema/001_initial.sql index a29bdaa..e53c48b 100644 --- a/internal/db/schema/001_initial.sql +++ b/internal/db/schema/001_initial.sql @@ -44,6 +44,7 @@ CREATE TABLE IF NOT EXISTS channels ( is_topic_locked INTEGER NOT NULL DEFAULT 1, is_invite_only INTEGER NOT NULL DEFAULT 0, is_secret INTEGER NOT NULL DEFAULT 0, + is_no_external INTEGER NOT NULL DEFAULT 1, channel_key TEXT NOT NULL DEFAULT '', user_limit INTEGER NOT NULL DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 22825af..3989583 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -2016,62 +2016,14 @@ func (hdlr *Handlers) handleChannelMode( } // buildChannelModeString constructs the current mode -// string for a channel, including +n (always on), +t, +m, -// +i, +s, +k, +l, and +H with their parameters. +// string for a channel by delegating to the service +// layer's QueryChannelMode, which returns the complete +// mode string including all flags and parameters. func (hdlr *Handlers) buildChannelModeString( ctx context.Context, chID int64, ) string { - modes := "+n" - - isInviteOnly, ioErr := hdlr.params.Database. - IsChannelInviteOnly(ctx, chID) - if ioErr == nil && isInviteOnly { - modes += "i" - } - - isModerated, modErr := hdlr.params.Database. - IsChannelModerated(ctx, chID) - if modErr == nil && isModerated { - modes += "m" - } - - isSecret, secErr := hdlr.params.Database. - IsChannelSecret(ctx, chID) - if secErr == nil && isSecret { - modes += "s" - } - - isTopicLocked, tlErr := hdlr.params.Database. - IsChannelTopicLocked(ctx, chID) - if tlErr == nil && isTopicLocked { - modes += "t" - } - - var modeParams string - - key, keyErr := hdlr.params.Database. - GetChannelKey(ctx, chID) - if keyErr == nil && key != "" { - modes += "k" - modeParams += " " + key - } - - limit, limErr := hdlr.params.Database. - GetChannelUserLimit(ctx, chID) - if limErr == nil && limit > 0 { - modes += "l" - modeParams += " " + strconv.Itoa(limit) - } - - bits, bitsErr := hdlr.params.Database. - GetChannelHashcashBits(ctx, chID) - if bitsErr == nil && bits > 0 { - modes += "H" - modeParams += " " + strconv.Itoa(bits) - } - - return modes + modeParams + return hdlr.svc.QueryChannelMode(ctx, chID) } // queryChannelMode sends RPL_CHANNELMODEIS and diff --git a/internal/ircserver/commands.go b/internal/ircserver/commands.go index 297ea4f..03deaa6 100644 --- a/internal/ircserver/commands.go +++ b/internal/ircserver/commands.go @@ -490,6 +490,124 @@ func (c *Conn) handleChannelMode( ) } +// modeResult holds the delta strings produced by a +// single mode-char application. +type modeResult struct { + applied string + appliedArgs string + consumed int + skip bool +} + +// applyHashcashMode handles +H/-H (hashcash difficulty). +func (c *Conn) applyHashcashMode( + ctx context.Context, + chID int64, + adding bool, + args []string, + argIdx int, +) modeResult { + if !adding { + _ = c.database.SetChannelHashcashBits( + ctx, chID, 0, + ) + + return modeResult{ + applied: "-H", + appliedArgs: "", + consumed: 0, + skip: false, + } + } + + if argIdx >= len(args) { + return modeResult{ + applied: "", + appliedArgs: "", + consumed: 0, + skip: true, + } + } + + bitsStr := args[argIdx] + + bits, parseErr := strconv.Atoi(bitsStr) + if parseErr != nil || + bits < 1 || bits > maxHashcashBits { + c.sendNumeric( + irc.ErrUnknownMode, "H", + "is unknown mode char to me", + ) + + return modeResult{ + applied: "", + appliedArgs: "", + consumed: 1, + skip: true, + } + } + + _ = c.database.SetChannelHashcashBits( + ctx, chID, bits, + ) + + return modeResult{ + applied: "+H", + appliedArgs: " " + bitsStr, + consumed: 1, + skip: false, + } +} + +// applyMemberMode handles +o/-o and +v/-v. +func (c *Conn) applyMemberMode( + ctx context.Context, + chID int64, + channel string, + modeChar rune, + adding bool, + args []string, + argIdx int, +) modeResult { + if argIdx >= len(args) { + return modeResult{ + applied: "", + appliedArgs: "", + consumed: 0, + skip: true, + } + } + + targetNick := args[argIdx] + + err := c.svc.ApplyMemberMode( + ctx, chID, channel, + targetNick, modeChar, adding, + ) + if err != nil { + c.sendIRCError(err) + + return modeResult{ + applied: "", + appliedArgs: "", + consumed: 1, + skip: true, + } + } + + prefix := "+" + if !adding { + prefix = "-" + } + + return modeResult{ + applied: prefix + string(modeChar), + appliedArgs: " " + targetNick, + consumed: 1, + skip: false, + } +} + // applyChannelModes applies mode changes using the // service for individual mode operations. func (c *Conn) applyChannelModes( @@ -505,52 +623,57 @@ func (c *Conn) applyChannelModes( appliedArgs := "" for _, modeChar := range modeStr { + var res modeResult + switch modeChar { case '+': adding = true + + continue case '-': adding = false - case 'm', 't': + + continue + case 'i', 'm', 'n', 's', 't': _ = c.svc.SetChannelFlag( ctx, chID, modeChar, adding, ) - if adding { - applied += "+" + string(modeChar) - } else { - applied += "-" + string(modeChar) - } - case 'o', 'v': - if argIdx >= len(args) { - break + prefix := "+" + if !adding { + prefix = "-" } - targetNick := args[argIdx] - argIdx++ - - err := c.svc.ApplyMemberMode( - ctx, chID, channel, - targetNick, modeChar, adding, + res = modeResult{ + applied: prefix + string(modeChar), + appliedArgs: "", + consumed: 0, + skip: false, + } + case 'H': + res = c.applyHashcashMode( + ctx, chID, adding, args, argIdx, + ) + case 'o', 'v': + res = c.applyMemberMode( + ctx, chID, channel, + modeChar, adding, args, argIdx, ) - if err != nil { - c.sendIRCError(err) - - continue - } - - if adding { - applied += "+" + string(modeChar) - } else { - applied += "-" + string(modeChar) - } - - appliedArgs += " " + targetNick default: c.sendNumeric( irc.ErrUnknownMode, string(modeChar), "is unknown mode char to me", ) + + continue + } + + argIdx += res.consumed + + if !res.skip { + applied += res.applied + appliedArgs += res.appliedArgs } } diff --git a/internal/ircserver/conn.go b/internal/ircserver/conn.go index 173ac19..835f452 100644 --- a/internal/ircserver/conn.go +++ b/internal/ircserver/conn.go @@ -19,15 +19,16 @@ import ( ) const ( - maxLineLen = 512 - readTimeout = 5 * time.Minute - writeTimeout = 30 * time.Second - dnsTimeout = 3 * time.Second - pollInterval = 100 * time.Millisecond - pingInterval = 90 * time.Second - pongDeadline = 30 * time.Second - maxNickLen = 32 - minPasswordLen = 8 + maxLineLen = 512 + readTimeout = 5 * time.Minute + writeTimeout = 30 * time.Second + dnsTimeout = 3 * time.Second + pollInterval = 100 * time.Millisecond + pingInterval = 90 * time.Second + pongDeadline = 30 * time.Second + maxNickLen = 32 + minPasswordLen = 8 + maxHashcashBits = 40 ) // cmdHandler is the signature for registered IRC command @@ -434,7 +435,7 @@ func (c *Conn) deliverWelcome() { "CHANTYPES=#", "NICKLEN=32", "PREFIX=(ov)@+", - "CHANMODES=,,H,mnst", + "CHANMODES=,,H,imnst", "NETWORK="+c.serverSfx, "are supported by this server", ) diff --git a/internal/ircserver/export_test.go b/internal/ircserver/export_test.go index dc75ee7..e1583c9 100644 --- a/internal/ircserver/export_test.go +++ b/internal/ircserver/export_test.go @@ -19,12 +19,9 @@ func NewTestServer( database *db.Database, brk *broker.Broker, ) *Server { - svc := &service.Service{ - DB: database, - Broker: brk, - Config: cfg, - Log: log, - } + svc := service.NewTestService( + database, brk, cfg, log, + ) return &Server{ //nolint:exhaustruct log: log, diff --git a/internal/service/service.go b/internal/service/service.go index a8304fa..6a8e298 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "log/slog" + "strconv" "strings" "git.eeqj.de/sneak/neoirc/internal/broker" @@ -30,19 +31,35 @@ type Params struct { // Service provides shared business logic for IRC commands. type Service struct { - DB *db.Database - Broker *broker.Broker - Config *config.Config - Log *slog.Logger + db *db.Database + broker *broker.Broker + config *config.Config + log *slog.Logger } // New creates a new Service. func New(params Params) *Service { return &Service{ - DB: params.Database, - Broker: params.Broker, - Config: params.Config, - Log: params.Logger.Get(), + db: params.Database, + broker: params.Broker, + config: params.Config, + log: params.Logger.Get(), + } +} + +// NewTestService creates a Service for use in tests +// outside the service package. +func NewTestService( + database *db.Database, + brk *broker.Broker, + cfg *config.Config, + log *slog.Logger, +) *Service { + return &Service{ + db: database, + broker: brk, + config: cfg, + log: log, } } @@ -76,7 +93,7 @@ func (s *Service) FanOut( params, body, meta json.RawMessage, sessionIDs []int64, ) (int64, string, error) { - dbID, msgUUID, err := s.DB.InsertMessage( + dbID, msgUUID, err := s.db.InsertMessage( ctx, command, from, to, params, body, meta, ) if err != nil { @@ -84,8 +101,8 @@ func (s *Service) FanOut( } for _, sid := range sessionIDs { - _ = s.DB.EnqueueToSession(ctx, sid, dbID) - s.Broker.Notify(sid) + _ = s.db.EnqueueToSession(ctx, sid, dbID) + s.broker.Notify(sid) } return dbID, msgUUID, nil @@ -120,7 +137,7 @@ func (s *Service) SendChannelMessage( nick, command, channel string, body, meta json.RawMessage, ) (int64, string, error) { - chID, err := s.DB.GetChannelByName(ctx, channel) + chID, err := s.db.GetChannelByName(ctx, channel) if err != nil { return 0, "", &IRCError{ irc.ErrNoSuchChannel, @@ -129,7 +146,7 @@ func (s *Service) SendChannelMessage( } } - isMember, _ := s.DB.IsChannelMember( + isMember, _ := s.db.IsChannelMember( ctx, chID, sessionID, ) if !isMember { @@ -141,7 +158,7 @@ func (s *Service) SendChannelMessage( } // Ban check — banned users cannot send messages. - isBanned, banErr := s.DB.IsSessionBanned( + isBanned, banErr := s.db.IsSessionBanned( ctx, chID, sessionID, ) if banErr == nil && isBanned { @@ -152,12 +169,12 @@ func (s *Service) SendChannelMessage( } } - moderated, _ := s.DB.IsChannelModerated(ctx, chID) + moderated, _ := s.db.IsChannelModerated(ctx, chID) if moderated { - isOp, _ := s.DB.IsChannelOperator( + isOp, _ := s.db.IsChannelOperator( ctx, chID, sessionID, ) - isVoiced, _ := s.DB.IsChannelVoiced( + isVoiced, _ := s.db.IsChannelVoiced( ctx, chID, sessionID, ) @@ -170,7 +187,7 @@ func (s *Service) SendChannelMessage( } } - memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID) recipients := excludeSession(memberIDs, sessionID) dbID, uuid, fanErr := s.FanOut( @@ -193,7 +210,7 @@ func (s *Service) SendDirectMessage( nick, command, target string, body, meta json.RawMessage, ) (*DirectMsgResult, error) { - targetSID, err := s.DB.GetSessionByNick(ctx, target) + targetSID, err := s.db.GetSessionByNick(ctx, target) if err != nil { return nil, &IRCError{ irc.ErrNoSuchNick, @@ -202,7 +219,7 @@ func (s *Service) SendDirectMessage( } } - away, _ := s.DB.GetAway(ctx, targetSID) + away, _ := s.db.GetAway(ctx, targetSID) recipients := []int64{targetSID} if targetSID != sessionID { @@ -228,19 +245,19 @@ func (s *Service) JoinChannel( sessionID int64, nick, channel, suppliedKey string, ) (*JoinResult, error) { - chID, err := s.DB.GetOrCreateChannel(ctx, channel) + chID, err := s.db.GetOrCreateChannel(ctx, channel) if err != nil { return nil, fmt.Errorf("get/create channel: %w", err) } - memberCount, countErr := s.DB.CountChannelMembers( + memberCount, countErr := s.db.CountChannelMembers( ctx, chID, ) isCreator := countErr == nil && memberCount == 0 if !isCreator { if joinErr := checkJoinRestrictions( - ctx, s.DB, chID, sessionID, + ctx, s.db, chID, sessionID, channel, suppliedKey, memberCount, ); joinErr != nil { return nil, joinErr @@ -248,11 +265,11 @@ func (s *Service) JoinChannel( } if isCreator { - err = s.DB.JoinChannelAsOperator( + err = s.db.JoinChannelAsOperator( ctx, chID, sessionID, ) } else { - err = s.DB.JoinChannel(ctx, chID, sessionID) + err = s.db.JoinChannel(ctx, chID, sessionID) } if err != nil { @@ -260,9 +277,9 @@ func (s *Service) JoinChannel( } // Clear invite after successful join. - _ = s.DB.ClearChannelInvite(ctx, chID, sessionID) + _ = s.db.ClearChannelInvite(ctx, chID, sessionID) - memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID) body, _ := json.Marshal([]string{channel}) //nolint:errchkjson _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast @@ -284,7 +301,7 @@ func (s *Service) PartChannel( sessionID int64, nick, channel, reason string, ) error { - chID, err := s.DB.GetChannelByName(ctx, channel) + chID, err := s.db.GetChannelByName(ctx, channel) if err != nil { return &IRCError{ irc.ErrNoSuchChannel, @@ -293,7 +310,7 @@ func (s *Service) PartChannel( } } - isMember, _ := s.DB.IsChannelMember( + isMember, _ := s.db.IsChannelMember( ctx, chID, sessionID, ) if !isMember { @@ -304,7 +321,7 @@ func (s *Service) PartChannel( } } - memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID) recipients := excludeSession(memberIDs, sessionID) body, _ := json.Marshal([]string{reason}) //nolint:errchkjson @@ -313,8 +330,8 @@ func (s *Service) PartChannel( nil, body, nil, recipients, ) - s.DB.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec - s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + s.db.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec + s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec return nil } @@ -326,7 +343,7 @@ func (s *Service) SetTopic( sessionID int64, nick, channel, topic string, ) error { - chID, err := s.DB.GetChannelByName(ctx, channel) + chID, err := s.db.GetChannelByName(ctx, channel) if err != nil { return &IRCError{ irc.ErrNoSuchChannel, @@ -335,7 +352,7 @@ func (s *Service) SetTopic( } } - isMember, _ := s.DB.IsChannelMember( + isMember, _ := s.db.IsChannelMember( ctx, chID, sessionID, ) if !isMember { @@ -346,9 +363,9 @@ func (s *Service) SetTopic( } } - topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) + topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID) if topicLocked { - isOp, _ := s.DB.IsChannelOperator( + isOp, _ := s.db.IsChannelOperator( ctx, chID, sessionID, ) if !isOp { @@ -360,15 +377,15 @@ func (s *Service) SetTopic( } } - if setErr := s.DB.SetTopic( + if setErr := s.db.SetTopic( ctx, channel, topic, ); setErr != nil { return fmt.Errorf("set topic: %w", setErr) } - _ = s.DB.SetTopicMeta(ctx, channel, topic, nick) + _ = s.db.SetTopicMeta(ctx, channel, topic, nick) - memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID) body, _ := json.Marshal([]string{topic}) //nolint:errchkjson _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast @@ -387,7 +404,7 @@ func (s *Service) KickUser( sessionID int64, nick, channel, targetNick, reason string, ) error { - chID, err := s.DB.GetChannelByName(ctx, channel) + chID, err := s.db.GetChannelByName(ctx, channel) if err != nil { return &IRCError{ irc.ErrNoSuchChannel, @@ -396,7 +413,7 @@ func (s *Service) KickUser( } } - isOp, _ := s.DB.IsChannelOperator( + isOp, _ := s.db.IsChannelOperator( ctx, chID, sessionID, ) if !isOp { @@ -407,7 +424,7 @@ func (s *Service) KickUser( } } - targetSID, err := s.DB.GetSessionByNick( + targetSID, err := s.db.GetSessionByNick( ctx, targetNick, ) if err != nil { @@ -418,7 +435,7 @@ func (s *Service) KickUser( } } - isMember, _ := s.DB.IsChannelMember( + isMember, _ := s.db.IsChannelMember( ctx, chID, targetSID, ) if !isMember { @@ -429,7 +446,7 @@ func (s *Service) KickUser( } } - memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID) body, _ := json.Marshal([]string{reason}) //nolint:errchkjson params, _ := json.Marshal( //nolint:errchkjson []string{targetNick}, @@ -440,8 +457,8 @@ func (s *Service) KickUser( params, body, nil, memberIDs, ) - s.DB.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec - s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + s.db.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec + s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec return nil } @@ -453,7 +470,7 @@ func (s *Service) ChangeNick( sessionID int64, oldNick, newNick string, ) error { - err := s.DB.ChangeNick(ctx, sessionID, newNick) + err := s.db.ChangeNick(ctx, sessionID, newNick) if err != nil { if strings.Contains(err.Error(), "UNIQUE") || db.IsUniqueConstraintError(err) { @@ -485,7 +502,7 @@ func (s *Service) BroadcastQuit( sessionID int64, nick, reason string, ) { - channels, err := s.DB.GetSessionChannels( + channels, err := s.db.GetSessionChannels( ctx, sessionID, ) if err != nil { @@ -495,7 +512,7 @@ func (s *Service) BroadcastQuit( notified := make(map[int64]bool) for _, ch := range channels { - memberIDs, memErr := s.DB.GetChannelMemberIDs( + memberIDs, memErr := s.db.GetChannelMemberIDs( ctx, ch.ID, ) if memErr != nil { @@ -526,11 +543,11 @@ func (s *Service) BroadcastQuit( } for _, ch := range channels { - s.DB.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec - s.DB.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec + s.db.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec + s.db.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec } - s.DB.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec + s.db.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec } // SetAway sets or clears the away message. Returns true @@ -540,7 +557,7 @@ func (s *Service) SetAway( sessionID int64, message string, ) (bool, error) { - err := s.DB.SetAway(ctx, sessionID, message) + err := s.db.SetAway(ctx, sessionID, message) if err != nil { return false, fmt.Errorf("set away: %w", err) } @@ -555,8 +572,8 @@ func (s *Service) Oper( sessionID int64, name, password string, ) error { - cfgName := s.Config.OperName - cfgPassword := s.Config.OperPassword + cfgName := s.config.OperName + cfgPassword := s.config.OperPassword // Use constant-time comparison and return the same // error for all failures to prevent information @@ -575,7 +592,7 @@ func (s *Service) Oper( } } - _ = s.DB.SetSessionOper(ctx, sessionID, true) + _ = s.db.SetSessionOper(ctx, sessionID, true) return nil } @@ -587,7 +604,7 @@ func (s *Service) ValidateChannelOp( sessionID int64, channel string, ) (int64, error) { - chID, err := s.DB.GetChannelByName(ctx, channel) + chID, err := s.db.GetChannelByName(ctx, channel) if err != nil { return 0, &IRCError{ irc.ErrNoSuchChannel, @@ -596,7 +613,7 @@ func (s *Service) ValidateChannelOp( } } - isOp, _ := s.DB.IsChannelOperator( + isOp, _ := s.db.IsChannelOperator( ctx, chID, sessionID, ) if !isOp { @@ -619,7 +636,7 @@ func (s *Service) ApplyMemberMode( mode rune, adding bool, ) error { - targetSID, err := s.DB.GetSessionByNick( + targetSID, err := s.db.GetSessionByNick( ctx, targetNick, ) if err != nil { @@ -630,7 +647,7 @@ func (s *Service) ApplyMemberMode( } } - isMember, _ := s.DB.IsChannelMember( + isMember, _ := s.db.IsChannelMember( ctx, chID, targetSID, ) if !isMember { @@ -643,11 +660,11 @@ func (s *Service) ApplyMemberMode( switch mode { case 'o': - _ = s.DB.SetChannelMemberOperator( + _ = s.db.SetChannelMemberOperator( ctx, chID, targetSID, adding, ) case 'v': - _ = s.DB.SetChannelMemberVoiced( + _ = s.db.SetChannelMemberVoiced( ctx, chID, targetSID, adding, ) } @@ -655,7 +672,8 @@ func (s *Service) ApplyMemberMode( return nil } -// SetChannelFlag applies +m/-m or +t/-t on a channel. +// SetChannelFlag applies a simple boolean channel mode +// (+m/-m, +t/-t, +i/-i, +s/-s, +n/-n). func (s *Service) SetChannelFlag( ctx context.Context, chID int64, @@ -664,29 +682,37 @@ func (s *Service) SetChannelFlag( ) error { switch flag { case 'm': - if err := s.DB.SetChannelModerated( + if err := s.db.SetChannelModerated( ctx, chID, setting, ); err != nil { return fmt.Errorf("set moderated: %w", err) } case 't': - if err := s.DB.SetChannelTopicLocked( + if err := s.db.SetChannelTopicLocked( ctx, chID, setting, ); err != nil { return fmt.Errorf("set topic locked: %w", err) } case 'i': - if err := s.DB.SetChannelInviteOnly( + if err := s.db.SetChannelInviteOnly( ctx, chID, setting, ); err != nil { return fmt.Errorf("set invite only: %w", err) } case 's': - if err := s.DB.SetChannelSecret( + if err := s.db.SetChannelSecret( ctx, chID, setting, ); err != nil { return fmt.Errorf("set secret: %w", err) } + case 'n': + if err := s.db.SetChannelNoExternal( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf( + "set no external: %w", err, + ) + } } return nil @@ -700,7 +726,7 @@ func (s *Service) BroadcastMode( chID int64, modeText string, ) { - memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID) body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast @@ -709,24 +735,60 @@ func (s *Service) BroadcastMode( ) } -// QueryChannelMode returns the channel mode string. +// QueryChannelMode returns the complete channel mode +// string including all flags and parameterized modes. func (s *Service) QueryChannelMode( ctx context.Context, chID int64, ) string { modes := "+" - moderated, _ := s.DB.IsChannelModerated(ctx, chID) + noExternal, _ := s.db.IsChannelNoExternal(ctx, chID) + if noExternal { + modes += "n" + } + + inviteOnly, _ := s.db.IsChannelInviteOnly(ctx, chID) + if inviteOnly { + modes += "i" + } + + moderated, _ := s.db.IsChannelModerated(ctx, chID) if moderated { modes += "m" } - topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) + secret, _ := s.db.IsChannelSecret(ctx, chID) + if secret { + modes += "s" + } + + topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID) if topicLocked { modes += "t" } - return modes + var modeParams string + + key, _ := s.db.GetChannelKey(ctx, chID) + if key != "" { + modes += "k" + modeParams += " " + key + } + + limit, _ := s.db.GetChannelUserLimit(ctx, chID) + if limit > 0 { + modes += "l" + modeParams += " " + strconv.Itoa(limit) + } + + bits, _ := s.db.GetChannelHashcashBits(ctx, chID) + if bits > 0 { + modes += "H" + modeParams += " " + strconv.Itoa(bits) + } + + return modes + modeParams } // broadcastNickChange notifies channel peers of a nick @@ -736,7 +798,7 @@ func (s *Service) broadcastNickChange( sessionID int64, oldNick, newNick string, ) { - channels, err := s.DB.GetSessionChannels( + channels, err := s.db.GetSessionChannels( ctx, sessionID, ) if err != nil { @@ -746,7 +808,7 @@ func (s *Service) broadcastNickChange( body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson notified := make(map[int64]bool) - dbID, _, insErr := s.DB.InsertMessage( + dbID, _, insErr := s.db.InsertMessage( ctx, irc.CmdNick, oldNick, "", nil, body, nil, ) @@ -755,12 +817,12 @@ func (s *Service) broadcastNickChange( } // Notify the user themselves (for multi-client sync). - _ = s.DB.EnqueueToSession(ctx, sessionID, dbID) - s.Broker.Notify(sessionID) + _ = s.db.EnqueueToSession(ctx, sessionID, dbID) + s.broker.Notify(sessionID) notified[sessionID] = true for _, ch := range channels { - memberIDs, memErr := s.DB.GetChannelMemberIDs( + memberIDs, memErr := s.db.GetChannelMemberIDs( ctx, ch.ID, ) if memErr != nil { @@ -774,8 +836,8 @@ func (s *Service) broadcastNickChange( notified[mid] = true - _ = s.DB.EnqueueToSession(ctx, mid, dbID) - s.Broker.Notify(mid) + _ = s.db.EnqueueToSession(ctx, mid, dbID) + s.broker.Notify(mid) } } } -- 2.49.1