diff --git a/Dockerfile b/Dockerfile index 273f53f..32acda7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,7 @@ RUN apk add --no-cache ca-certificates \ COPY --from=builder /neoircd /usr/local/bin/neoircd USER neoirc -EXPOSE 8080 +EXPOSE 8080 6667 HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ CMD wget -qO- http://localhost:8080/.well-known/healthcheck.json || exit 1 ENTRYPOINT ["neoircd"] diff --git a/README.md b/README.md index be7d0f6..fb97988 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ web client as a convenience/reference implementation, but the API comes first. - [Federation](#federation-server-to-server) - [Storage](#storage) - [Configuration](#configuration) +- [IRC Protocol Listener](#irc-protocol-listener) - [Deployment](#deployment) - [Client Development Guide](#client-development-guide) - [Rate Limiting & Abuse Prevention](#rate-limiting--abuse-prevention) @@ -2266,6 +2267,7 @@ directory is also loaded automatically via | `NEOIRC_OPER_PASSWORD` | string | `""` | Server operator (o-line) password. Both name and password must be set to enable OPER. | | `LOGIN_RATE_LIMIT` | float | `1` | Allowed login attempts per second per IP address. | | `LOGIN_RATE_BURST` | int | `5` | Maximum burst of login attempts per IP before rate limiting kicks in. | +| `IRC_LISTEN_ADDR` | string | `:6667` | TCP address for the traditional IRC protocol listener. Set to empty string to disable. | | `MAINTENANCE_MODE` | bool | `false` | Maintenance mode flag (reserved) | ### Example `.env` file @@ -2282,6 +2284,69 @@ NEOIRC_HASHCASH_BITS=20 --- +## IRC Protocol Listener + +neoirc includes an optional traditional IRC wire protocol listener (RFC +1459/2812) that allows standard IRC clients to connect directly. This enables +backward compatibility with existing IRC clients like irssi, weechat, hexchat, +and others. + +### Configuration + +The IRC listener is **enabled by default** on `:6667`. To disable it, set +`IRC_LISTEN_ADDR` to an empty string: + +```bash +IRC_LISTEN_ADDR= +``` + +### Supported Commands + +| Category | Commands | +|------------|------------------------------------------------------| +| Connection | `NICK`, `USER`, `PASS`, `QUIT`, `PING`/`PONG`, `CAP` | +| Channels | `JOIN`, `PART`, `MODE`, `TOPIC`, `NAMES`, `LIST`, `KICK`, `INVITE` | +| Messaging | `PRIVMSG`, `NOTICE` | +| Info | `WHO`, `WHOIS`, `LUSERS`, `MOTD`, `AWAY` | +| Operator | `OPER` (requires `NEOIRC_OPER_NAME` and `NEOIRC_OPER_PASSWORD`) | + +### Protocol Details + +- **Wire format**: CR-LF delimited lines, max 512 bytes per line +- **Connection registration**: Clients must send `NICK` and `USER` to register. + An optional `PASS` before registration sets the session password (minimum 8 + characters). +- **CAP negotiation**: `CAP LS` and `CAP END` are silently handled for + compatibility with modern clients. No capabilities are advertised. +- **Channel prefixes**: Channels must start with `#`. If omitted, it's + automatically prepended. +- **First joiner**: The first user to join a channel is automatically granted + operator status (`@`). +- **Channel modes**: `+m` (moderated), `+t` (topic lock), `+o` (operator), + `+v` (voice) + +### Bridge to HTTP API + +Messages sent by IRC clients appear in channels visible to HTTP/JSON API +clients and vice versa. The IRC listener and HTTP API share the same database, +broker, and session infrastructure. A user connected via IRC and a user +connected via the HTTP API can communicate in the same channels seamlessly. + +### Docker Usage + +To expose the IRC port in Docker (the listener is enabled by default on +`:6667`): + +```bash +docker run -d \ + -p 8080:8080 \ + -p 6667:6667 \ + -v neoirc-data:/var/lib/neoirc \ + neoirc +``` + +--- + ## Deployment ### Docker (Recommended) diff --git a/cmd/neoircd/main.go b/cmd/neoircd/main.go index d8d6601..8a5b141 100644 --- a/cmd/neoircd/main.go +++ b/cmd/neoircd/main.go @@ -2,14 +2,17 @@ package main import ( + "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/globals" "git.eeqj.de/sneak/neoirc/internal/handlers" "git.eeqj.de/sneak/neoirc/internal/healthcheck" + "git.eeqj.de/sneak/neoirc/internal/ircserver" "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/middleware" "git.eeqj.de/sneak/neoirc/internal/server" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/internal/stats" "go.uber.org/fx" ) @@ -28,16 +31,23 @@ func main() { fx.New( fx.Provide( + broker.New, config.New, db.New, globals.New, handlers.New, + ircserver.New, logger.New, server.New, middleware.New, healthcheck.New, + service.New, stats.New, ), - fx.Invoke(func(*server.Server) {}), + fx.Invoke(func( + _ *server.Server, + _ *ircserver.Server, + ) { + }), ).Run() } diff --git a/internal/config/config.go b/internal/config/config.go index 05f6882..c3377d7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -50,6 +50,7 @@ type Config struct { OperPassword string LoginRateLimit float64 LoginRateBurst int + IRCListenAddr string params *Params log *slog.Logger } @@ -86,6 +87,7 @@ func New( viper.SetDefault("NEOIRC_OPER_PASSWORD", "") viper.SetDefault("LOGIN_RATE_LIMIT", "1") viper.SetDefault("LOGIN_RATE_BURST", "5") + viper.SetDefault("IRC_LISTEN_ADDR", ":6667") err := viper.ReadInConfig() if err != nil { @@ -116,6 +118,7 @@ func New( OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"), LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"), LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"), + IRCListenAddr: viper.GetString("IRC_LISTEN_ADDR"), log: log, params: ¶ms, } diff --git a/internal/db/testing.go b/internal/db/testing.go new file mode 100644 index 0000000..d036611 --- /dev/null +++ b/internal/db/testing.go @@ -0,0 +1,25 @@ +package db + +import ( + "context" + "database/sql" + "log/slog" +) + +// NewTestDatabaseFromConn creates a Database wrapping an +// existing *sql.DB connection. Intended for integration +// tests in other packages. +func NewTestDatabaseFromConn(conn *sql.DB) *Database { + return &Database{ //nolint:exhaustruct + conn: conn, + log: slog.Default(), + } +} + +// RunMigrations applies all schema migrations. Exposed +// for integration tests in other packages. +func (database *Database) RunMigrations( + ctx context.Context, +) error { + return database.runMigrations(ctx) +} diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 34f1987..58f4cb6 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -2,7 +2,6 @@ package handlers import ( "context" - "crypto/subtle" "encoding/json" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/hashcash" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/pkg/irc" "github.com/go-chi/chi/v5" ) @@ -182,51 +182,6 @@ func (hdlr *Handlers) requireAuth( return sessionID, clientID, nick, true } -// fanOut stores a message and enqueues it to all specified -// session IDs, then notifies them. -func (hdlr *Handlers) fanOut( - request *http.Request, - command, from, target string, - body json.RawMessage, - meta json.RawMessage, - sessionIDs []int64, -) (string, error) { - dbID, msgUUID, err := hdlr.params.Database.InsertMessage( - request.Context(), command, from, target, nil, body, meta, - ) - if err != nil { - return "", fmt.Errorf("insert message: %w", err) - } - - for _, sid := range sessionIDs { - enqErr := hdlr.params.Database.EnqueueToSession( - request.Context(), sid, dbID, - ) - if enqErr != nil { - hdlr.log.Error("enqueue failed", - "error", enqErr, "session_id", sid) - } - - hdlr.broker.Notify(sid) - } - - return msgUUID, nil -} - -// fanOutSilent calls fanOut and discards the UUID. -func (hdlr *Handlers) fanOutSilent( - request *http.Request, - command, from, target string, - body json.RawMessage, - sessionIDs []int64, -) error { - _, err := hdlr.fanOut( - request, command, from, target, body, nil, sessionIDs, - ) - - return err -} - // HandleCreateSession creates a new user session. func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc { return func( @@ -1218,6 +1173,43 @@ func (hdlr *Handlers) respondIRCError( http.StatusOK) } +// handleServiceError maps a service-layer error to an IRC +// numeric reply or a generic HTTP 500 response. Returns +// true if an error was handled (response sent). +func (hdlr *Handlers) handleServiceError( + writer http.ResponseWriter, + request *http.Request, + clientID, sessionID int64, + nick string, + err error, +) bool { + if err == nil { + return false + } + + var ircErr *service.IRCError + if errors.As(err, &ircErr) { + hdlr.respondIRCError( + writer, request, clientID, sessionID, + ircErr.Code, nick, ircErr.Params, + ircErr.Message, + ) + + return true + } + + hdlr.log.Error( + "service error", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return true +} + func (hdlr *Handlers) handleChannelMsg( writer http.ResponseWriter, request *http.Request, @@ -1228,141 +1220,45 @@ func (hdlr *Handlers) handleChannelMsg( ) { ctx := request.Context() - chID, ok := hdlr.resolveChannelForSend( - writer, request, - sessionID, clientID, nick, target, + // Hashcash validation is HTTP-specific; needs chID. + isNotice := command == irc.CmdNotice + + if !isNotice { + chID, chErr := hdlr.params.Database. + GetChannelByName(ctx, target) + if chErr == nil { + hashcashErr := hdlr.validateChannelHashcash( + request, clientID, sessionID, + writer, nick, target, body, meta, chID, + ) + if hashcashErr != nil { + return + } + } + } + + // Delegate validation + fan-out to service layer. + dbID, uuid, err := hdlr.svc.SendChannelMessage( + ctx, sessionID, nick, + command, target, body, meta, ) - if !ok { - return - } - - // Enforce +b (ban): banned users cannot send. - isBanned, banErr := hdlr.params.Database. - IsSessionBanned(ctx, chID, sessionID) - if banErr == nil && isBanned { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrCannotSendToChan, nick, []string{target}, - "Cannot send to channel (+b)", - ) - - return - } - - // Enforce +m (moderated): only +o and +v can send. - if !hdlr.checkModeratedSend( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, target, chID, + clientID, sessionID, nick, err, ) { return } - // NOTICE skips hashcash validation on +H channels. - if command != irc.CmdNotice { - hashcashErr := hdlr.validateChannelHashcash( - request, clientID, sessionID, - writer, nick, target, body, meta, chID, - ) - if hashcashErr != nil { - return - } - } - - hdlr.sendChannelMsg( - writer, request, command, nick, target, - body, meta, chID, + // HTTP echo: enqueue to sender so all their clients + // see the message in long-poll responses. + _ = hdlr.params.Database.EnqueueToSession( + ctx, sessionID, dbID, ) -} + hdlr.broker.Notify(sessionID) -// resolveChannelForSend validates that the channel exists -// and the session is a member. Returns the channel ID. -func (hdlr *Handlers) resolveChannelForSend( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, target string, -) (int64, bool) { - ctx := request.Context() - - chID, err := hdlr.params.Database.GetChannelByName( - ctx, target, - ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, []string{target}, - "No such channel", - ) - - return 0, false - } - - isMember, memErr := hdlr.params.Database.IsChannelMember( - ctx, chID, sessionID, - ) - if memErr != nil { - hdlr.log.Error( - "check membership failed", "error", memErr, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return 0, false - } - - if !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrCannotSendToChan, nick, []string{target}, - "Cannot send to channel", - ) - - return 0, false - } - - return chID, true -} - -// checkModeratedSend checks if a user can send to a -// moderated channel. Returns true if sending is allowed. -func (hdlr *Handlers) checkModeratedSend( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, target string, - chID int64, -) bool { - ctx := request.Context() - - isModerated, err := hdlr.params.Database. - IsChannelModerated(ctx, chID) - if err != nil || !isModerated { - return true - } - - isOp, opErr := hdlr.params.Database. - IsChannelOperator(ctx, chID, sessionID) - if opErr == nil && isOp { - return true - } - - isVoiced, vErr := hdlr.params.Database. - IsChannelVoiced(ctx, chID, sessionID) - if vErr == nil && isVoiced { - return true - } - - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrCannotSendToChan, nick, - []string{target}, - "Cannot send to channel (+m)", - ) - - return false + hdlr.respondJSON(writer, request, + map[string]string{"id": uuid, "status": "sent"}, + http.StatusOK) } // validateChannelHashcash checks whether the channel @@ -1519,49 +1415,6 @@ func (hdlr *Handlers) extractHashcashFromMeta( return stamp } -func (hdlr *Handlers) sendChannelMsg( - writer http.ResponseWriter, - request *http.Request, - command, nick, target string, - body json.RawMessage, - meta json.RawMessage, - chID int64, -) { - memberIDs, err := hdlr.params.Database.GetChannelMemberIDs( - request.Context(), chID, - ) - if err != nil { - hdlr.log.Error( - "get channel members failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - msgUUID, err := hdlr.fanOut( - request, command, nick, target, body, meta, memberIDs, - ) - if err != nil { - hdlr.log.Error("send message failed", "error", err) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - hdlr.respondJSON(writer, request, - map[string]string{"id": msgUUID, "status": "sent"}, - http.StatusOK) -} - func (hdlr *Handlers) handleDirectMsg( writer http.ResponseWriter, request *http.Request, @@ -1570,60 +1423,32 @@ func (hdlr *Handlers) handleDirectMsg( body json.RawMessage, meta json.RawMessage, ) { - targetSID, err := hdlr.params.Database.GetSessionByNick( - request.Context(), target, + result, err := hdlr.svc.SendDirectMessage( + request.Context(), sessionID, nick, + command, target, body, meta, ) - if err != nil { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.ErrNoSuchNick, nick, []string{target}, - "No such nick/channel", - ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) - - return - } - - recipients := []int64{targetSID} - if targetSID != sessionID { - recipients = append(recipients, sessionID) - } - - msgUUID, err := hdlr.fanOut( - request, command, nick, target, body, meta, recipients, - ) - if err != nil { - hdlr.log.Error("send dm failed", "error", err) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } // Per RFC 2812: NOTICE must NOT trigger auto-replies // including RPL_AWAY. - if command != irc.CmdNotice { - awayMsg, awayErr := hdlr.params.Database.GetAway( - request.Context(), targetSID, + if command != irc.CmdNotice && result.AwayMsg != "" { + hdlr.enqueueNumeric( + request.Context(), clientID, + irc.RplAway, nick, + []string{target}, result.AwayMsg, ) - if awayErr == nil && awayMsg != "" { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.RplAway, nick, - []string{target}, awayMsg, - ) - hdlr.broker.Notify(sessionID) - } + hdlr.broker.Notify(sessionID) } hdlr.respondJSON(writer, request, - map[string]string{"id": msgUUID, "status": "sent"}, + map[string]string{ + "id": result.UUID, "status": "sent", + }, http.StatusOK) } @@ -1679,141 +1504,20 @@ func (hdlr *Handlers) executeJoin( sessionID, clientID int64, nick, channel, suppliedKey string, ) { - ctx := request.Context() - - chID, isCreator, ok := hdlr.resolveJoinChannel( - writer, request, channel, + result, err := hdlr.svc.JoinChannel( + request.Context(), sessionID, nick, + channel, suppliedKey, ) - if !ok { - return - } - - // Tier 2 join checks only apply to non-creators. - if !isCreator { - if !hdlr.checkJoinAllowed( - writer, request, - sessionID, clientID, nick, - channel, chID, suppliedKey, - ) { - return - } - } - - if err := hdlr.addMemberToChannel( - ctx, chID, sessionID, isCreator, - ); err != nil { - hdlr.log.Error( - "join channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - hdlr.broadcastJoin( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, channel, chID, - ) -} - -// resolveJoinChannel gets or creates the channel and -// determines if the joiner would be the first member -// (i.e., the channel creator/op). -func (hdlr *Handlers) resolveJoinChannel( - writer http.ResponseWriter, - request *http.Request, - channel string, -) (int64, bool, bool) { - ctx := request.Context() - - chID, err := hdlr.params.Database.GetOrCreateChannel( - ctx, channel, - ) - if err != nil { - hdlr.log.Error( - "get/create channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return 0, false, false + clientID, sessionID, nick, err, + ) { + return } - memberCount, countErr := hdlr.params.Database. - CountChannelMembers(ctx, chID) - if countErr != nil { - hdlr.log.Error( - "count members failed", "error", countErr, - ) - } - - isCreator := countErr == nil && memberCount == 0 - - return chID, isCreator, true -} - -// addMemberToChannel adds a session to a channel, with -// operator privileges if they're the first member. -func (hdlr *Handlers) addMemberToChannel( - ctx context.Context, - chID, sessionID int64, - isCreator bool, -) error { - if isCreator { - err := hdlr.params.Database.JoinChannelAsOperator( - ctx, chID, sessionID, - ) - if err != nil { - return fmt.Errorf( - "join as operator: %w", err, - ) - } - - return nil - } - - err := hdlr.params.Database.JoinChannel( - ctx, chID, sessionID, - ) - if err != nil { - return fmt.Errorf("join channel: %w", err) - } - - return nil -} - -// broadcastJoin fans out the JOIN, clears invites, and -// sends numerics. -func (hdlr *Handlers) broadcastJoin( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, -) { - ctx := request.Context() - - memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( - ctx, chID, - ) - - _ = hdlr.fanOutSilent( - request, irc.CmdJoin, nick, channel, nil, memberIDs, - ) - - _ = hdlr.params.Database.ClearChannelInvite( - ctx, chID, sessionID, - ) - hdlr.deliverJoinNumerics( - request, clientID, sessionID, nick, channel, chID, + request, clientID, sessionID, nick, + channel, result.ChannelID, ) hdlr.respondJSON(writer, request, @@ -2039,15 +1743,12 @@ func (hdlr *Handlers) handlePart( body json.RawMessage, ) { if target == "" { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.ErrNeedMoreParams, nick, []string{irc.CmdPart}, + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdPart}, "Not enough parameters", ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) return } @@ -2057,51 +1758,27 @@ func (hdlr *Handlers) handlePart( channel = "#" + channel } - chID, err := hdlr.params.Database.GetChannelByName( - request.Context(), channel, - ) - if err != nil { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.ErrNoSuchChannel, nick, []string{channel}, - "No such channel", - ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) - - return + // Extract reason from body for the service call. + reason := "" + if body != nil { + var lines []string + if json.Unmarshal(body, &lines) == nil && + len(lines) > 0 { + reason = lines[0] + } } - memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( - request.Context(), chID, + err := hdlr.svc.PartChannel( + request.Context(), sessionID, + nick, channel, reason, ) - - _ = hdlr.fanOutSilent( - request, irc.CmdPart, nick, channel, body, memberIDs, - ) - - err = hdlr.params.Database.PartChannel( - request.Context(), chID, sessionID, - ) - if err != nil { - hdlr.log.Error( - "part channel failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - _ = hdlr.params.Database.DeleteChannelIfEmpty( - request.Context(), chID, - ) - hdlr.respondJSON(writer, request, map[string]string{ "status": "parted", @@ -2162,34 +1839,16 @@ func (hdlr *Handlers) executeNickChange( sessionID, clientID int64, nick, newNick string, ) { - err := hdlr.params.Database.ChangeNick( - request.Context(), sessionID, newNick, + err := hdlr.svc.ChangeNick( + request.Context(), sessionID, nick, newNick, ) - if err != nil { - if db.IsUniqueConstraintError(err) { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNicknameInUse, nick, []string{newNick}, - "Nickname is already in use", - ) - - return - } - - hdlr.log.Error( - "change nick failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - hdlr.broadcastNick(request, sessionID, nick, newNick) - hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "nick": newNick, @@ -2197,70 +1856,19 @@ func (hdlr *Handlers) executeNickChange( http.StatusOK) } -func (hdlr *Handlers) broadcastNick( - request *http.Request, - sessionID int64, - oldNick, newNick string, -) { - channels, _ := hdlr.params.Database. - GetSessionChannels( - request.Context(), sessionID, - ) - - notified := map[int64]bool{sessionID: true} - - nickBody, err := json.Marshal([]string{newNick}) - if err != nil { - hdlr.log.Error( - "marshal nick body", "error", err, - ) - - return - } - - dbID, _, _ := hdlr.params.Database.InsertMessage( - request.Context(), irc.CmdNick, oldNick, "", - nil, json.RawMessage(nickBody), nil, - ) - - _ = hdlr.params.Database.EnqueueToSession( - request.Context(), sessionID, dbID, - ) - - hdlr.broker.Notify(sessionID) - - for _, chanInfo := range channels { - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs( - request.Context(), chanInfo.ID, - ) - - for _, mid := range memberIDs { - if !notified[mid] { - notified[mid] = true - - _ = hdlr.params.Database.EnqueueToSession( - request.Context(), mid, dbID, - ) - - hdlr.broker.Notify(mid) - } - } - } -} - func (hdlr *Handlers) handleTopic( writer http.ResponseWriter, request *http.Request, sessionID, clientID int64, nick, target string, - body json.RawMessage, + _ json.RawMessage, bodyLines func() []string, ) { if target == "" { hdlr.respondIRCError( writer, request, clientID, sessionID, - irc.ErrNeedMoreParams, nick, []string{irc.CmdTopic}, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdTopic}, "Not enough parameters", ) @@ -2271,7 +1879,8 @@ func (hdlr *Handlers) handleTopic( if len(lines) == 0 { hdlr.respondIRCError( writer, request, clientID, sessionID, - irc.ErrNeedMoreParams, nick, []string{irc.CmdTopic}, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdTopic}, "Not enough parameters", ) @@ -2283,131 +1892,24 @@ func (hdlr *Handlers) handleTopic( channel = "#" + channel } - chID, err := hdlr.params.Database.GetChannelByName( - request.Context(), channel, + topic := strings.Join(lines, " ") + + err := hdlr.svc.SetTopic( + request.Context(), sessionID, + nick, channel, topic, ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, []string{channel}, - "No such channel", - ) - - return - } - - isMember, err := hdlr.params.Database.IsChannelMember( - request.Context(), chID, sessionID, - ) - if err != nil { - hdlr.log.Error( - "check membership failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return - } - - if !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNotOnChannel, nick, []string{channel}, - "You're not on that channel", - ) - - return - } - - hdlr.executeTopic( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, - channel, strings.Join(lines, " "), - body, chID, - ) -} - -func (hdlr *Handlers) executeTopic( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel, topic string, - body json.RawMessage, - chID int64, -) { - ctx := request.Context() - - // Enforce +t: only operators can change topic when - // topic lock is active. - isLocked, lockErr := hdlr.params.Database. - IsChannelTopicLocked(ctx, chID) - if lockErr == nil && isLocked { - isOp, opErr := hdlr.params.Database. - IsChannelOperator(ctx, chID, sessionID) - if opErr != nil || !isOp { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrChanOpPrivsNeeded, nick, - []string{channel}, - "You're not channel operator", - ) - - return - } - } - - setErr := hdlr.params.Database.SetTopicMeta( - request.Context(), channel, topic, nick, - ) - if setErr != nil { - hdlr.log.Error( - "set topic failed", "error", setErr, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + clientID, sessionID, nick, err, + ) { return } - memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( - request.Context(), chID, + hdlr.deliverSetTopicNumerics( + request.Context(), clientID, sessionID, + nick, channel, topic, ) - _ = hdlr.fanOutSilent( - request, irc.CmdTopic, nick, channel, body, memberIDs, - ) - - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.RplTopic, nick, []string{channel}, topic, - ) - - // 333 RPL_TOPICWHOTIME - topicMeta, tmErr := hdlr.params.Database. - GetTopicMeta(request.Context(), chID) - if tmErr == nil && topicMeta != nil { - hdlr.enqueueNumeric( - request.Context(), clientID, - irc.RplTopicWhoTime, nick, - []string{ - channel, - topicMeta.SetBy, - strconv.FormatInt( - topicMeta.SetAt.Unix(), 10, - ), - }, - "", - ) - } - - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "topic": topic, @@ -2415,6 +1917,43 @@ func (hdlr *Handlers) executeTopic( http.StatusOK) } +// deliverSetTopicNumerics sends RPL_TOPIC and +// RPL_TOPICWHOTIME to the client after a topic change. +func (hdlr *Handlers) deliverSetTopicNumerics( + ctx context.Context, + clientID, sessionID int64, + nick, channel, topic string, +) { + hdlr.enqueueNumeric( + ctx, clientID, irc.RplTopic, nick, + []string{channel}, topic, + ) + + chID, chErr := hdlr.params.Database.GetChannelByName( + ctx, channel, + ) + if chErr == nil { + topicMeta, tmErr := hdlr.params.Database. + GetTopicMeta(ctx, chID) + if tmErr == nil && topicMeta != nil { + hdlr.enqueueNumeric( + ctx, clientID, + irc.RplTopicWhoTime, nick, + []string{ + channel, + topicMeta.SetBy, + strconv.FormatInt( + topicMeta.SetAt.Unix(), 10, + ), + }, + "", + ) + } + } + + hdlr.broker.Notify(sessionID) +} + // dispatchInfoCommand handles informational IRC commands // that produce server-side numerics (MOTD, PING). func (hdlr *Handlers) dispatchInfoCommand( @@ -2457,51 +1996,17 @@ func (hdlr *Handlers) handleQuit( nick string, body json.RawMessage, ) { - channels, _ := hdlr.params.Database. - GetSessionChannels( - request.Context(), sessionID, - ) - - notified := map[int64]bool{} - - var dbID int64 - - if len(channels) > 0 { - dbID, _, _ = hdlr.params.Database.InsertMessage( - request.Context(), irc.CmdQuit, nick, "", - nil, body, nil, - ) - } - - for _, chanInfo := range channels { - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs( - request.Context(), chanInfo.ID, - ) - - for _, mid := range memberIDs { - if mid != sessionID && !notified[mid] { - notified[mid] = true - - _ = hdlr.params.Database.EnqueueToSession( - request.Context(), mid, dbID, - ) - - hdlr.broker.Notify(mid) - } + reason := "Client quit" + if body != nil { + var lines []string + if json.Unmarshal(body, &lines) == nil && + len(lines) > 0 { + reason = lines[0] } - - _ = hdlr.params.Database.PartChannel( - request.Context(), chanInfo.ID, sessionID, - ) - - _ = hdlr.params.Database.DeleteChannelIfEmpty( - request.Context(), chanInfo.ID, - ) } - _ = hdlr.params.Database.DeleteSession( - request.Context(), sessionID, + hdlr.svc.BroadcastQuit( + request.Context(), sessionID, nick, reason, ) hdlr.clearAuthCookie(writer, request) @@ -2695,30 +2200,6 @@ func (hdlr *Handlers) queryChannelMode( // requireChannelOp checks that the session has +o in the // channel. If not, it sends ERR_CHANOPRIVSNEEDED and // returns false. -func (hdlr *Handlers) requireChannelOp( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, -) bool { - isOp, err := hdlr.params.Database.IsChannelOperator( - request.Context(), chID, sessionID, - ) - if err != nil || !isOp { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrChanOpPrivsNeeded, nick, - []string{channel}, - "You're not channel operator", - ) - - return false - } - - return true -} - // applyChannelMode handles setting channel modes. // Supports +o/-o, +v/-v, +m/-m, +t/-t, +H/-H. func (hdlr *Handlers) applyChannelMode( @@ -2894,67 +2375,6 @@ func (hdlr *Handlers) applyParameterizedMode( } } -// resolveUserModeTarget validates a user-mode change -// target and returns the target session ID if valid. -// Returns -1 on error (error response already sent). -func (hdlr *Handlers) resolveUserModeTarget( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel string, - chID int64, - modeArgs []string, -) (int64, string, bool) { - if !hdlr.requireChannelOp( - writer, request, - sessionID, clientID, nick, channel, chID, - ) { - return -1, "", false - } - - if len(modeArgs) < 2 { //nolint:mnd // mode + nick - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNeedMoreParams, nick, - []string{irc.CmdMode}, - "Not enough parameters", - ) - - return -1, "", false - } - - ctx := request.Context() - targetNick := modeArgs[1] - - targetSID, err := hdlr.params.Database. - GetSessionByNick(ctx, targetNick) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchNick, nick, - []string{targetNick}, - "No such nick/channel", - ) - - return -1, "", false - } - - isMember, memErr := hdlr.params.Database. - IsChannelMember(ctx, chID, targetSID) - if memErr != nil || !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrUserNotInChannel, nick, - []string{targetNick, channel}, - "They aren't on that channel", - ) - - return -1, "", false - } - - return targetSID, targetNick, true -} - // applyUserMode handles +o/-o and +v/-v mode changes. // isOperMode=true for +o/-o, false for +v/-v. func (hdlr *Handlers) applyUserMode( @@ -2966,46 +2386,53 @@ func (hdlr *Handlers) applyUserMode( modeArgs []string, isOperMode bool, ) { - ctx := request.Context() - - targetSID, targetNick, ok := hdlr.resolveUserModeTarget( - writer, request, - sessionID, clientID, nick, - channel, chID, modeArgs, + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + request.Context(), sessionID, channel, ) - if !ok { + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, opErr, + ) { return } + if len(modeArgs) < 2 { //nolint:mnd // mode + nick + hdlr.respondIRCError( + writer, request, clientID, sessionID, + irc.ErrNeedMoreParams, nick, + []string{irc.CmdMode}, + "Not enough parameters", + ) + + return + } + + targetNick := modeArgs[1] setting := strings.HasPrefix(modeArgs[0], "+") - var err error + var modeChar rune if isOperMode { - err = hdlr.params.Database.SetChannelMemberOperator( - ctx, chID, targetSID, setting, - ) + modeChar = 'o' } else { - err = hdlr.params.Database.SetChannelMemberVoiced( - ctx, chID, targetSID, setting, - ) + modeChar = 'v' } - if err != nil { - hdlr.log.Error( - "set user mode failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + err := hdlr.svc.ApplyMemberMode( + request.Context(), chID, channel, + targetNick, modeChar, setting, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - hdlr.broadcastUserModeChange( - request, nick, channel, chID, - modeArgs[0], targetNick, + modeText := modeArgs[0] + " " + targetNick + hdlr.svc.BroadcastMode( + request.Context(), nick, channel, chID, + modeText, ) hdlr.respondJSON(writer, request, @@ -3013,32 +2440,6 @@ func (hdlr *Handlers) applyUserMode( http.StatusOK) } -// broadcastUserModeChange fans out a user-mode change -// to all channel members. -func (hdlr *Handlers) broadcastUserModeChange( - request *http.Request, - nick, channel string, - chID int64, - modeStr, targetNick string, -) { - ctx := request.Context() - - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chID) - - modeBody, err := json.Marshal( - []string{modeStr, targetNick}, - ) - if err != nil { - return - } - - _ = hdlr.fanOutSilent( - request, irc.CmdMode, nick, channel, - json.RawMessage(modeBody), memberIDs, - ) -} - // setChannelFlag handles +m/-m and +t/-t mode changes. func (hdlr *Handlers) setChannelFlag( writer http.ResponseWriter, @@ -3049,70 +2450,35 @@ func (hdlr *Handlers) setChannelFlag( flag string, setting bool, ) { - ctx := request.Context() - - if !hdlr.requireChannelOp( + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + request.Context(), sessionID, channel, + ) + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, channel, chID, + clientID, sessionID, nick, opErr, ) { return } - var err error - - switch flag { - case "m": - err = hdlr.params.Database.SetChannelModerated( - ctx, chID, setting, - ) - case "t": - err = hdlr.params.Database.SetChannelTopicLocked( - ctx, chID, setting, - ) - case "i": - err = hdlr.params.Database.SetChannelInviteOnly( - ctx, chID, setting, - ) - case "s": - err = hdlr.params.Database.SetChannelSecret( - ctx, chID, setting, - ) - } - - if err != nil { - hdlr.log.Error( - "set channel flag failed", "error", err, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + err := hdlr.svc.SetChannelFlag( + request.Context(), chID, rune(flag[0]), setting, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - // Broadcast the MODE change. modeStr := "+" + flag if !setting { modeStr = "-" + flag } - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chID) - - modeBody, mErr := json.Marshal([]string{modeStr}) - if mErr != nil { - hdlr.log.Error( - "marshal mode body", "error", mErr, - ) - - return - } - - _ = hdlr.fanOutSilent( - request, irc.CmdMode, nick, channel, - json.RawMessage(modeBody), memberIDs, + hdlr.svc.BroadcastMode( + request.Context(), nick, channel, chID, + modeStr, ) hdlr.respondJSON(writer, request, @@ -3140,6 +2506,17 @@ func (hdlr *Handlers) setHashcashMode( ) { ctx := request.Context() + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + ctx, sessionID, channel, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, opErr, + ) { + return + } + if len(modeArgs) < 2 { //nolint:mnd // +H requires a bits arg hdlr.respondIRCError( writer, request, clientID, sessionID, @@ -3203,6 +2580,17 @@ func (hdlr *Handlers) clearHashcashMode( ) { ctx := request.Context() + // Validate operator status via service. + _, opErr := hdlr.svc.ValidateChannelOp( + ctx, sessionID, channel, + ) + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, opErr, + ) { + return + } + err := hdlr.params.Database.SetChannelHashcashBits( ctx, chID, 0, ) @@ -4453,52 +3841,16 @@ func (hdlr *Handlers) HandleLogout() http.HandlerFunc { } // cleanupUser parts the user from all channels (notifying -// members) and deletes the session. +// members) and deletes the session via the shared service +// layer. func (hdlr *Handlers) cleanupUser( ctx context.Context, sessionID int64, nick string, ) { - channels, _ := hdlr.params.Database. - GetSessionChannels(ctx, sessionID) - - notified := map[int64]bool{} - - var quitDBID int64 - - if len(channels) > 0 { - quitDBID, _, _ = hdlr.params.Database.InsertMessage( - ctx, irc.CmdQuit, nick, "", - nil, nil, nil, - ) - } - - for _, chanInfo := range channels { - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chanInfo.ID) - - for _, mid := range memberIDs { - if mid != sessionID && !notified[mid] { - notified[mid] = true - - _ = hdlr.params.Database.EnqueueToSession( - ctx, mid, quitDBID, - ) - - hdlr.broker.Notify(mid) - } - } - - _ = hdlr.params.Database.PartChannel( - ctx, chanInfo.ID, sessionID, - ) - - _ = hdlr.params.Database.DeleteChannelIfEmpty( - ctx, chanInfo.ID, - ) - } - - _ = hdlr.params.Database.DeleteSession(ctx, sessionID) + hdlr.svc.BroadcastQuit( + ctx, sessionID, nick, "Connection closed", + ) } // HandleUsersMe returns the current user's session info. @@ -4545,7 +3897,8 @@ func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc { } } -// handleOper handles the OPER command for server operator authentication. +// handleOper handles the OPER command for server operator +// authentication. func (hdlr *Handlers) handleOper( writer http.ResponseWriter, request *http.Request, @@ -4567,39 +3920,13 @@ func (hdlr *Handlers) handleOper( return } - operName := lines[0] - operPass := lines[1] - - cfgName := hdlr.params.Config.OperName - cfgPass := hdlr.params.Config.OperPassword - - if cfgName == "" || cfgPass == "" || - subtle.ConstantTimeCompare([]byte(operName), []byte(cfgName)) != 1 || - subtle.ConstantTimeCompare([]byte(operPass), []byte(cfgPass)) != 1 { - hdlr.enqueueNumeric( - ctx, clientID, irc.ErrNoOperHost, nick, - nil, "No O-lines for your host", - ) - hdlr.broker.Notify(sessionID) - hdlr.respondJSON(writer, request, - map[string]string{"status": "error"}, - http.StatusOK) - - return - } - - err := hdlr.params.Database.SetSessionOper( - ctx, sessionID, true, + err := hdlr.svc.Oper( + ctx, sessionID, lines[0], lines[1], ) - if err != nil { - hdlr.log.Error( - "set oper failed", "error", err, - ) - hdlr.respondError( - writer, request, "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } @@ -4633,21 +3960,17 @@ func (hdlr *Handlers) handleAway( awayMsg = strings.Join(lines, " ") } - err := hdlr.params.Database.SetAway( + cleared, err := hdlr.svc.SetAway( ctx, sessionID, awayMsg, ) - if err != nil { - hdlr.log.Error("set away failed", "error", err) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - + if hdlr.handleServiceError( + writer, request, + clientID, sessionID, nick, err, + ) { return } - if awayMsg == "" { + if cleared { // 305 RPL_UNAWAY hdlr.enqueueNumeric( ctx, clientID, irc.RplUnaway, nick, nil, @@ -4676,6 +3999,8 @@ func (hdlr *Handlers) handleKick( body json.RawMessage, bodyLines func() []string, ) { + _ = body + if target == "" { hdlr.respondIRCError( writer, request, clientID, sessionID, @@ -4711,178 +4036,22 @@ func (hdlr *Handlers) handleKick( reason = lines[1] } - hdlr.executeKick( - writer, request, - sessionID, clientID, nick, - channel, targetNick, reason, body, + err := hdlr.svc.KickUser( + request.Context(), sessionID, nick, + channel, targetNick, reason, ) -} - -func (hdlr *Handlers) executeKick( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel, targetNick, reason string, - _ json.RawMessage, -) { - ctx := request.Context() - - chID, targetSID, ok := hdlr.validateKick( + if hdlr.handleServiceError( writer, request, - sessionID, clientID, nick, - channel, targetNick, - ) - if !ok { - return - } - - if !hdlr.broadcastKick( - writer, request, - nick, channel, targetNick, reason, chID, + clientID, sessionID, nick, err, ) { return } - // Remove the kicked user from the channel. - _ = hdlr.params.Database.PartChannel( - ctx, chID, targetSID, - ) - - // Clean up empty channel. - _ = hdlr.params.Database.DeleteChannelIfEmpty( - ctx, chID, - ) - hdlr.respondJSON(writer, request, map[string]string{"status": "ok"}, http.StatusOK) } -// validateKick checks the channel exists, the kicker is -// an operator, and the target is in the channel. -func (hdlr *Handlers) validateKick( - writer http.ResponseWriter, - request *http.Request, - sessionID, clientID int64, - nick, channel, targetNick string, -) (int64, int64, bool) { - ctx := request.Context() - - chID, err := hdlr.params.Database.GetChannelByName( - ctx, channel, - ) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchChannel, nick, - []string{channel}, - "No such channel", - ) - - return 0, 0, false - } - - if !hdlr.requireChannelOp( - writer, request, - sessionID, clientID, nick, channel, chID, - ) { - return 0, 0, false - } - - targetSID, err := hdlr.params.Database. - GetSessionByNick(ctx, targetNick) - if err != nil { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrNoSuchNick, nick, - []string{targetNick}, - "No such nick/channel", - ) - - return 0, 0, false - } - - isMember, memErr := hdlr.params.Database. - IsChannelMember(ctx, chID, targetSID) - if memErr != nil || !isMember { - hdlr.respondIRCError( - writer, request, clientID, sessionID, - irc.ErrUserNotInChannel, nick, - []string{targetNick, channel}, - "They aren't on that channel", - ) - - return 0, 0, false - } - - return chID, targetSID, true -} - -// broadcastKick inserts a KICK message and fans it out -// to all channel members. -func (hdlr *Handlers) broadcastKick( - writer http.ResponseWriter, - request *http.Request, - nick, channel, targetNick, reason string, - chID int64, -) bool { - ctx := request.Context() - - memberIDs, _ := hdlr.params.Database. - GetChannelMemberIDs(ctx, chID) - - kickBody, bErr := json.Marshal([]string{reason}) - if bErr != nil { - hdlr.log.Error("marshal kick body", "error", bErr) - - return false - } - - kickParams, pErr := json.Marshal( - []string{targetNick}, - ) - if pErr != nil { - hdlr.log.Error( - "marshal kick params", "error", pErr, - ) - - return false - } - - dbID, _, insertErr := hdlr.params.Database. - InsertMessage( - ctx, irc.CmdKick, nick, channel, - json.RawMessage(kickParams), - json.RawMessage(kickBody), nil, - ) - if insertErr != nil { - hdlr.log.Error( - "insert kick message", "error", insertErr, - ) - hdlr.respondError( - writer, request, - "internal error", - http.StatusInternalServerError, - ) - - return false - } - - for _, sid := range memberIDs { - enqErr := hdlr.params.Database.EnqueueToSession( - ctx, sid, dbID, - ) - if enqErr != nil { - hdlr.log.Error("enqueue kick failed", - "error", enqErr, "session_id", sid) - } - - hdlr.broker.Notify(sid) - } - - return true -} - // deliverWhoisIdle sends RPL_WHOISIDLE (317) with idle // time and signon time. func (hdlr *Handlers) deliverWhoisIdle( diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index 6b0a477..ad70bdb 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/globals" @@ -27,6 +28,7 @@ import ( "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/middleware" "git.eeqj.de/sneak/neoirc/internal/server" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/internal/stats" "go.uber.org/fx" "go.uber.org/fx/fxtest" @@ -206,6 +208,14 @@ func newTestHandlers( hcheck *healthcheck.Healthcheck, tracker *stats.Tracker, ) (*handlers.Handlers, error) { + brk := broker.New() + svc := service.New(service.Params{ //nolint:exhaustruct + Logger: log, + Config: cfg, + Database: database, + Broker: brk, + }) + hdlr, err := handlers.New(lifecycle, handlers.Params{ //nolint:exhaustruct Logger: log, Globals: globs, @@ -213,6 +223,8 @@ func newTestHandlers( Database: database, Healthcheck: hcheck, Stats: tracker, + Broker: brk, + Service: svc, }) if err != nil { return nil, fmt.Errorf("test handlers: %w", err) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 5d5c225..d185b5d 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -17,6 +17,7 @@ import ( "git.eeqj.de/sneak/neoirc/internal/healthcheck" "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/ratelimit" + "git.eeqj.de/sneak/neoirc/internal/service" "git.eeqj.de/sneak/neoirc/internal/stats" "go.uber.org/fx" ) @@ -33,6 +34,8 @@ type Params struct { Database *db.Database Healthcheck *healthcheck.Healthcheck Stats *stats.Tracker + Broker *broker.Broker + Service *service.Service } const defaultIdleTimeout = 30 * 24 * time.Hour @@ -48,6 +51,7 @@ type Handlers struct { log *slog.Logger hc *healthcheck.Healthcheck broker *broker.Broker + svc *service.Service hashcashVal *hashcash.Validator channelHashcash *hashcash.ChannelValidator loginLimiter *ratelimit.Limiter @@ -79,7 +83,8 @@ func New( params: ¶ms, log: params.Logger.Get(), hc: params.Healthcheck, - broker: broker.New(), + broker: params.Broker, + svc: params.Service, hashcashVal: hashcash.NewValidator(resource), channelHashcash: hashcash.NewChannelValidator(), loginLimiter: ratelimit.New(loginRate, loginBurst), diff --git a/internal/ircserver/commands.go b/internal/ircserver/commands.go new file mode 100644 index 0000000..9ba1e57 --- /dev/null +++ b/internal/ircserver/commands.go @@ -0,0 +1,1178 @@ +package ircserver + +import ( + "context" + "encoding/json" + "errors" + "strconv" + "strings" + "time" + + "git.eeqj.de/sneak/neoirc/internal/service" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +// sendIRCError maps a service.IRCError to an IRC numeric +// reply on the wire. +func (c *Conn) sendIRCError(err error) { + var ircErr *service.IRCError + if errors.As(err, &ircErr) { + args := make([]string, 0, len(ircErr.Params)+1) + args = append(args, ircErr.Params...) + args = append(args, ircErr.Message) + c.sendNumeric(ircErr.Code, args...) + } +} + +// handleCAP silently acknowledges CAP negotiation. +func (c *Conn) handleCAP(msg *Message) { + if len(msg.Params) == 0 { + return + } + + sub := strings.ToUpper(msg.Params[0]) + if sub == "LS" { + c.send(FormatMessage( + c.serverSfx, "CAP", "*", "LS", "", + )) + } + + // CAP END and other subcommands are silently ignored. +} + +// handlePing replies with a PONG. +func (c *Conn) handlePing(msg *Message) { + token := c.serverSfx + + if len(msg.Params) > 0 { + token = msg.Params[0] + } + + c.sendFromServer("PONG", c.serverSfx, token) +} + +// handleNick changes the user's nickname via the shared +// service layer. +func (c *Conn) handleNick( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + newNick := msg.Params[0] + if len(newNick) > maxNickLen { + newNick = newNick[:maxNickLen] + } + + oldMask := c.hostmask() + oldNick := c.nick + + err := c.svc.ChangeNick( + ctx, c.sessionID, oldNick, newNick, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + c.mu.Lock() + c.nick = newNick + c.mu.Unlock() + + // Echo NICK change to the client on wire. + c.send(FormatMessage(oldMask, "NICK", newNick)) +} + +// handlePrivmsg handles PRIVMSG and NOTICE commands via +// the shared service layer. +func (c *Conn) handlePrivmsg( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoRecipient, + "No recipient given ("+msg.Command+")", + ) + + return + } + + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNoTextToSend, "No text to send", + ) + + return + } + + target := msg.Params[0] + text := msg.Params[1] + body, _ := json.Marshal([]string{text}) //nolint:errchkjson + + if strings.HasPrefix(target, "#") { + _, _, err := c.svc.SendChannelMessage( + ctx, c.sessionID, c.nick, + msg.Command, target, body, nil, + ) + if err != nil { + c.sendIRCError(err) + } + } else { + result, err := c.svc.SendDirectMessage( + ctx, c.sessionID, c.nick, + msg.Command, target, body, nil, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + if result.AwayMsg != "" { + c.sendNumeric( + irc.RplAway, target, result.AwayMsg, + ) + } + } +} + +// handleJoin joins one or more channels via the shared +// service layer. +func (c *Conn) handleJoin( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "JOIN", "Not enough parameters", + ) + + return + } + + channels := strings.Split(msg.Params[0], ",") + + for _, chanName := range channels { + chanName = strings.TrimSpace(chanName) + + if !strings.HasPrefix(chanName, "#") { + chanName = "#" + chanName + } + + c.joinChannel(ctx, chanName) + } +} + +// joinChannel joins a single channel using the service +// and delivers topic/names on the wire. +func (c *Conn) joinChannel( + ctx context.Context, + channel string, +) { + result, err := c.svc.JoinChannel( + ctx, c.sessionID, c.nick, channel, + ) + if err != nil { + c.sendIRCError(err) + + if !errors.As(err, new(*service.IRCError)) { + c.log.Error( + "join channel failed", "error", err, + ) + } + + return + } + + // Send JOIN echo to this client directly on wire. + c.send(FormatMessage(c.hostmask(), "JOIN", channel)) + + // Send topic. + c.deliverTopic(ctx, channel, result.ChannelID) + + // Send NAMES. + c.deliverNames(ctx, channel, result.ChannelID) +} + +// deliverTopic sends RPL_TOPIC or RPL_NOTOPIC. +func (c *Conn) deliverTopic( + ctx context.Context, + channel string, + chID int64, +) { + channels, err := c.database.ListChannels( + ctx, c.sessionID, + ) + + topic := "" + + if err == nil { + for _, ch := range channels { + if ch.Name == channel { + topic = ch.Topic + + break + } + } + } + + if topic == "" { + c.sendNumeric( + irc.RplNoTopic, channel, "No topic is set", + ) + + return + } + + c.sendNumeric(irc.RplTopic, channel, topic) + + meta, tmErr := c.database.GetTopicMeta(ctx, chID) + if tmErr == nil && meta != nil { + c.sendNumeric( + irc.RplTopicWhoTime, channel, + meta.SetBy, + strconv.FormatInt(meta.SetAt.Unix(), 10), + ) + } +} + +// deliverNames sends RPL_NAMREPLY and RPL_ENDOFNAMES. +func (c *Conn) deliverNames( + ctx context.Context, + channel string, + chID int64, +) { + members, err := c.database.ChannelMembers(ctx, chID) + if err != nil { + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) + + return + } + + names := make([]string, 0, len(members)) + + for _, member := range members { + prefix := "" + if member.IsOperator { + prefix = "@" + } else if member.IsVoiced { + prefix = "+" + } + + names = append(names, prefix+member.Nick) + } + + nameStr := strings.Join(names, " ") + + c.sendNumeric( + irc.RplNamReply, "=", channel, nameStr, + ) + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) +} + +// handlePart leaves one or more channels via the shared +// service layer. +func (c *Conn) handlePart( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PART", "Not enough parameters", + ) + + return + } + + reason := "" + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + channels := strings.Split(msg.Params[0], ",") + + for _, ch := range channels { + ch = strings.TrimSpace(ch) + c.partChannel(ctx, ch, reason) + } +} + +// partChannel leaves a single channel using the service. +func (c *Conn) partChannel( + ctx context.Context, + channel, reason string, +) { + err := c.svc.PartChannel( + ctx, c.sessionID, c.nick, channel, reason, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + // Echo PART to the client on wire. + if reason != "" { + c.send(FormatMessage( + c.hostmask(), "PART", channel, reason, + )) + } else { + c.send(FormatMessage( + c.hostmask(), "PART", channel, + )) + } +} + +// handleQuit handles the QUIT command. +func (c *Conn) handleQuit(msg *Message) { + reason := "Client quit" + + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + + c.send("ERROR :Closing Link: " + c.hostname + + " (Quit: " + reason + ")") + c.closed = true +} + +// handleTopic gets or sets a channel topic via the shared +// service layer. +func (c *Conn) handleTopic( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "TOPIC", "Not enough parameters", + ) + + return + } + + channel := msg.Params[0] + + // If no second param, query the topic. + if len(msg.Params) < 2 { //nolint:mnd + c.queryTopic(ctx, channel) + + return + } + + // Set topic via service. + newTopic := msg.Params[1] + + err := c.svc.SetTopic( + ctx, c.sessionID, c.nick, channel, newTopic, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + // Echo TOPIC to the setting client on wire. + c.send(FormatMessage( + c.hostmask(), "TOPIC", channel, newTopic, + )) +} + +// queryTopic sends the current topic for a channel. +func (c *Conn) queryTopic( + ctx context.Context, + channel string, +) { + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + c.deliverTopic(ctx, channel, chID) +} + +// handleMode handles MODE queries and changes. +func (c *Conn) handleMode( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "MODE", "Not enough parameters", + ) + + return + } + + target := msg.Params[0] + + if strings.HasPrefix(target, "#") { + c.handleChannelMode(ctx, msg) + } else { + c.handleUserMode(msg) + } +} + +// handleChannelMode handles MODE for channels. +func (c *Conn) handleChannelMode( + ctx context.Context, + msg *Message, +) { + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + // Query mode if no mode string given. + if len(msg.Params) < 2 { //nolint:mnd + modeStr := c.svc.QueryChannelMode(ctx, chID) + c.sendNumeric( + irc.RplChannelModeIs, channel, modeStr, + ) + + created, _ := c.database.GetChannelCreatedAt( + ctx, chID, + ) + if !created.IsZero() { + c.sendNumeric( + irc.RplCreationTime, channel, + strconv.FormatInt(created.Unix(), 10), + ) + } + + return + } + + // Need ops to change modes — validated by service. + _, opErr := c.svc.ValidateChannelOp( + ctx, c.sessionID, channel, + ) + if opErr != nil { + c.sendIRCError(opErr) + + return + } + + modeStr := msg.Params[1] + modeArgs := msg.Params[2:] + + c.applyChannelModes( + ctx, channel, chID, modeStr, modeArgs, + ) +} + +// applyChannelModes applies mode changes using the +// service for individual mode operations. +func (c *Conn) applyChannelModes( + ctx context.Context, + channel string, + chID int64, + modeStr string, + args []string, +) { + adding := true + argIdx := 0 + applied := "" + appliedArgs := "" + + for _, modeChar := range modeStr { + switch modeChar { + case '+': + adding = true + case '-': + adding = false + case 'm', 't': + _ = c.svc.SetChannelFlag( + ctx, chID, modeChar, adding, + ) + + if adding { + applied += "+" + string(modeChar) + } else { + applied += "-" + string(modeChar) + } + case 'o', 'v': + if argIdx >= len(args) { + break + } + + targetNick := args[argIdx] + argIdx++ + + err := c.svc.ApplyMemberMode( + ctx, chID, channel, + targetNick, modeChar, adding, + ) + if err != nil { + c.sendIRCError(err) + + continue + } + + if adding { + applied += "+" + string(modeChar) + } else { + applied += "-" + string(modeChar) + } + + appliedArgs += " " + targetNick + default: + c.sendNumeric( + irc.ErrUnknownMode, + string(modeChar), + "is unknown mode char to me", + ) + } + } + + if applied != "" { + modeReply := applied + if appliedArgs != "" { + modeReply += appliedArgs + } + + c.send(FormatMessage( + c.hostmask(), "MODE", channel, modeReply, + )) + + c.svc.BroadcastMode( + ctx, c.nick, channel, chID, modeReply, + ) + } +} + +// handleUserMode handles MODE for users. +func (c *Conn) handleUserMode(msg *Message) { + target := msg.Params[0] + + if !strings.EqualFold(target, c.nick) { + c.sendNumeric( + irc.ErrUsersDoNotMatch, + "Can't change mode for other users", + ) + + return + } + + // We don't support user modes beyond the basics. + c.sendNumeric(irc.RplUmodeIs, "+") +} + +// handleNames replies with channel member list. +func (c *Conn) handleNames( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.RplEndOfNames, + "*", "End of /NAMES list", + ) + + return + } + + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) + + return + } + + c.deliverNames(ctx, channel, chID) +} + +// handleList sends the channel list. +func (c *Conn) handleList(ctx context.Context) { + channels, err := c.database.ListAllChannelsWithCounts( + ctx, + ) + if err != nil { + c.sendNumeric( + irc.RplListEnd, "End of /LIST", + ) + + return + } + + c.sendNumeric(irc.RplListStart, "Channel", "Users Name") + + for idx := range channels { + c.sendNumeric( + irc.RplList, channels[idx].Name, + strconv.FormatInt( + channels[idx].MemberCount, 10, + ), + channels[idx].Topic, + ) + } + + c.sendNumeric(irc.RplListEnd, "End of /LIST") +} + +// handleWhois replies with user info. Individual numeric +// replies are split into focused helper methods. +func (c *Conn) handleWhois( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + target := msg.Params[0] + + if len(msg.Params) > 1 { + target = msg.Params[1] + } + + targetID, err := c.database.GetSessionByNick( + ctx, target, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, target, "No such nick", + ) + c.sendNumeric( + irc.RplEndOfWhois, + target, "End of /WHOIS list", + ) + + return + } + + c.whoisUser(ctx, target, targetID) + c.whoisServer(target) + c.whoisOper(ctx, target, targetID) + c.whoisChannels(ctx, target, targetID) + c.whoisIdle(ctx, target, targetID) + c.whoisAway(ctx, target, targetID) + + c.sendNumeric( + irc.RplEndOfWhois, + target, "End of /WHOIS list", + ) +} + +// whoisUser sends 311 RPL_WHOISUSER. +func (c *Conn) whoisUser( + ctx context.Context, + target string, + targetID int64, +) { + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, targetID, + ) + + username := target + hostname := "*" + + if hostInfo != nil { + username = hostInfo.Username + hostname = hostInfo.Hostname + } + + c.sendNumeric( + irc.RplWhoisUser, target, + username, hostname, "*", target, + ) +} + +// whoisServer sends 312 RPL_WHOISSERVER. +func (c *Conn) whoisServer(target string) { + c.sendNumeric( + irc.RplWhoisServer, target, + c.serverSfx, "neoirc server", + ) +} + +// whoisOper sends 313 RPL_WHOISOPERATOR if applicable. +func (c *Conn) whoisOper( + ctx context.Context, + target string, + targetID int64, +) { + isOper, _ := c.database.IsSessionOper(ctx, targetID) + if isOper { + c.sendNumeric( + irc.RplWhoisOperator, + target, "is an IRC operator", + ) + } +} + +// whoisChannels sends 319 RPL_WHOISCHANNELS. +func (c *Conn) whoisChannels( + ctx context.Context, + target string, + targetID int64, +) { + userChannels, _ := c.database.GetSessionChannels( + ctx, targetID, + ) + if len(userChannels) == 0 { + return + } + + chanList := make([]string, 0, len(userChannels)) + + for _, userChan := range userChannels { + chID, getErr := c.database.GetChannelByName( + ctx, userChan.Name, + ) + if getErr != nil { + chanList = append(chanList, userChan.Name) + + continue + } + + isChOp, _ := c.database.IsChannelOperator( + ctx, chID, targetID, + ) + isVoiced, _ := c.database.IsChannelVoiced( + ctx, chID, targetID, + ) + + prefix := "" + if isChOp { + prefix = "@" + } else if isVoiced { + prefix = "+" + } + + chanList = append(chanList, prefix+userChan.Name) + } + + c.sendNumeric( + irc.RplWhoisChannels, target, + strings.Join(chanList, " "), + ) +} + +// whoisIdle sends 317 RPL_WHOISIDLE. +func (c *Conn) whoisIdle( + ctx context.Context, + target string, + targetID int64, +) { + lastSeen, _ := c.database.GetSessionLastSeen( + ctx, targetID, + ) + created, _ := c.database.GetSessionCreatedAt( + ctx, targetID, + ) + + if lastSeen.IsZero() { + return + } + + idle := int64(time.Since(lastSeen).Seconds()) + + signonTS := int64(0) + if !created.IsZero() { + signonTS = created.Unix() + } + + c.sendNumeric( + irc.RplWhoisIdle, target, + strconv.FormatInt(idle, 10), + strconv.FormatInt(signonTS, 10), + "seconds idle, signon time", + ) +} + +// whoisAway sends 301 RPL_AWAY if the target is away. +func (c *Conn) whoisAway( + ctx context.Context, + target string, + targetID int64, +) { + away, _ := c.database.GetAway(ctx, targetID) + if away != "" { + c.sendNumeric(irc.RplAway, target, away) + } +} + +// handleWho sends WHO replies for a channel. +func (c *Conn) handleWho( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.RplEndOfWho, "*", "End of /WHO list", + ) + + return + } + + target := msg.Params[0] + + if !strings.HasPrefix(target, "#") { + // WHO for a nick. + c.whoNick(ctx, target) + + return + } + + chID, err := c.database.GetChannelByName(ctx, target) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) + + return + } + + members, err := c.database.ChannelMembers(ctx, chID) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) + + return + } + + for _, member := range members { + flags := "H" + if member.IsOperator { + flags += "@" + } else if member.IsVoiced { + flags += "+" + } + + c.sendNumeric( + irc.RplWhoReply, + target, member.Username, member.Hostname, + c.serverSfx, member.Nick, flags, + "0 "+member.Nick, + ) + } + + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) +} + +// whoNick sends WHO reply for a single nick. +func (c *Conn) whoNick(ctx context.Context, nick string) { + targetID, err := c.database.GetSessionByNick(ctx, nick) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, nick, "End of /WHO list", + ) + + return + } + + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, targetID, + ) + + username := nick + hostname := "*" + + if hostInfo != nil { + username = hostInfo.Username + hostname = hostInfo.Hostname + } + + c.sendNumeric( + irc.RplWhoReply, + "*", username, hostname, + c.serverSfx, nick, "H", + "0 "+nick, + ) + c.sendNumeric( + irc.RplEndOfWho, nick, "End of /WHO list", + ) +} + +// handleLusers replies with server statistics. +func (c *Conn) handleLusers(ctx context.Context) { + c.deliverLusers(ctx) +} + +// handleOper handles the OPER command via the shared +// service layer. +func (c *Conn) handleOper( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "OPER", "Not enough parameters", + ) + + return + } + + err := c.svc.Oper( + ctx, c.sessionID, + msg.Params[0], msg.Params[1], + ) + if err != nil { + c.sendIRCError(err) + + return + } + + c.sendNumeric( + irc.RplYoureOper, + "You are now an IRC operator", + ) +} + +// handleAway sets or clears the AWAY status via the +// shared service layer. +func (c *Conn) handleAway( + ctx context.Context, + msg *Message, +) { + message := "" + if len(msg.Params) > 0 { + message = msg.Params[0] + } + + cleared, err := c.svc.SetAway( + ctx, c.sessionID, message, + ) + if err != nil { + c.log.Error("set away failed", "error", err) + + return + } + + if cleared { + c.sendNumeric( + irc.RplUnaway, + "You are no longer marked as being away", + ) + } else { + c.sendNumeric( + irc.RplNowAway, + "You have been marked as being away", + ) + } +} + +// handleKick kicks a user from a channel via the shared +// service layer. +func (c *Conn) handleKick( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "KICK", "Not enough parameters", + ) + + return + } + + channel := msg.Params[0] + targetNick := msg.Params[1] + + reason := targetNick + if len(msg.Params) > 2 { //nolint:mnd + reason = msg.Params[2] + } + + err := c.svc.KickUser( + ctx, c.sessionID, c.nick, + channel, targetNick, reason, + ) + if err != nil { + c.sendIRCError(err) + + return + } + + // Echo KICK on wire. + c.send(FormatMessage( + c.hostmask(), "KICK", channel, targetNick, reason, + )) +} + +// handlePassPostReg handles PASS after registration (for +// setting a session password). +func (c *Conn) handlePassPostReg( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PASS", "Not enough parameters", + ) + + return + } + + password := msg.Params[0] + if len(password) < minPasswordLen { + c.sendFromServer("NOTICE", c.nick, + "Password must be at least 8 characters", + ) + + return + } + + c.setPassword(ctx, password) + + c.sendFromServer("NOTICE", c.nick, + "Password set. You can reconnect using "+ + "PASS with your nick.", + ) +} + +// handleInvite handles the INVITE command. +func (c *Conn) handleInvite( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "INVITE", "Not enough parameters", + ) + + return + } + + targetNick := msg.Params[0] + channel := msg.Params[1] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, c.sessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrNotOnChannel, + channel, "You're not on that channel", + ) + + return + } + + targetID, err := c.database.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, + targetNick, "No such nick", + ) + + return + } + + c.sendNumeric( + irc.RplInviting, targetNick, channel, + ) + + // Send INVITE notice to target via service fan-out. + body, _ := json.Marshal( //nolint:errchkjson + []string{"You have been invited to " + channel}, + ) + + _, _, _ = c.svc.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, "INVITE", c.nick, targetNick, + nil, body, nil, []int64{targetID}, + ) +} + +// handleUserhost replies with USERHOST info. +func (c *Conn) handleUserhost( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + return + } + + replies := make([]string, 0, len(msg.Params)) + + for _, nick := range msg.Params { + sid, err := c.database.GetSessionByNick(ctx, nick) + if err != nil { + continue + } + + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, sid, + ) + + host := "*" + if hostInfo != nil { + host = hostInfo.Hostname + } + + isOper, _ := c.database.IsSessionOper(ctx, sid) + + operStar := "" + if isOper { + operStar = "*" + } + + replies = append( + replies, + nick+operStar+"=+"+nick+"@"+host, + ) + } + + c.sendNumeric( + irc.RplUserHost, + strings.Join(replies, " "), + ) +} diff --git a/internal/ircserver/conn.go b/internal/ircserver/conn.go new file mode 100644 index 0000000..173ac19 --- /dev/null +++ b/internal/ircserver/conn.go @@ -0,0 +1,501 @@ +package ircserver + +import ( + "bufio" + "context" + "fmt" + "log/slog" + "net" + "strconv" + "strings" + "sync" + "time" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/service" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +const ( + maxLineLen = 512 + readTimeout = 5 * time.Minute + writeTimeout = 30 * time.Second + dnsTimeout = 3 * time.Second + pollInterval = 100 * time.Millisecond + pingInterval = 90 * time.Second + pongDeadline = 30 * time.Second + maxNickLen = 32 + minPasswordLen = 8 +) + +// cmdHandler is the signature for registered IRC command +// handlers. +type cmdHandler func(ctx context.Context, msg *Message) + +// Conn represents a single IRC client TCP connection. +type Conn struct { + conn net.Conn + log *slog.Logger + database *db.Database + brk *broker.Broker + cfg *config.Config + svc *service.Service + serverSfx string + commands map[string]cmdHandler + + mu sync.Mutex + nick string + username string + realname string + hostname string + remoteIP string + sessionID int64 + clientID int64 + + registered bool + gotNick bool + gotUser bool + passWord string + + lastQueueID int64 + closed bool + cancel context.CancelFunc +} + +func newConn( + ctx context.Context, + tcpConn net.Conn, + log *slog.Logger, + database *db.Database, + brk *broker.Broker, + cfg *config.Config, + svc *service.Service, +) *Conn { + host, _, _ := net.SplitHostPort(tcpConn.RemoteAddr().String()) + + srvName := cfg.ServerName + if srvName == "" { + srvName = "neoirc" + } + + conn := &Conn{ //nolint:exhaustruct // zero-value defaults + conn: tcpConn, + log: log, + database: database, + brk: brk, + cfg: cfg, + svc: svc, + serverSfx: srvName, + remoteIP: host, + hostname: resolveHost(ctx, host), + } + + conn.commands = conn.buildCommandMap() + + return conn +} + +// buildCommandMap returns a map from IRC command strings +// to handler functions. +func (c *Conn) buildCommandMap() map[string]cmdHandler { + return map[string]cmdHandler{ + irc.CmdPing: func(_ context.Context, msg *Message) { + c.handlePing(msg) + }, + "PONG": func(context.Context, *Message) {}, + irc.CmdNick: c.handleNick, + irc.CmdPrivmsg: c.handlePrivmsg, + irc.CmdNotice: c.handlePrivmsg, + irc.CmdJoin: c.handleJoin, + irc.CmdPart: c.handlePart, + irc.CmdQuit: func(_ context.Context, msg *Message) { + c.handleQuit(msg) + }, + irc.CmdTopic: c.handleTopic, + irc.CmdMode: c.handleMode, + irc.CmdNames: c.handleNames, + irc.CmdList: func(ctx context.Context, _ *Message) { c.handleList(ctx) }, + irc.CmdWhois: c.handleWhois, + irc.CmdWho: c.handleWho, + irc.CmdLusers: func(ctx context.Context, _ *Message) { c.handleLusers(ctx) }, + irc.CmdMotd: func(context.Context, *Message) { c.deliverMOTD() }, + irc.CmdOper: c.handleOper, + irc.CmdAway: c.handleAway, + irc.CmdKick: c.handleKick, + irc.CmdPass: c.handlePassPostReg, + "INVITE": c.handleInvite, + "CAP": func(_ context.Context, msg *Message) { + c.handleCAP(msg) + }, + "USERHOST": c.handleUserhost, + } +} + +// resolveHost does a reverse DNS lookup, returning the IP +// on failure. +func resolveHost(ctx context.Context, addr string) string { + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + resolver := &net.Resolver{} //nolint:exhaustruct + + names, err := resolver.LookupAddr(ctx, addr) + if err != nil || len(names) == 0 { + return addr + } + + return strings.TrimSuffix(names[0], ".") +} + +// serve is the main loop for a single IRC client connection. +func (c *Conn) serve(ctx context.Context) { + ctx, c.cancel = context.WithCancel(ctx) + defer c.cleanup(ctx) + + scanner := bufio.NewScanner(c.conn) + scanner.Buffer(make([]byte, maxLineLen), maxLineLen) + + for { + _ = c.conn.SetReadDeadline( + time.Now().Add(readTimeout), + ) + + if !scanner.Scan() { + return + } + + line := scanner.Text() + if line == "" { + continue + } + + msg := ParseMessage(line) + if msg == nil { + continue + } + + c.handleMessage(ctx, msg) + + if c.closed { + return + } + } +} + +func (c *Conn) cleanup(ctx context.Context) { + c.mu.Lock() + wasRegistered := c.registered + sessID := c.sessionID + nick := c.nick + c.closed = true + c.mu.Unlock() + + if wasRegistered && sessID > 0 { + c.svc.BroadcastQuit( + ctx, sessID, nick, "Connection closed", + ) + } + + c.conn.Close() //nolint:errcheck,gosec +} + +// send writes a formatted IRC line to the connection. +func (c *Conn) send(line string) { + _ = c.conn.SetWriteDeadline( + time.Now().Add(writeTimeout), + ) + + _, _ = fmt.Fprintf(c.conn, "%s\r\n", line) +} + +// sendNumeric sends a numeric reply from the server. +func (c *Conn) sendNumeric( + code irc.IRCMessageType, + params ...string, +) { + nick := c.nick + if nick == "" { + nick = "*" + } + + allParams := make([]string, 0, 1+len(params)) + allParams = append(allParams, nick) + allParams = append(allParams, params...) + + c.send(FormatMessage( + c.serverSfx, code.Code(), allParams..., + )) +} + +// sendFromServer sends a message from the server. +func (c *Conn) sendFromServer( + command string, params ...string, +) { + c.send(FormatMessage(c.serverSfx, command, params...)) +} + +// hostmask returns the client's full hostmask +// (nick!user@host). +func (c *Conn) hostmask() string { + user := c.username + if user == "" { + user = c.nick + } + + host := c.hostname + if host == "" { + host = c.remoteIP + } + + return c.nick + "!" + user + "@" + host +} + +// handleMessage dispatches a parsed IRC message using +// the command handler map. +func (c *Conn) handleMessage( + ctx context.Context, + msg *Message, +) { + // Before registration, only NICK, USER, PASS, PING, + // QUIT, and CAP are accepted. + if !c.registered { + c.handlePreRegistration(ctx, msg) + + return + } + + handler, ok := c.commands[msg.Command] + if !ok { + c.sendNumeric( + irc.ErrUnknownCommand, + msg.Command, "Unknown command", + ) + + return + } + + handler(ctx, msg) +} + +// handlePreRegistration handles messages before the +// connection is registered (NICK+USER received). +func (c *Conn) handlePreRegistration( + ctx context.Context, + msg *Message, +) { + switch msg.Command { + case irc.CmdPass: + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PASS", "Not enough parameters", + ) + + return + } + + c.passWord = msg.Params[0] + case irc.CmdNick: + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, + "No nickname given", + ) + + return + } + + c.nick = msg.Params[0] + if len(c.nick) > maxNickLen { + c.nick = c.nick[:maxNickLen] + } + + c.gotNick = true + case irc.CmdUser: + if len(msg.Params) < 4 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "USER", "Not enough parameters", + ) + + return + } + + c.username = msg.Params[0] + c.realname = msg.Params[3] + c.gotUser = true + case irc.CmdPing: + c.handlePing(msg) + + return + case irc.CmdQuit: + c.handleQuit(msg) + + return + case "CAP": + c.handleCAP(msg) + + return + default: + c.sendNumeric( + irc.ErrNotRegistered, + "You have not registered", + ) + + return + } + + // Try to complete registration once we have both + // NICK and USER. + if c.gotNick && c.gotUser { + c.completeRegistration(ctx) + } +} + +// completeRegistration creates a session and sends the +// welcome burst. +func (c *Conn) completeRegistration(ctx context.Context) { + // Check if nick is valid. + if c.nick == "" { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + // Create session in DB. + sessionID, clientID, _, err := c.database.CreateSession( + ctx, c.nick, c.username, c.hostname, c.remoteIP, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint") || + strings.Contains(err.Error(), "nick") { + c.sendNumeric( + irc.ErrNicknameInUse, + c.nick, "Nickname is already in use", + ) + + return + } + + c.log.Error( + "failed to create session", "error", err, + ) + c.send("ERROR :Internal server error") + c.closed = true + + return + } + + c.mu.Lock() + c.sessionID = sessionID + c.clientID = clientID + c.registered = true + c.mu.Unlock() + + // If PASS was provided before registration, set the + // session password. + if c.passWord != "" && len(c.passWord) >= minPasswordLen { + c.setPassword(ctx, c.passWord) + } + + // Send welcome burst. + c.deliverWelcome() + c.deliverLusers(ctx) + c.deliverMOTD() + + // Start the message relay goroutine. + go c.relayMessages(ctx) +} + +// deliverWelcome sends 001-005 welcome numerics. +func (c *Conn) deliverWelcome() { + c.sendNumeric(irc.RplWelcome, fmt.Sprintf( + "Welcome to the %s Network, %s", + c.serverSfx, c.hostmask(), + )) + c.sendNumeric(irc.RplYourHost, fmt.Sprintf( + "Your host is %s, running version neoirc", + c.serverSfx, + )) + c.sendNumeric( + irc.RplCreated, + "This server was created recently", + ) + c.sendNumeric( + irc.RplMyInfo, + c.serverSfx, "neoirc", "", "mnst", + ) + c.sendNumeric( + irc.RplIsupport, + "CHANTYPES=#", + "NICKLEN=32", + "PREFIX=(ov)@+", + "CHANMODES=,,H,mnst", + "NETWORK="+c.serverSfx, + "are supported by this server", + ) +} + +// deliverLusers sends 251/252/254/255 server statistics. +func (c *Conn) deliverLusers(ctx context.Context) { + users, _ := c.database.GetUserCount(ctx) + opers, _ := c.database.GetOperCount(ctx) + channels, _ := c.database.GetChannelCount(ctx) + + c.sendNumeric(irc.RplLuserClient, fmt.Sprintf( + "There are %d users and 0 invisible on 1 servers", + users, + )) + c.sendNumeric( + irc.RplLuserOp, + strconv.FormatInt(opers, 10), + "operator(s) online", + ) + c.sendNumeric( + irc.RplLuserChannels, + strconv.FormatInt(channels, 10), + "channels formed", + ) + c.sendNumeric(irc.RplLuserMe, fmt.Sprintf( + "I have %d clients and 1 servers", users, + )) +} + +// deliverMOTD sends 375/372/376 MOTD lines. +func (c *Conn) deliverMOTD() { + motd := c.cfg.MOTD + if motd == "" { + c.sendNumeric( + irc.ErrNoMotd, "MOTD File is missing", + ) + + return + } + + c.sendNumeric(irc.RplMotdStart, fmt.Sprintf( + "- %s Message of the Day -", c.serverSfx, + )) + + for _, line := range strings.Split(motd, "\n") { + c.sendNumeric(irc.RplMotd, "- "+line) + } + + c.sendNumeric( + irc.RplEndOfMotd, "End of /MOTD command", + ) +} + +// setPassword sets a bcrypt password on the session. +func (c *Conn) setPassword(ctx context.Context, pw string) { + // Use the database's auth module to hash and store. + err := c.database.SetPassword(ctx, c.sessionID, pw) + if err != nil { + c.log.Error( + "failed to set password", "error", err, + ) + } +} diff --git a/internal/ircserver/export_test.go b/internal/ircserver/export_test.go new file mode 100644 index 0000000..dc75ee7 --- /dev/null +++ b/internal/ircserver/export_test.go @@ -0,0 +1,52 @@ +package ircserver + +import ( + "context" + "log/slog" + "net" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/service" +) + +// NewTestServer creates a Server suitable for testing. +// The caller must call Stop() when finished. +func NewTestServer( + log *slog.Logger, + cfg *config.Config, + database *db.Database, + brk *broker.Broker, +) *Server { + svc := &service.Service{ + DB: database, + Broker: brk, + Config: cfg, + Log: log, + } + + return &Server{ //nolint:exhaustruct + log: log, + cfg: cfg, + database: database, + brk: brk, + svc: svc, + conns: make(map[*Conn]struct{}), + } +} + +// Start exposes the unexported start method for tests. +func (s *Server) Start(addr string) error { + return s.start(context.Background(), addr) +} + +// Stop exposes the unexported stop method for tests. +func (s *Server) Stop() { + s.stop() +} + +// Listener returns the server's net.Listener for tests. +func (s *Server) Listener() net.Listener { + return s.listener +} diff --git a/internal/ircserver/parser.go b/internal/ircserver/parser.go new file mode 100644 index 0000000..5374bc0 --- /dev/null +++ b/internal/ircserver/parser.go @@ -0,0 +1,123 @@ +// Package ircserver implements a traditional IRC wire protocol +// listener (RFC 1459/2812) that bridges to the neoirc HTTP/JSON +// server internals. +package ircserver + +import "strings" + +// Message represents a parsed IRC wire protocol message. +type Message struct { + // Prefix is the optional :prefix at the start (may be + // empty for client-to-server messages). + Prefix string + // Command is the IRC command (e.g., "PRIVMSG", "NICK"). + Command string + // Params holds the positional parameters, including the + // trailing parameter (which was preceded by ':' on the + // wire). + Params []string +} + +// ParseMessage parses a single IRC wire protocol line +// (without the trailing CR-LF) into a Message. +// Returns nil if the line is empty. +// +// IRC message format (RFC 1459 §2.3.1): +// +// [":" prefix SPACE] command { SPACE param } [SPACE ":" trailing] +func ParseMessage(line string) *Message { + if line == "" { + return nil + } + + msg := &Message{} //nolint:exhaustruct // fields set below + + // Extract prefix if present. + if line[0] == ':' { + idx := strings.IndexByte(line, ' ') + if idx < 0 { + // Only a prefix, no command — invalid. + return nil + } + + msg.Prefix = line[1:idx] + line = line[idx+1:] + } + + // Skip leading spaces. + line = strings.TrimLeft(line, " ") + if line == "" { + return nil + } + + // Extract command. + idx := strings.IndexByte(line, ' ') + if idx < 0 { + msg.Command = strings.ToUpper(line) + + return msg + } + + msg.Command = strings.ToUpper(line[:idx]) + line = line[idx+1:] + + // Extract parameters. + for line != "" { + line = strings.TrimLeft(line, " ") + if line == "" { + break + } + + // Trailing parameter (everything after ':'). + if line[0] == ':' { + msg.Params = append(msg.Params, line[1:]) + + break + } + + idx = strings.IndexByte(line, ' ') + if idx < 0 { + msg.Params = append(msg.Params, line) + + break + } + + msg.Params = append(msg.Params, line[:idx]) + line = line[idx+1:] + } + + return msg +} + +// FormatMessage formats an IRC message into wire protocol +// format (without the trailing CR-LF). +func FormatMessage( + prefix, command string, + params ...string, +) string { + var buf strings.Builder + + if prefix != "" { + buf.WriteByte(':') + buf.WriteString(prefix) + buf.WriteByte(' ') + } + + buf.WriteString(command) + + for i, param := range params { + buf.WriteByte(' ') + + isLast := i == len(params)-1 + needsColon := strings.Contains(param, " ") || + param == "" || param[0] == ':' + + if isLast && needsColon { + buf.WriteByte(':') + } + + buf.WriteString(param) + } + + return buf.String() +} diff --git a/internal/ircserver/parser_test.go b/internal/ircserver/parser_test.go new file mode 100644 index 0000000..8585eab --- /dev/null +++ b/internal/ircserver/parser_test.go @@ -0,0 +1,328 @@ +package ircserver_test + +import ( + "testing" + + "git.eeqj.de/sneak/neoirc/internal/ircserver" +) + +//nolint:funlen // table-driven test +func TestParseMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want *ircserver.Message + wantNil bool + }{ + { + name: "empty", + input: "", + want: nil, + wantNil: true, + }, + { + name: "simple command", + input: "PING", + want: &ircserver.Message{ + Prefix: "", + Command: "PING", + Params: nil, + }, + wantNil: false, + }, + { + name: "command with one param", + input: "NICK alice", + want: &ircserver.Message{ + Prefix: "", + Command: "NICK", + Params: []string{"alice"}, + }, + wantNil: false, + }, + { + name: "command case insensitive", + input: "nick Alice", + want: &ircserver.Message{ + Prefix: "", + Command: "NICK", + Params: []string{"Alice"}, + }, + wantNil: false, + }, + { + name: "privmsg with trailing", + input: "PRIVMSG #general :hello world", + want: &ircserver.Message{ + Prefix: "", + Command: "PRIVMSG", + Params: []string{"#general", "hello world"}, + }, + wantNil: false, + }, + { + name: "with prefix", + input: ":server.example.com 001 alice :Welcome to IRC", + want: &ircserver.Message{ + Prefix: "server.example.com", + Command: "001", + Params: []string{"alice", "Welcome to IRC"}, + }, + wantNil: false, + }, + { + name: "user command", + input: "USER alice 0 * :Alice Smith", + want: &ircserver.Message{ + Prefix: "", + Command: "USER", + Params: []string{ + "alice", "0", "*", "Alice Smith", + }, + }, + wantNil: false, + }, + { + name: "join channel", + input: "JOIN #general", + want: &ircserver.Message{ + Prefix: "", + Command: "JOIN", + Params: []string{"#general"}, + }, + wantNil: false, + }, + { + name: "quit with trailing", + input: "QUIT :leaving now", + want: &ircserver.Message{ + Prefix: "", + Command: "QUIT", + Params: []string{"leaving now"}, + }, + wantNil: false, + }, + { + name: "quit without reason", + input: "QUIT", + want: &ircserver.Message{ + Prefix: "", + Command: "QUIT", + Params: nil, + }, + wantNil: false, + }, + { + name: "mode query", + input: "MODE #general", + want: &ircserver.Message{ + Prefix: "", + Command: "MODE", + Params: []string{"#general"}, + }, + wantNil: false, + }, + { + name: "kick with reason", + input: "KICK #general bob :misbehaving", + want: &ircserver.Message{ + Prefix: "", + Command: "KICK", + Params: []string{ + "#general", "bob", "misbehaving", + }, + }, + wantNil: false, + }, + { + name: "empty trailing", + input: "PRIVMSG #general :", + want: &ircserver.Message{ + Prefix: "", + Command: "PRIVMSG", + Params: []string{"#general", ""}, + }, + wantNil: false, + }, + { + name: "pass command", + input: "PASS mysecret", + want: &ircserver.Message{ + Prefix: "", + Command: "PASS", + Params: []string{"mysecret"}, + }, + wantNil: false, + }, + { + name: "ping with server", + input: "PING :irc.example.com", + want: &ircserver.Message{ + Prefix: "", + Command: "PING", + Params: []string{"irc.example.com"}, + }, + wantNil: false, + }, + { + name: "topic with trailing spaces", + input: "TOPIC #general :Welcome to the channel!", + want: &ircserver.Message{ + Prefix: "", + Command: "TOPIC", + Params: []string{ + "#general", + "Welcome to the channel!", + }, + }, + wantNil: false, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := ircserver.ParseMessage(testCase.input) + if testCase.wantNil { + if got != nil { + t.Fatalf("expected nil, got %+v", got) + } + + return + } + + if got == nil { + t.Fatal("expected non-nil message") + } + + if got.Prefix != testCase.want.Prefix { + t.Errorf( + "prefix: got %q, want %q", + got.Prefix, testCase.want.Prefix, + ) + } + + if got.Command != testCase.want.Command { + t.Errorf( + "command: got %q, want %q", + got.Command, testCase.want.Command, + ) + } + + if len(got.Params) != len(testCase.want.Params) { + t.Fatalf( + "params length: got %d, want %d (%v vs %v)", + len(got.Params), + len(testCase.want.Params), + got.Params, + testCase.want.Params, + ) + } + + for i, p := range got.Params { + if p != testCase.want.Params[i] { + t.Errorf( + "param[%d]: got %q, want %q", + i, p, testCase.want.Params[i], + ) + } + } + }) + } +} + +func TestFormatMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prefix string + command string + params []string + want string + }{ + { + name: "simple command", + prefix: "", + command: "PING", + params: nil, + want: "PING", + }, + { + name: "with prefix", + prefix: "server", + command: "PONG", + params: []string{"server"}, + want: ":server PONG server", + }, + { + name: "privmsg with trailing", + prefix: "alice!alice@host", + command: "PRIVMSG", + params: []string{"#general", "hello world"}, + want: ":alice!alice@host PRIVMSG #general :hello world", + }, + { + name: "numeric reply", + prefix: "server", + command: "001", + params: []string{"alice", "Welcome to IRC"}, + want: ":server 001 alice :Welcome to IRC", + }, + { + name: "empty trailing", + prefix: "server", + command: "PRIVMSG", + params: []string{"#chan", ""}, + want: ":server PRIVMSG #chan :", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := ircserver.FormatMessage( + testCase.prefix, testCase.command, testCase.params..., + ) + if got != testCase.want { + t.Errorf("got %q, want %q", got, testCase.want) + } + }) + } +} + +func TestParseFormatRoundTrip(t *testing.T) { + t.Parallel() + + // Round-trip only works for lines where the last + // parameter either contains a space (gets ':' prefix + // on format) or is a non-trailing single token. + lines := []string{ + "PING", + "NICK alice", + "PRIVMSG #general :hello world", + "JOIN #general", + "MODE #general", + } + + for _, line := range lines { + msg := ircserver.ParseMessage(line) + if msg == nil { + t.Fatalf("failed to parse: %q", line) + } + + formatted := ircserver.FormatMessage( + msg.Prefix, msg.Command, msg.Params..., + ) + if formatted != line { + t.Errorf( + "round-trip failed: input %q, got %q", + line, formatted, + ) + } + } +} diff --git a/internal/ircserver/relay.go b/internal/ircserver/relay.go new file mode 100644 index 0000000..c1f3436 --- /dev/null +++ b/internal/ircserver/relay.go @@ -0,0 +1,319 @@ +package ircserver + +import ( + "context" + "encoding/json" + "strings" + "time" + + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +// relayMessages polls the client output queue and delivers +// IRC-formatted messages to the TCP connection. It runs +// in a goroutine for the lifetime of the connection. +func (c *Conn) relayMessages(ctx context.Context) { + // Use a ticker as a fallback; primary wakeup is via + // broker notification. + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + default: + } + + // Drain any available messages. + delivered := c.drainQueue(ctx) + + if delivered { + // Tight loop while there are messages. + continue + } + + // Wait for notification or timeout. + waitCh := c.brk.Wait(c.sessionID) + + select { + case <-waitCh: + // New message notification — loop back. + case <-ticker.C: + // Periodic check. + case <-ctx.Done(): + c.brk.Remove(c.sessionID, waitCh) + + return + } + } +} + +const relayPollLimit = 100 + +// drainQueue polls the output queue and delivers all +// pending messages. Returns true if at least one message +// was delivered. +func (c *Conn) drainQueue(ctx context.Context) bool { + msgs, lastID, err := c.database.PollMessages( + ctx, c.clientID, c.lastQueueID, relayPollLimit, + ) + if err != nil { + return false + } + + if len(msgs) == 0 { + return false + } + + for i := range msgs { + c.deliverIRCMessage(ctx, &msgs[i]) + } + + if lastID > c.lastQueueID { + c.lastQueueID = lastID + } + + return true +} + +// deliverIRCMessage converts a db.IRCMessage to wire +// protocol and sends it. +// +//nolint:cyclop // dispatch table +func (c *Conn) deliverIRCMessage( + _ context.Context, + msg *db.IRCMessage, +) { + command := msg.Command + + // Decode body as []string for the trailing text. + var bodyLines []string + + if msg.Body != nil { + _ = json.Unmarshal(msg.Body, &bodyLines) + } + + text := "" + if len(bodyLines) > 0 { + text = bodyLines[0] + } + + // Route by command type. + switch { + case isNumeric(command): + c.deliverNumeric(msg, text) + case command == irc.CmdPrivmsg || command == irc.CmdNotice: + c.deliverTextMessage(msg, command, text) + case command == irc.CmdJoin: + c.deliverJoin(msg) + case command == irc.CmdPart: + c.deliverPart(msg, text) + case command == irc.CmdNick: + c.deliverNickChange(msg, text) + case command == irc.CmdQuit: + c.deliverQuitMsg(msg, text) + case command == irc.CmdTopic: + c.deliverTopicChange(msg, text) + case command == irc.CmdKick: + c.deliverKickMsg(msg, text) + case command == "INVITE": + c.deliverInviteMsg(msg, text) + case command == irc.CmdMode: + c.deliverMode(msg, text) + case command == irc.CmdPing: + // Server-originated PING — reply with PONG. + c.sendFromServer("PING", c.serverSfx) + default: + // Unknown command — deliver as server notice. + if text != "" { + c.sendFromServer("NOTICE", c.nick, text) + } + } +} + +// isNumeric returns true if the command is a 3-digit +// numeric code. +func isNumeric(cmd string) bool { + return len(cmd) == 3 && + cmd[0] >= '0' && cmd[0] <= '9' && + cmd[1] >= '0' && cmd[1] <= '9' && + cmd[2] >= '0' && cmd[2] <= '9' +} + +// deliverNumeric sends a numeric reply. +func (c *Conn) deliverNumeric( + msg *db.IRCMessage, + text string, +) { + from := msg.From + if from == "" { + from = c.serverSfx + } + + var params []string + + if msg.Params != nil { + _ = json.Unmarshal(msg.Params, ¶ms) + } + + allParams := make([]string, 0, 1+len(params)+1) + allParams = append(allParams, c.nick) + allParams = append(allParams, params...) + + if text != "" { + allParams = append(allParams, text) + } + + c.send(FormatMessage(from, msg.Command, allParams...)) +} + +// deliverTextMessage sends PRIVMSG or NOTICE. +func (c *Conn) deliverTextMessage( + msg *db.IRCMessage, + command, text string, +) { + from := msg.From + target := msg.To + + // Don't echo our own messages back. + if strings.EqualFold(from, c.nick) { + return + } + + prefix := from + if !strings.Contains(prefix, "!") { + prefix = from + "!" + from + "@*" + } + + c.send(FormatMessage(prefix, command, target, text)) +} + +// deliverJoin sends a JOIN notification. +func (c *Conn) deliverJoin(msg *db.IRCMessage) { + // Don't echo our own JOINs (we already sent them + // during joinChannel). + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + c.send(FormatMessage(prefix, "JOIN", channel)) +} + +// deliverPart sends a PART notification. +func (c *Conn) deliverPart(msg *db.IRCMessage, text string) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + if text != "" { + c.send(FormatMessage( + prefix, "PART", channel, text, + )) + } else { + c.send(FormatMessage(prefix, "PART", channel)) + } +} + +// deliverNickChange sends a NICK change notification. +func (c *Conn) deliverNickChange( + msg *db.IRCMessage, + newNick string, +) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + + c.send(FormatMessage(prefix, "NICK", newNick)) +} + +// deliverQuitMsg sends a QUIT notification. +func (c *Conn) deliverQuitMsg( + msg *db.IRCMessage, + text string, +) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + + if text != "" { + c.send(FormatMessage( + prefix, "QUIT", "Quit: "+text, + )) + } else { + c.send(FormatMessage(prefix, "QUIT", "Quit")) + } +} + +// deliverTopicChange sends a TOPIC change notification. +func (c *Conn) deliverTopicChange( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + c.send(FormatMessage(prefix, "TOPIC", channel, text)) +} + +// deliverKickMsg sends a KICK notification. +func (c *Conn) deliverKickMsg( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + var params []string + + if msg.Params != nil { + _ = json.Unmarshal(msg.Params, ¶ms) + } + + kickTarget := "" + if len(params) > 0 { + kickTarget = params[0] + } + + if kickTarget != "" { + c.send(FormatMessage( + prefix, "KICK", channel, kickTarget, text, + )) + } else { + c.send(FormatMessage( + prefix, "KICK", channel, "?", text, + )) + } +} + +// deliverInviteMsg sends an INVITE notification. +func (c *Conn) deliverInviteMsg( + _ *db.IRCMessage, + text string, +) { + c.sendFromServer("NOTICE", c.nick, text) +} + +// deliverMode sends a MODE change notification. +func (c *Conn) deliverMode( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + target := msg.To + + if text != "" { + c.send(FormatMessage(prefix, "MODE", target, text)) + } +} diff --git a/internal/ircserver/server.go b/internal/ircserver/server.go new file mode 100644 index 0000000..0ee9256 --- /dev/null +++ b/internal/ircserver/server.go @@ -0,0 +1,157 @@ +package ircserver + +import ( + "context" + "fmt" + "log/slog" + "net" + "sync" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/logger" + "git.eeqj.de/sneak/neoirc/internal/service" + "go.uber.org/fx" +) + +// Params defines the dependencies for creating an IRC +// Server. +type Params struct { + fx.In + + Logger *logger.Logger + Config *config.Config + Database *db.Database + Broker *broker.Broker + Service *service.Service +} + +// Server is the TCP IRC protocol server. +type Server struct { + log *slog.Logger + cfg *config.Config + database *db.Database + brk *broker.Broker + svc *service.Service + listener net.Listener + mu sync.Mutex + conns map[*Conn]struct{} + cancel context.CancelFunc +} + +// New creates a new IRC Server and registers its lifecycle +// hooks. The listener is only started if IRC_LISTEN_ADDR +// is configured; otherwise the server is inert. +func New( + lifecycle fx.Lifecycle, + params Params, +) *Server { + srv := &Server{ + log: params.Logger.Get(), + cfg: params.Config, + database: params.Database, + brk: params.Broker, + svc: params.Service, + conns: make(map[*Conn]struct{}), + listener: nil, + cancel: nil, + mu: sync.Mutex{}, + } + + listenAddr := params.Config.IRCListenAddr + if listenAddr == "" { + return srv + } + + lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return srv.start(ctx, listenAddr) + }, + OnStop: func(_ context.Context) error { + srv.stop() + + return nil + }, + }) + + return srv +} + +// start begins listening for TCP connections. +// +//nolint:contextcheck // long-lived server ctx, not the short Fx one +func (s *Server) start(_ context.Context, addr string) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("irc listen: %w", err) + } + + s.listener = ln + + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + s.log.Info( + "irc server listening", "addr", addr, + ) + + go s.acceptLoop(ctx) + + return nil +} + +// stop shuts down the listener and all connections. +func (s *Server) stop() { + if s.cancel != nil { + s.cancel() + } + + if s.listener != nil { + s.listener.Close() //nolint:errcheck,gosec + } + + s.mu.Lock() + for c := range s.conns { + c.conn.Close() //nolint:errcheck,gosec + } + s.mu.Unlock() +} + +// acceptLoop accepts new connections. +func (s *Server) acceptLoop(ctx context.Context) { + for { + tcpConn, err := s.listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + return + default: + s.log.Error( + "irc accept error", "error", err, + ) + + continue + } + } + + client := newConn( + ctx, tcpConn, s.log, + s.database, s.brk, s.cfg, s.svc, + ) + + s.mu.Lock() + s.conns[client] = struct{}{} + s.mu.Unlock() + + go func() { + defer func() { + s.mu.Lock() + delete(s.conns, client) + s.mu.Unlock() + }() + + client.serve(ctx) + }() + } +} diff --git a/internal/ircserver/server_test.go b/internal/ircserver/server_test.go new file mode 100644 index 0000000..fb21e6b --- /dev/null +++ b/internal/ircserver/server_test.go @@ -0,0 +1,625 @@ +package ircserver_test + +import ( + "bufio" + "database/sql" + "fmt" + "log/slog" + "net" + "os" + "strings" + "testing" + "time" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/ircserver" + + _ "modernc.org/sqlite" +) + +const testTimeout = 5 * time.Second + +func TestMain(m *testing.M) { + db.SetBcryptCost(4) + + os.Exit(m.Run()) +} + +// testEnv holds the shared test infrastructure. +type testEnv struct { + database *db.Database + brk *broker.Broker + cfg *config.Config + srv *ircserver.Server +} + +func newTestEnv(t *testing.T) *testEnv { + t.Helper() + + dsn := fmt.Sprintf( + "file:%s?mode=memory&cache=shared&_journal_mode=WAL", + t.Name(), + ) + + conn, err := sql.Open("sqlite", dsn) + if err != nil { + t.Fatalf("open db: %v", err) + } + + conn.SetMaxOpenConns(1) + + _, err = conn.ExecContext( + t.Context(), "PRAGMA foreign_keys = ON", + ) + if err != nil { + t.Fatalf("pragma: %v", err) + } + + database := db.NewTestDatabaseFromConn(conn) + + err = database.RunMigrations(t.Context()) + if err != nil { + t.Fatalf("migrate: %v", err) + } + + brk := broker.New() + + cfg := &config.Config{ //nolint:exhaustruct + ServerName: "test.irc", + MOTD: "Welcome to test IRC", + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + addr := listener.Addr().String() + + err = listener.Close() + if err != nil { + t.Fatalf("close listener: %v", err) + } + + log := slog.New(slog.NewTextHandler( + os.Stderr, + &slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct + )) + + srv := ircserver.NewTestServer(log, cfg, database, brk) + + err = srv.Start(addr) + if err != nil { + t.Fatalf("start irc server: %v", err) + } + + t.Cleanup(func() { + srv.Stop() + + err := conn.Close() + if err != nil { + t.Logf("close db: %v", err) + } + }) + + return &testEnv{ + database: database, + brk: brk, + cfg: cfg, + srv: srv, + } +} + +// dial connects to the test server. +func (env *testEnv) dial(t *testing.T) *testClient { + t.Helper() + + conn, err := net.DialTimeout( + "tcp", + env.srv.Listener().Addr().String(), + testTimeout, + ) + if err != nil { + t.Fatalf("dial: %v", err) + } + + t.Cleanup(func() { + err := conn.Close() + if err != nil { + t.Logf("close conn: %v", err) + } + }) + + return &testClient{ + t: t, + conn: conn, + scanner: bufio.NewScanner(conn), + } +} + +// testClient wraps a raw TCP connection with helpers. +type testClient struct { + t *testing.T + conn net.Conn + scanner *bufio.Scanner +} + +func (tc *testClient) send(line string) { + tc.t.Helper() + + _ = tc.conn.SetWriteDeadline( + time.Now().Add(testTimeout), + ) + + _, err := fmt.Fprintf(tc.conn, "%s\r\n", line) + if err != nil { + tc.t.Fatalf("send: %v", err) + } +} + +func (tc *testClient) readLine() string { + tc.t.Helper() + + _ = tc.conn.SetReadDeadline( + time.Now().Add(testTimeout), + ) + + if !tc.scanner.Scan() { + err := tc.scanner.Err() + if err != nil { + tc.t.Fatalf("read: %v", err) + } + + tc.t.Fatal("connection closed unexpectedly") + } + + return tc.scanner.Text() +} + +// readUntil reads lines until one matches the predicate. +func (tc *testClient) readUntil( + pred func(string) bool, +) []string { + tc.t.Helper() + + var lines []string + + for { + line := tc.readLine() + lines = append(lines, line) + + if pred(line) { + return lines + } + } +} + +// register sends NICK + USER and reads through the welcome +// burst. +func (tc *testClient) register(nick string) []string { + tc.t.Helper() + + tc.send("NICK " + nick) + tc.send("USER " + nick + " 0 * :Test User") + + return tc.readUntil(func(line string) bool { + return strings.Contains(line, " 376 ") || + strings.Contains(line, " 422 ") + }) +} + +// assertContains checks that at least one line matches the +// given substring. +func assertContains( + t *testing.T, + lines []string, + substr, description string, +) { + t.Helper() + + for _, line := range lines { + if strings.Contains(line, substr) { + return + } + } + + t.Errorf("did not find %q in output: %s", substr, description) +} + +// joinAndDrain joins a channel and reads until +// RPL_ENDOFNAMES. +func (tc *testClient) joinAndDrain(channel string) { + tc.t.Helper() + + tc.send("JOIN " + channel) + + tc.readUntil(func(line string) bool { + return strings.Contains(line, " 366 ") + }) +} + +// sendAndExpect sends a command and reads until a line +// containing the expected substring is found. +func (tc *testClient) sendAndExpect( + cmd, expect string, +) []string { + tc.t.Helper() + + tc.send(cmd) + + return tc.readUntil(func(line string) bool { + return strings.Contains(line, expect) + }) +} + +func TestRegistration(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + lines := client.register("alice") + assertContains(t, lines, " 001 ", "RPL_WELCOME") +} + +func TestWelcomeContainsNick(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + lines := client.register("bob") + + for _, line := range lines { + if strings.Contains(line, " 001 ") && + !strings.Contains(line, "bob") { + t.Errorf("001 should contain nick: %s", line) + } + } +} + +func TestPingPong(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("pingtest") + lines := client.sendAndExpect("PING :hello", "PONG") + assertContains(t, lines, "PONG", "PONG response") +} + +func TestJoinChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("joiner") + client.send("JOIN #test") + + lines := client.readUntil(func(line string) bool { + return strings.Contains(line, " 366 ") + }) + + assertContains(t, lines, "JOIN", "JOIN echo") + assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES") +} + +func TestPrivmsgBetweenClients(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + alice := env.dial(t) + alice.register("alice_pm") + + bob := env.dial(t) + bob.register("bob_pm") + + alice.joinAndDrain("#chat") + bob.joinAndDrain("#chat") + + alice.send("PRIVMSG #chat :hello bob!") + lines := bob.sendAndExpect("PING :sync", "hello bob!") + assertContains(t, lines, "hello bob!", "channel PRIVMSG") +} + +func TestNickChange(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("oldnick") + lines := client.sendAndExpect("NICK newnick", "newnick") + assertContains(t, lines, "NICK", "NICK change echo") +} + +func TestDuplicateNick(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + first := env.dial(t) + first.register("taken") + + second := env.dial(t) + second.send("NICK taken") + second.send("USER taken 0 * :Test") + + lines := second.readUntil(func(line string) bool { + return strings.Contains(line, " 433 ") + }) + + assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE") +} + +func TestListChannels(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("lister") + client.joinAndDrain("#listtest") + lines := client.sendAndExpect("LIST", " 323 ") + assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term +} + +func TestWhois(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("whoistest") + lines := client.sendAndExpect( + "WHOIS whoistest", " 318 ", + ) + + assertContains(t, lines, " 311 ", "RPL_WHOISUSER") + assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS") +} + +func TestQuit(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("quitter") + lines := client.sendAndExpect( + "QUIT :goodbye", "ERROR", + ) + + assertContains(t, lines, "goodbye", "QUIT reason") +} + +func TestTopicSetAndGet(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("topicuser") + client.joinAndDrain("#topictest") + lines := client.sendAndExpect( + "TOPIC #topictest :New topic here", + "New topic here", + ) + + assertContains( + t, lines, "New topic here", "TOPIC echo", + ) +} + +func TestUnknownCommand(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("unknowncmd") + lines := client.sendAndExpect("FOOBAR", " 421 ") + assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND") +} + +func TestDirectMessage(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + sender := env.dial(t) + sender.register("dmsender") + + receiver := env.dial(t) + receiver.register("dmreceiver") + + // Give relay goroutines time to start. + time.Sleep(100 * time.Millisecond) + + sender.send("PRIVMSG dmreceiver :hello privately") + + lines := receiver.readUntil(func(line string) bool { + return strings.Contains(line, "hello privately") + }) + + assertContains( + t, lines, "hello privately", "direct PRIVMSG", + ) +} + +func TestCAPNegotiation(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.send("CAP LS 302") + + line := client.readLine() + if !strings.Contains(line, "CAP") { + t.Errorf("expected CAP response, got: %s", line) + } + + client.send("CAP END") + lines := client.register("capuser") + assertContains(t, lines, " 001 ", "registration after CAP") +} + +func TestPartChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("parter") + client.joinAndDrain("#parttest") + lines := client.sendAndExpect( + "PART #parttest :leaving", "PART", + ) + + assertContains(t, lines, "#parttest", "PART echo") +} + +func TestModeQuery(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("modeuser") + client.joinAndDrain("#modetest") + lines := client.sendAndExpect( + "MODE #modetest", " 324 ", + ) + + assertContains( + t, lines, " 324 ", "RPL_CHANNELMODEIS", + ) +} + +func TestWhoChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("whouser") + client.joinAndDrain("#whotest") + lines := client.sendAndExpect("WHO #whotest", " 315 ") + assertContains(t, lines, " 352 ", "RPL_WHOREPLY") +} + +func TestLusers(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("luseruser") + lines := client.sendAndExpect("LUSERS", " 255 ") + assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT") +} + +func TestMotd(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("motduser") + lines := client.sendAndExpect("MOTD", " 376 ") + assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD") +} + +func TestAwaySetAndClear(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("awayuser") + + setLines := client.sendAndExpect( + "AWAY :brb lunch", " 306 ", + ) + assertContains(t, setLines, " 306 ", "RPL_NOWAWAY") + + clearLines := client.sendAndExpect("AWAY", " 305 ") + assertContains(t, clearLines, " 305 ", "RPL_UNAWAY") +} + +func TestHandlePassPostRegistration(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("passuser") + lines := client.sendAndExpect( + "PASS :mypassword123", "Password set", + ) + + assertContains( + t, lines, "Password set", "password confirmation", + ) +} + +func TestPreRegistrationNotRegistered(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.send("PRIVMSG #test :hello") + + line := client.readLine() + if !strings.Contains(line, " 451 ") { + t.Errorf( + "expected ERR_NOTREGISTERED (451), got: %s", + line, + ) + } +} + +func TestNamesNonExistentChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("namesuser") + lines := client.sendAndExpect( + "NAMES #doesnotexist", " 366 ", + ) + + assertContains( + t, lines, " 366 ", + "RPL_ENDOFNAMES for non-existent channel", + ) +} + +func BenchmarkParseMessage(b *testing.B) { + line := ":nick!user@host PRIVMSG #channel :Hello, world!" + + b.ResetTimer() + + for range b.N { + _ = ircserver.ParseMessage(line) + } +} + +func BenchmarkFormatMessage(b *testing.B) { + b.ResetTimer() + + for range b.N { + _ = ircserver.FormatMessage( + "nick!user@host", "PRIVMSG", + "#channel", "Hello, world!", + ) + } +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 0000000..11f3f19 --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,745 @@ +// Package service provides shared business logic for both +// the IRC wire protocol and HTTP/JSON transports. +package service + +import ( + "context" + "crypto/subtle" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/logger" + "git.eeqj.de/sneak/neoirc/pkg/irc" + "go.uber.org/fx" +) + +// Params defines the dependencies for creating a Service. +type Params struct { + fx.In + + Logger *logger.Logger + Config *config.Config + Database *db.Database + Broker *broker.Broker +} + +// Service provides shared business logic for IRC commands. +type Service struct { + DB *db.Database + Broker *broker.Broker + Config *config.Config + Log *slog.Logger +} + +// New creates a new Service. +func New(params Params) *Service { + return &Service{ + DB: params.Database, + Broker: params.Broker, + Config: params.Config, + Log: params.Logger.Get(), + } +} + +// IRCError represents an IRC protocol-level error with a +// numeric code that both transports can map to responses. +type IRCError struct { + Code irc.IRCMessageType + Params []string + Message string +} + +func (e *IRCError) Error() string { return e.Message } + +// JoinResult contains the outcome of a channel join. +type JoinResult struct { + ChannelID int64 + IsCreator bool +} + +// DirectMsgResult contains the outcome of a direct message. +type DirectMsgResult struct { + UUID string + AwayMsg string +} + +// FanOut inserts a message and enqueues it to all given +// session IDs, notifying each via the broker. +func (s *Service) FanOut( + ctx context.Context, + command, from, to string, + params, body, meta json.RawMessage, + sessionIDs []int64, +) (int64, string, error) { + dbID, msgUUID, err := s.DB.InsertMessage( + ctx, command, from, to, params, body, meta, + ) + if err != nil { + return 0, "", fmt.Errorf("insert message: %w", err) + } + + for _, sid := range sessionIDs { + _ = s.DB.EnqueueToSession(ctx, sid, dbID) + s.Broker.Notify(sid) + } + + return dbID, msgUUID, nil +} + +// excludeSession returns a copy of ids without the given +// session. +func excludeSession( + ids []int64, + exclude int64, +) []int64 { + out := make([]int64, 0, len(ids)) + + for _, id := range ids { + if id != exclude { + out = append(out, id) + } + } + + return out +} + +// SendChannelMessage validates membership and moderation, +// then fans out a message to all channel members except +// the sender. Returns the database row ID, message UUID, +// and any error. The dbID lets callers enqueue the same +// message to the sender when echo is needed (HTTP +// transport). +func (s *Service) SendChannelMessage( + ctx context.Context, + sessionID int64, + nick, command, channel string, + body, meta json.RawMessage, +) (int64, string, error) { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return 0, "", &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, sessionID, + ) + if !isMember { + return 0, "", &IRCError{ + irc.ErrCannotSendToChan, + []string{channel}, + "Cannot send to channel", + } + } + + moderated, _ := s.DB.IsChannelModerated(ctx, chID) + if moderated { + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + isVoiced, _ := s.DB.IsChannelVoiced( + ctx, chID, sessionID, + ) + + if !isOp && !isVoiced { + return 0, "", &IRCError{ + irc.ErrCannotSendToChan, + []string{channel}, + "Cannot send to channel (+m)", + } + } + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + recipients := excludeSession(memberIDs, sessionID) + + dbID, uuid, fanErr := s.FanOut( + ctx, command, nick, channel, + nil, body, meta, recipients, + ) + if fanErr != nil { + return 0, "", fanErr + } + + return dbID, uuid, nil +} + +// SendDirectMessage validates the target and sends a +// direct message, returning the message UUID and any away +// message set on the target. +func (s *Service) SendDirectMessage( + ctx context.Context, + sessionID int64, + nick, command, target string, + body, meta json.RawMessage, +) (*DirectMsgResult, error) { + targetSID, err := s.DB.GetSessionByNick(ctx, target) + if err != nil { + return nil, &IRCError{ + irc.ErrNoSuchNick, + []string{target}, + "No such nick", + } + } + + away, _ := s.DB.GetAway(ctx, targetSID) + + recipients := []int64{targetSID} + if targetSID != sessionID { + recipients = append(recipients, sessionID) + } + + _, uuid, fanErr := s.FanOut( + ctx, command, nick, target, + nil, body, meta, recipients, + ) + if fanErr != nil { + return nil, fanErr + } + + return &DirectMsgResult{UUID: uuid, AwayMsg: away}, nil +} + +// JoinChannel creates or joins a channel, making the +// first joiner the operator. Fans out the JOIN to all +// channel members. +func (s *Service) JoinChannel( + ctx context.Context, + sessionID int64, + nick, channel string, +) (*JoinResult, error) { + chID, err := s.DB.GetOrCreateChannel(ctx, channel) + if err != nil { + return nil, fmt.Errorf("get/create channel: %w", err) + } + + memberCount, countErr := s.DB.CountChannelMembers( + ctx, chID, + ) + isCreator := countErr == nil && memberCount == 0 + + if isCreator { + err = s.DB.JoinChannelAsOperator( + ctx, chID, sessionID, + ) + } else { + err = s.DB.JoinChannel(ctx, chID, sessionID) + } + + if err != nil { + return nil, fmt.Errorf("join channel: %w", err) + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{channel}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdJoin, nick, channel, + nil, body, nil, memberIDs, + ) + + return &JoinResult{ + ChannelID: chID, + IsCreator: isCreator, + }, nil +} + +// PartChannel validates membership, broadcasts PART to +// remaining members, removes the user, and cleans up empty +// channels. +func (s *Service) PartChannel( + ctx context.Context, + sessionID int64, + nick, channel, reason string, +) error { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, sessionID, + ) + if !isMember { + return &IRCError{ + irc.ErrNotOnChannel, + []string{channel}, + "You're not on that channel", + } + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + recipients := excludeSession(memberIDs, sessionID) + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdPart, nick, channel, + nil, body, nil, recipients, + ) + + s.DB.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec + s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + + return nil +} + +// SetTopic validates membership and topic-lock, sets the +// topic, and broadcasts the change. +func (s *Service) SetTopic( + ctx context.Context, + sessionID int64, + nick, channel, topic string, +) error { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, sessionID, + ) + if !isMember { + return &IRCError{ + irc.ErrNotOnChannel, + []string{channel}, + "You're not on that channel", + } + } + + topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) + if topicLocked { + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + if !isOp { + return &IRCError{ + irc.ErrChanOpPrivsNeeded, + []string{channel}, + "You're not channel operator", + } + } + } + + if setErr := s.DB.SetTopic( + ctx, channel, topic, + ); setErr != nil { + return fmt.Errorf("set topic: %w", setErr) + } + + _ = s.DB.SetTopicMeta(ctx, channel, topic, nick) + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{topic}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdTopic, nick, channel, + nil, body, nil, memberIDs, + ) + + return nil +} + +// KickUser validates operator status and target +// membership, broadcasts the KICK, removes the target, +// and cleans up empty channels. +func (s *Service) KickUser( + ctx context.Context, + sessionID int64, + nick, channel, targetNick, reason string, +) error { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + if !isOp { + return &IRCError{ + irc.ErrChanOpPrivsNeeded, + []string{channel}, + "You're not channel operator", + } + } + + targetSID, err := s.DB.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + return &IRCError{ + irc.ErrNoSuchNick, + []string{targetNick}, + "No such nick/channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, targetSID, + ) + if !isMember { + return &IRCError{ + irc.ErrUserNotInChannel, + []string{targetNick, channel}, + "They aren't on that channel", + } + } + + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + params, _ := json.Marshal( //nolint:errchkjson + []string{targetNick}, + ) + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdKick, nick, channel, + params, body, nil, memberIDs, + ) + + s.DB.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec + s.DB.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + + return nil +} + +// ChangeNick changes a user's nickname and broadcasts the +// change to all users sharing channels. +func (s *Service) ChangeNick( + ctx context.Context, + sessionID int64, + oldNick, newNick string, +) error { + err := s.DB.ChangeNick(ctx, sessionID, newNick) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE") || + db.IsUniqueConstraintError(err) { + return &IRCError{ + irc.ErrNicknameInUse, + []string{newNick}, + "Nickname is already in use", + } + } + + return &IRCError{ + irc.ErrErroneusNickname, + []string{newNick}, + "Erroneous nickname", + } + } + + s.broadcastNickChange(ctx, sessionID, oldNick, newNick) + + return nil +} + +// BroadcastQuit broadcasts a QUIT to all channel peers, +// parts all channels, and deletes the session. Uses the +// FanOut pattern: one message row fanned out to all unique +// peer sessions. +func (s *Service) BroadcastQuit( + ctx context.Context, + sessionID int64, + nick, reason string, +) { + channels, err := s.DB.GetSessionChannels( + ctx, sessionID, + ) + if err != nil { + return + } + + notified := make(map[int64]bool) + + for _, ch := range channels { + memberIDs, memErr := s.DB.GetChannelMemberIDs( + ctx, ch.ID, + ) + if memErr != nil { + continue + } + + for _, mid := range memberIDs { + if mid == sessionID || notified[mid] { + continue + } + + notified[mid] = true + } + } + + if len(notified) > 0 { + recipients := make([]int64, 0, len(notified)) + for sid := range notified { + recipients = append(recipients, sid) + } + + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + _, _, _ = s.FanOut( + ctx, irc.CmdQuit, nick, "", + nil, body, nil, recipients, + ) + } + + for _, ch := range channels { + s.DB.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec + s.DB.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec + } + + s.DB.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec +} + +// SetAway sets or clears the away message. Returns true +// if the message was cleared (empty string). +func (s *Service) SetAway( + ctx context.Context, + sessionID int64, + message string, +) (bool, error) { + err := s.DB.SetAway(ctx, sessionID, message) + if err != nil { + return false, fmt.Errorf("set away: %w", err) + } + + return message == "", nil +} + +// Oper validates operator credentials and grants oper +// status to the session. +func (s *Service) Oper( + ctx context.Context, + sessionID int64, + name, password string, +) error { + cfgName := s.Config.OperName + cfgPassword := s.Config.OperPassword + + // Use constant-time comparison and return the same + // error for all failures to prevent information + // leakage about valid operator names. + if cfgName == "" || cfgPassword == "" || + subtle.ConstantTimeCompare( + []byte(name), []byte(cfgName), + ) != 1 || + subtle.ConstantTimeCompare( + []byte(password), []byte(cfgPassword), + ) != 1 { + return &IRCError{ + irc.ErrNoOperHost, + nil, + "No O-lines for your host", + } + } + + _ = s.DB.SetSessionOper(ctx, sessionID, true) + + return nil +} + +// ValidateChannelOp checks that the session is a channel +// operator. Returns the channel ID. +func (s *Service) ValidateChannelOp( + ctx context.Context, + sessionID int64, + channel string, +) (int64, error) { + chID, err := s.DB.GetChannelByName(ctx, channel) + if err != nil { + return 0, &IRCError{ + irc.ErrNoSuchChannel, + []string{channel}, + "No such channel", + } + } + + isOp, _ := s.DB.IsChannelOperator( + ctx, chID, sessionID, + ) + if !isOp { + return 0, &IRCError{ + irc.ErrChanOpPrivsNeeded, + []string{channel}, + "You're not channel operator", + } + } + + return chID, nil +} + +// ApplyMemberMode applies +o/-o or +v/-v on a channel +// member after validating the target. +func (s *Service) ApplyMemberMode( + ctx context.Context, + chID int64, + channel, targetNick string, + mode rune, + adding bool, +) error { + targetSID, err := s.DB.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + return &IRCError{ + irc.ErrNoSuchNick, + []string{targetNick}, + "No such nick/channel", + } + } + + isMember, _ := s.DB.IsChannelMember( + ctx, chID, targetSID, + ) + if !isMember { + return &IRCError{ + irc.ErrUserNotInChannel, + []string{targetNick, channel}, + "They aren't on that channel", + } + } + + switch mode { + case 'o': + _ = s.DB.SetChannelMemberOperator( + ctx, chID, targetSID, adding, + ) + case 'v': + _ = s.DB.SetChannelMemberVoiced( + ctx, chID, targetSID, adding, + ) + } + + return nil +} + +// SetChannelFlag applies +m/-m or +t/-t on a channel. +func (s *Service) SetChannelFlag( + ctx context.Context, + chID int64, + flag rune, + setting bool, +) error { + switch flag { + case 'm': + if err := s.DB.SetChannelModerated( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf("set moderated: %w", err) + } + case 't': + if err := s.DB.SetChannelTopicLocked( + ctx, chID, setting, + ); err != nil { + return fmt.Errorf("set topic locked: %w", err) + } + } + + return nil +} + +// BroadcastMode fans out a MODE change to all channel +// members. +func (s *Service) BroadcastMode( + ctx context.Context, + nick, channel string, + chID int64, + modeText string, +) { + memberIDs, _ := s.DB.GetChannelMemberIDs(ctx, chID) + body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson + + _, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast + ctx, irc.CmdMode, nick, channel, + nil, body, nil, memberIDs, + ) +} + +// QueryChannelMode returns the channel mode string. +func (s *Service) QueryChannelMode( + ctx context.Context, + chID int64, +) string { + modes := "+" + + moderated, _ := s.DB.IsChannelModerated(ctx, chID) + if moderated { + modes += "m" + } + + topicLocked, _ := s.DB.IsChannelTopicLocked(ctx, chID) + if topicLocked { + modes += "t" + } + + return modes +} + +// broadcastNickChange notifies channel peers of a nick +// change. +func (s *Service) broadcastNickChange( + ctx context.Context, + sessionID int64, + oldNick, newNick string, +) { + channels, err := s.DB.GetSessionChannels( + ctx, sessionID, + ) + if err != nil { + return + } + + body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson + notified := make(map[int64]bool) + + dbID, _, insErr := s.DB.InsertMessage( + ctx, irc.CmdNick, oldNick, "", + nil, body, nil, + ) + if insErr != nil { + return + } + + // Notify the user themselves (for multi-client sync). + _ = s.DB.EnqueueToSession(ctx, sessionID, dbID) + s.Broker.Notify(sessionID) + notified[sessionID] = true + + for _, ch := range channels { + memberIDs, memErr := s.DB.GetChannelMemberIDs( + ctx, ch.ID, + ) + if memErr != nil { + continue + } + + for _, mid := range memberIDs { + if notified[mid] { + continue + } + + notified[mid] = true + + _ = s.DB.EnqueueToSession(ctx, mid, dbID) + s.Broker.Notify(mid) + } + } +} diff --git a/internal/service/service_test.go b/internal/service/service_test.go new file mode 100644 index 0000000..6d407d5 --- /dev/null +++ b/internal/service/service_test.go @@ -0,0 +1,365 @@ +// Tests use a global viper instance for configuration, +// making parallel execution unsafe. +// +//nolint:paralleltest +package service_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "testing" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/globals" + "git.eeqj.de/sneak/neoirc/internal/logger" + "git.eeqj.de/sneak/neoirc/internal/service" + "git.eeqj.de/sneak/neoirc/pkg/irc" + "go.uber.org/fx" + "go.uber.org/fx/fxtest" + "golang.org/x/crypto/bcrypt" +) + +func TestMain(m *testing.M) { + db.SetBcryptCost(bcrypt.MinCost) + os.Exit(m.Run()) +} + +// testEnv holds all dependencies for a service test. +type testEnv struct { + svc *service.Service + db *db.Database + broker *broker.Broker + app *fxtest.App +} + +func newTestEnv(t *testing.T) *testEnv { + t.Helper() + + dbURL := fmt.Sprintf( + "file:svc_test_%p?mode=memory&cache=shared", + t, + ) + + var ( + database *db.Database + svc *service.Service + ) + + brk := broker.New() + + app := fxtest.New(t, + fx.Provide( + func() *globals.Globals { + return &globals.Globals{ //nolint:exhaustruct + Appname: "neoirc-test", + Version: "test", + } + }, + logger.New, + func( + lifecycle fx.Lifecycle, + globs *globals.Globals, + log *logger.Logger, + ) (*config.Config, error) { + cfg, err := config.New( + lifecycle, config.Params{ //nolint:exhaustruct + Globals: globs, Logger: log, + }, + ) + if err != nil { + return nil, fmt.Errorf( + "test config: %w", err, + ) + } + + cfg.DBURL = dbURL + cfg.Port = 0 + cfg.OperName = "admin" + cfg.OperPassword = "secret" + + return cfg, nil + }, + func( + lifecycle fx.Lifecycle, + log *logger.Logger, + cfg *config.Config, + ) (*db.Database, error) { + return db.New(lifecycle, db.Params{ //nolint:exhaustruct + Logger: log, Config: cfg, + }) + }, + func() *broker.Broker { return brk }, + service.New, + ), + fx.Populate(&database, &svc), + ) + + app.RequireStart() + + t.Cleanup(func() { + app.RequireStop() + }) + + return &testEnv{ + svc: svc, + db: database, + broker: brk, + app: app, + } +} + +// createSession is a test helper that creates a session +// and returns the session ID. +func createSession( + ctx context.Context, + t *testing.T, + database *db.Database, + nick string, +) int64 { + t.Helper() + + sessionID, _, _, err := database.CreateSession( + ctx, nick, nick, "localhost", "127.0.0.1", + ) + if err != nil { + t.Fatalf("create session %s: %v", nick, err) + } + + return sessionID +} + +func TestFanOut(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + body, _ := json.Marshal([]string{"hello"}) //nolint:errchkjson + + dbID, uuid, err := env.svc.FanOut( + ctx, irc.CmdPrivmsg, "alice", "#test", + nil, body, nil, + []int64{sid1, sid2}, + ) + if err != nil { + t.Fatalf("FanOut: %v", err) + } + + if dbID == 0 { + t.Error("expected non-zero dbID") + } + + if uuid == "" { + t.Error("expected non-empty UUID") + } +} + +func TestJoinChannel(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid := createSession(ctx, t, env.db, "alice") + + result, err := env.svc.JoinChannel( + ctx, sid, "alice", "#general", + ) + if err != nil { + t.Fatalf("JoinChannel: %v", err) + } + + if result.ChannelID == 0 { + t.Error("expected non-zero channel ID") + } + + if !result.IsCreator { + t.Error("first joiner should be creator") + } + + // Second user joins — not creator. + sid2 := createSession(ctx, t, env.db, "bob") + + result2, err := env.svc.JoinChannel( + ctx, sid2, "bob", "#general", + ) + if err != nil { + t.Fatalf("JoinChannel bob: %v", err) + } + + if result2.IsCreator { + t.Error("second joiner should not be creator") + } + + if result2.ChannelID != result.ChannelID { + t.Error("both should join the same channel") + } +} + +func TestPartChannel(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid := createSession(ctx, t, env.db, "alice") + + _, err := env.svc.JoinChannel( + ctx, sid, "alice", "#general", + ) + if err != nil { + t.Fatalf("JoinChannel: %v", err) + } + + err = env.svc.PartChannel( + ctx, sid, "alice", "#general", "bye", + ) + if err != nil { + t.Fatalf("PartChannel: %v", err) + } + + // Parting a non-existent channel returns error. + err = env.svc.PartChannel( + ctx, sid, "alice", "#nonexistent", "", + ) + if err == nil { + t.Error("expected error for non-existent channel") + } + + var ircErr *service.IRCError + if !errors.As(err, &ircErr) { + t.Errorf("expected IRCError, got %T", err) + } +} + +func TestSendChannelMessage(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + _, err := env.svc.JoinChannel( + ctx, sid1, "alice", "#chat", + ) + if err != nil { + t.Fatalf("join alice: %v", err) + } + + _, err = env.svc.JoinChannel( + ctx, sid2, "bob", "#chat", + ) + if err != nil { + t.Fatalf("join bob: %v", err) + } + + body, _ := json.Marshal([]string{"hello world"}) //nolint:errchkjson + + dbID, uuid, err := env.svc.SendChannelMessage( + ctx, sid1, "alice", + irc.CmdPrivmsg, "#chat", body, nil, + ) + if err != nil { + t.Fatalf("SendChannelMessage: %v", err) + } + + if dbID == 0 { + t.Error("expected non-zero dbID") + } + + if uuid == "" { + t.Error("expected non-empty UUID") + } + + // Non-member cannot send. + sid3 := createSession(ctx, t, env.db, "charlie") + + _, _, err = env.svc.SendChannelMessage( + ctx, sid3, "charlie", + irc.CmdPrivmsg, "#chat", body, nil, + ) + if err == nil { + t.Error("expected error for non-member send") + } +} + +func TestBroadcastQuit(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + _, err := env.svc.JoinChannel( + ctx, sid1, "alice", "#room", + ) + if err != nil { + t.Fatalf("join alice: %v", err) + } + + _, err = env.svc.JoinChannel( + ctx, sid2, "bob", "#room", + ) + if err != nil { + t.Fatalf("join bob: %v", err) + } + + // BroadcastQuit should not panic and should clean up. + env.svc.BroadcastQuit( + ctx, sid1, "alice", "Goodbye", + ) + + // Session should be deleted. + _, lookupErr := env.db.GetSessionByNick(ctx, "alice") + if lookupErr == nil { + t.Error("expected session to be deleted after quit") + } +} + +func TestSendChannelMessage_Moderated(t *testing.T) { + env := newTestEnv(t) + ctx := t.Context() + + sid1 := createSession(ctx, t, env.db, "alice") + sid2 := createSession(ctx, t, env.db, "bob") + + result, err := env.svc.JoinChannel( + ctx, sid1, "alice", "#modchat", + ) + if err != nil { + t.Fatalf("join alice: %v", err) + } + + _, err = env.svc.JoinChannel( + ctx, sid2, "bob", "#modchat", + ) + if err != nil { + t.Fatalf("join bob: %v", err) + } + + // Set channel to moderated. + chID := result.ChannelID + _ = env.svc.SetChannelFlag(ctx, chID, 'm', true) + + body, _ := json.Marshal([]string{"test"}) //nolint:errchkjson + + // Bob (non-op, non-voiced) should fail to send. + _, _, err = env.svc.SendChannelMessage( + ctx, sid2, "bob", + irc.CmdPrivmsg, "#modchat", body, nil, + ) + if err == nil { + t.Error("expected error for non-voiced user in moderated channel") + } + + // Alice (operator) should succeed. + _, _, err = env.svc.SendChannelMessage( + ctx, sid1, "alice", + irc.CmdPrivmsg, "#modchat", body, nil, + ) + if err != nil { + t.Errorf("operator should be able to send in moderated channel: %v", err) + } +} diff --git a/pkg/irc/commands.go b/pkg/irc/commands.go index 91893ec..60ef85e 100644 --- a/pkg/irc/commands.go +++ b/pkg/irc/commands.go @@ -21,6 +21,7 @@ const ( CmdPrivmsg = "PRIVMSG" CmdQuit = "QUIT" CmdTopic = "TOPIC" + CmdUser = "USER" CmdWho = "WHO" CmdWhois = "WHOIS" )