diff --git a/README.md b/README.md index 0ad9205..1483a32 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ web client as a convenience/reference implementation, but the API comes first. - [Federation](#federation-server-to-server) - [Storage](#storage) - [Configuration](#configuration) +- [IRC Protocol Listener](#irc-protocol-listener) - [Deployment](#deployment) - [Client Development Guide](#client-development-guide) - [Rate Limiting & Abuse Prevention](#rate-limiting--abuse-prevention) @@ -2227,6 +2228,7 @@ directory is also loaded automatically via | `NEOIRC_OPER_PASSWORD` | string | `""` | Server operator (o-line) password. Both name and password must be set to enable OPER. | | `LOGIN_RATE_LIMIT` | float | `1` | Allowed login attempts per second per IP address. | | `LOGIN_RATE_BURST` | int | `5` | Maximum burst of login attempts per IP before rate limiting kicks in. | +| `IRC_LISTEN_ADDR` | string | `""` | TCP address for the traditional IRC protocol listener (e.g. `:6667`). Disabled if empty. | | `MAINTENANCE_MODE` | bool | `false` | Maintenance mode flag (reserved) | ### Example `.env` file @@ -2243,6 +2245,71 @@ NEOIRC_HASHCASH_BITS=20 --- +## IRC Protocol Listener + +neoirc includes an optional traditional IRC wire protocol listener (RFC +1459/2812) that allows standard IRC clients to connect directly. This enables +backward compatibility with existing IRC clients like irssi, weechat, hexchat, +and others. + +### Enabling + +Set the `IRC_LISTEN_ADDR` environment variable to a TCP address: + +```bash +IRC_LISTEN_ADDR=:6667 +``` + +When unset or empty, the IRC listener is disabled and only the HTTP/JSON API is +available. + +### Supported Commands + +| Category | Commands | +|------------|------------------------------------------------------| +| Connection | `NICK`, `USER`, `PASS`, `QUIT`, `PING`/`PONG`, `CAP` | +| Channels | `JOIN`, `PART`, `MODE`, `TOPIC`, `NAMES`, `LIST`, `KICK`, `INVITE` | +| Messaging | `PRIVMSG`, `NOTICE` | +| Info | `WHO`, `WHOIS`, `LUSERS`, `MOTD`, `AWAY` | +| Operator | `OPER` (requires `NEOIRC_OPER_NAME` and `NEOIRC_OPER_PASSWORD`) | + +### Protocol Details + +- **Wire format**: CR-LF delimited lines, max 512 bytes per line +- **Connection registration**: Clients must send `NICK` and `USER` to register. + An optional `PASS` before registration sets the session password (minimum 8 + characters). +- **CAP negotiation**: `CAP LS` and `CAP END` are silently handled for + compatibility with modern clients. No capabilities are advertised. +- **Channel prefixes**: Channels must start with `#`. If omitted, it's + automatically prepended. +- **First joiner**: The first user to join a channel is automatically granted + operator status (`@`). +- **Channel modes**: `+m` (moderated), `+t` (topic lock), `+o` (operator), + `+v` (voice) + +### Bridge to HTTP API + +Messages sent by IRC clients appear in channels visible to HTTP/JSON API +clients and vice versa. The IRC listener and HTTP API share the same database, +broker, and session infrastructure. A user connected via IRC and a user +connected via the HTTP API can communicate in the same channels seamlessly. + +### Docker Usage + +To expose the IRC port in Docker: + +```bash +docker run -d \ + -p 8080:8080 \ + -p 6667:6667 \ + -e IRC_LISTEN_ADDR=:6667 \ + -v neoirc-data:/var/lib/neoirc \ + neoirc +``` + +--- + ## Deployment ### Docker (Recommended) diff --git a/cmd/neoircd/main.go b/cmd/neoircd/main.go index d8d6601..d5a3988 100644 --- a/cmd/neoircd/main.go +++ b/cmd/neoircd/main.go @@ -2,11 +2,13 @@ package main import ( + "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/globals" "git.eeqj.de/sneak/neoirc/internal/handlers" "git.eeqj.de/sneak/neoirc/internal/healthcheck" + "git.eeqj.de/sneak/neoirc/internal/ircserver" "git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/middleware" "git.eeqj.de/sneak/neoirc/internal/server" @@ -28,16 +30,22 @@ func main() { fx.New( fx.Provide( + broker.New, config.New, db.New, globals.New, handlers.New, + ircserver.New, logger.New, server.New, middleware.New, healthcheck.New, stats.New, ), - fx.Invoke(func(*server.Server) {}), + fx.Invoke(func( + _ *server.Server, + _ *ircserver.Server, + ) { + }), ).Run() } diff --git a/internal/config/config.go b/internal/config/config.go index 05f6882..c2f9be7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -50,6 +50,7 @@ type Config struct { OperPassword string LoginRateLimit float64 LoginRateBurst int + IRCListenAddr string params *Params log *slog.Logger } @@ -86,6 +87,7 @@ func New( viper.SetDefault("NEOIRC_OPER_PASSWORD", "") viper.SetDefault("LOGIN_RATE_LIMIT", "1") viper.SetDefault("LOGIN_RATE_BURST", "5") + viper.SetDefault("IRC_LISTEN_ADDR", "") err := viper.ReadInConfig() if err != nil { @@ -116,6 +118,7 @@ func New( OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"), LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"), LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"), + IRCListenAddr: viper.GetString("IRC_LISTEN_ADDR"), log: log, params: ¶ms, } diff --git a/internal/db/testing.go b/internal/db/testing.go new file mode 100644 index 0000000..d036611 --- /dev/null +++ b/internal/db/testing.go @@ -0,0 +1,25 @@ +package db + +import ( + "context" + "database/sql" + "log/slog" +) + +// NewTestDatabaseFromConn creates a Database wrapping an +// existing *sql.DB connection. Intended for integration +// tests in other packages. +func NewTestDatabaseFromConn(conn *sql.DB) *Database { + return &Database{ //nolint:exhaustruct + conn: conn, + log: slog.Default(), + } +} + +// RunMigrations applies all schema migrations. Exposed +// for integration tests in other packages. +func (database *Database) RunMigrations( + ctx context.Context, +) error { + return database.runMigrations(ctx) +} diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index d52a9c2..085b095 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/globals" @@ -213,6 +214,7 @@ func newTestHandlers( Database: database, Healthcheck: hcheck, Stats: tracker, + Broker: broker.New(), }) if err != nil { return nil, fmt.Errorf("test handlers: %w", err) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 5d5c225..4a4412d 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -33,6 +33,7 @@ type Params struct { Database *db.Database Healthcheck *healthcheck.Healthcheck Stats *stats.Tracker + Broker *broker.Broker } const defaultIdleTimeout = 30 * 24 * time.Hour @@ -79,7 +80,7 @@ func New( params: ¶ms, log: params.Logger.Get(), hc: params.Healthcheck, - broker: broker.New(), + broker: params.Broker, hashcashVal: hashcash.NewValidator(resource), channelHashcash: hashcash.NewChannelValidator(), loginLimiter: ratelimit.New(loginRate, loginBurst), diff --git a/internal/ircserver/commands.go b/internal/ircserver/commands.go new file mode 100644 index 0000000..3874d4c --- /dev/null +++ b/internal/ircserver/commands.go @@ -0,0 +1,1562 @@ +package ircserver + +import ( + "context" + "encoding/json" + "strconv" + "strings" + "time" + + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +// handleCAP silently acknowledges CAP negotiation. +func (c *Conn) handleCAP(msg *Message) { + if len(msg.Params) == 0 { + return + } + + sub := strings.ToUpper(msg.Params[0]) + if sub == "LS" { + c.send(FormatMessage( + c.serverSfx, "CAP", "*", "LS", "", + )) + } + + // CAP END and other subcommands are silently ignored. +} + +// handlePing replies with a PONG. +func (c *Conn) handlePing(msg *Message) { + token := c.serverSfx + + if len(msg.Params) > 0 { + token = msg.Params[0] + } + + c.sendFromServer("PONG", c.serverSfx, token) +} + +// handleNick changes the user's nickname. +func (c *Conn) handleNick(ctx context.Context, msg *Message) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + newNick := msg.Params[0] + if len(newNick) > maxNickLen { + newNick = newNick[:maxNickLen] + } + + oldMask := c.hostmask() + oldNick := c.nick + + err := c.database.ChangeNick( + ctx, c.sessionID, newNick, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + c.sendNumeric( + irc.ErrNicknameInUse, + newNick, "Nickname is already in use", + ) + + return + } + + c.sendNumeric( + irc.ErrErroneusNickname, + newNick, "Erroneous nickname", + ) + + return + } + + c.mu.Lock() + c.nick = newNick + c.mu.Unlock() + + // Echo NICK change to the client. + c.send(FormatMessage(oldMask, "NICK", newNick)) + + // Broadcast nick change to shared channels. + c.broadcastNickChange(ctx, oldNick, newNick) +} + +// broadcastNickChange notifies channel peers of a nick +// change. +func (c *Conn) broadcastNickChange( + ctx context.Context, + oldNick, newNick string, +) { + channels, err := c.database.GetSessionChannels( + ctx, c.sessionID, + ) + if err != nil { + return + } + + body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson + notified := make(map[int64]bool) + + for _, ch := range channels { + chID, getErr := c.database.GetChannelByName( + ctx, ch.Name, + ) + if getErr != nil { + continue + } + + memberIDs, memErr := c.database.GetChannelMemberIDs( + ctx, chID, + ) + if memErr != nil { + continue + } + + for _, mid := range memberIDs { + if mid == c.sessionID || notified[mid] { + continue + } + + notified[mid] = true + + dbID, _, insErr := c.database.InsertMessage( + ctx, irc.CmdNick, oldNick, "", + nil, body, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession( + ctx, mid, dbID, + ) + c.brk.Notify(mid) + } + } +} + +// handlePrivmsg handles PRIVMSG and NOTICE commands. +func (c *Conn) handlePrivmsg( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoRecipient, + "No recipient given ("+msg.Command+")", + ) + + return + } + + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNoTextToSend, "No text to send", + ) + + return + } + + target := msg.Params[0] + text := msg.Params[1] + body, _ := json.Marshal([]string{text}) //nolint:errchkjson + + if strings.HasPrefix(target, "#") { + c.handleChannelMsg(ctx, msg.Command, target, body) + } else { + c.handleDirectMsg(ctx, msg.Command, target, body) + } +} + +// handleChannelMsg sends a message to a channel. +func (c *Conn) handleChannelMsg( + ctx context.Context, + command, channel string, + body json.RawMessage, +) { + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, c.sessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrCannotSendToChan, + channel, "Cannot send to channel", + ) + + return + } + + // Check moderated mode. + moderated, _ := c.database.IsChannelModerated(ctx, chID) + if moderated { + isOp, _ := c.database.IsChannelOperator( + ctx, chID, c.sessionID, + ) + isVoiced, _ := c.database.IsChannelVoiced( + ctx, chID, c.sessionID, + ) + + if !isOp && !isVoiced { + c.sendNumeric( + irc.ErrCannotSendToChan, + channel, + "Cannot send to channel (+m)", + ) + + return + } + } + + memberIDs, _ := c.database.GetChannelMemberIDs( + ctx, chID, + ) + + // Fan out to all members except sender. + for _, mid := range memberIDs { + if mid == c.sessionID { + continue + } + + dbID, _, insErr := c.database.InsertMessage( + ctx, command, c.nick, channel, + nil, body, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession(ctx, mid, dbID) + c.brk.Notify(mid) + } +} + +// handleDirectMsg sends a private message to a user. +func (c *Conn) handleDirectMsg( + ctx context.Context, + command, target string, + body json.RawMessage, +) { + targetID, err := c.database.GetSessionByNick( + ctx, target, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, target, "No such nick", + ) + + return + } + + // Check AWAY status. + away, _ := c.database.GetAway(ctx, targetID) + if away != "" { + c.sendNumeric( + irc.RplAway, target, away, + ) + } + + dbID, _, insErr := c.database.InsertMessage( + ctx, command, c.nick, target, + nil, body, nil, + ) + if insErr != nil { + return + } + + _ = c.database.EnqueueToSession(ctx, targetID, dbID) + c.brk.Notify(targetID) +} + +// handleJoin joins one or more channels. +func (c *Conn) handleJoin(ctx context.Context, msg *Message) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "JOIN", "Not enough parameters", + ) + + return + } + + channels := strings.Split(msg.Params[0], ",") + + for _, chanName := range channels { + chanName = strings.TrimSpace(chanName) + + if !strings.HasPrefix(chanName, "#") { + chanName = "#" + chanName + } + + c.joinChannel(ctx, chanName) + } +} + +// joinChannel joins a single channel. +func (c *Conn) joinChannel( + ctx context.Context, channel string, +) { + chID, err := c.database.GetOrCreateChannel(ctx, channel) + if err != nil { + c.log.Error( + "get/create channel failed", "error", err, + ) + + return + } + + // First joiner becomes operator. + memberCount, countErr := c.database.CountChannelMembers( + ctx, chID, + ) + isCreator := countErr == nil && memberCount == 0 + + if isCreator { + err = c.database.JoinChannelAsOperator( + ctx, chID, c.sessionID, + ) + } else { + err = c.database.JoinChannel( + ctx, chID, c.sessionID, + ) + } + + if err != nil { + return + } + + // Fan out JOIN to all channel members. + memberIDs, _ := c.database.GetChannelMemberIDs( + ctx, chID, + ) + + joinBody, _ := json.Marshal([]string{channel}) //nolint:errchkjson + + for _, mid := range memberIDs { + dbID, _, insErr := c.database.InsertMessage( + ctx, irc.CmdJoin, c.nick, channel, + nil, joinBody, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession(ctx, mid, dbID) + c.brk.Notify(mid) + } + + // Send JOIN echo to this client directly on wire. + c.send(FormatMessage(c.hostmask(), "JOIN", channel)) + + // Send topic. + c.deliverTopic(ctx, channel, chID) + + // Send NAMES. + c.deliverNames(ctx, channel, chID) +} + +// deliverTopic sends RPL_TOPIC or RPL_NOTOPIC. +func (c *Conn) deliverTopic( + ctx context.Context, + channel string, + chID int64, +) { + channels, err := c.database.ListChannels( + ctx, c.sessionID, + ) + + topic := "" + + if err == nil { + for _, ch := range channels { + if ch.Name == channel { + topic = ch.Topic + + break + } + } + } + + if topic == "" { + c.sendNumeric( + irc.RplNoTopic, channel, "No topic is set", + ) + + return + } + + c.sendNumeric(irc.RplTopic, channel, topic) + + meta, tmErr := c.database.GetTopicMeta(ctx, chID) + if tmErr == nil && meta != nil { + c.sendNumeric( + irc.RplTopicWhoTime, channel, + meta.SetBy, + strconv.FormatInt(meta.SetAt.Unix(), 10), + ) + } +} + +// deliverNames sends RPL_NAMREPLY and RPL_ENDOFNAMES. +func (c *Conn) deliverNames( + ctx context.Context, + channel string, + chID int64, +) { + members, err := c.database.ChannelMembers(ctx, chID) + if err != nil { + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) + + return + } + + names := make([]string, 0, len(members)) + + for _, member := range members { + prefix := "" + if member.IsOperator { + prefix = "@" + } else if member.IsVoiced { + prefix = "+" + } + + names = append(names, prefix+member.Nick) + } + + nameStr := strings.Join(names, " ") + + c.sendNumeric( + irc.RplNamReply, "=", channel, nameStr, + ) + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) +} + +// handlePart leaves one or more channels. +func (c *Conn) handlePart( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PART", "Not enough parameters", + ) + + return + } + + reason := "" + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + channels := strings.Split(msg.Params[0], ",") + + for _, ch := range channels { + ch = strings.TrimSpace(ch) + c.partChannel(ctx, ch, reason) + } +} + +// partChannel leaves a single channel. +func (c *Conn) partChannel( + ctx context.Context, + channel, reason string, +) { + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, c.sessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrNotOnChannel, + channel, "You're not on that channel", + ) + + return + } + + // Broadcast PART to channel members before leaving. + memberIDs, _ := c.database.GetChannelMemberIDs( + ctx, chID, + ) + + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + for _, mid := range memberIDs { + if mid == c.sessionID { + continue + } + + dbID, _, insErr := c.database.InsertMessage( + ctx, irc.CmdPart, c.nick, channel, + nil, body, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession(ctx, mid, dbID) + c.brk.Notify(mid) + } + + // Echo PART to the client. + if reason != "" { + c.send(FormatMessage( + c.hostmask(), "PART", channel, reason, + )) + } else { + c.send(FormatMessage( + c.hostmask(), "PART", channel, + )) + } + + c.database.PartChannel(ctx, chID, c.sessionID) //nolint:errcheck,gosec + c.database.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec +} + +// handleQuit handles the QUIT command. +func (c *Conn) handleQuit(msg *Message) { + reason := "Client quit" + + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + + c.send("ERROR :Closing Link: " + c.hostname + + " (Quit: " + reason + ")") + c.closed = true +} + +// handleTopic gets or sets a channel topic. +// +//nolint:funlen // coherent flow +func (c *Conn) handleTopic( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "TOPIC", "Not enough parameters", + ) + + return + } + + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + // If no second param, query the topic. + if len(msg.Params) < 2 { //nolint:mnd + c.deliverTopic(ctx, channel, chID) + + return + } + + // Set topic — check permissions. + isMember, _ := c.database.IsChannelMember( + ctx, chID, c.sessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrNotOnChannel, + channel, "You're not on that channel", + ) + + return + } + + topicLocked, _ := c.database.IsChannelTopicLocked( + ctx, chID, + ) + if topicLocked { + isOp, _ := c.database.IsChannelOperator( + ctx, chID, c.sessionID, + ) + if !isOp { + c.sendNumeric( + irc.ErrChanOpPrivsNeeded, + channel, + "You're not channel operator", + ) + + return + } + } + + newTopic := msg.Params[1] + + err = c.database.SetTopic(ctx, channel, newTopic) + if err != nil { + return + } + + _ = c.database.SetTopicMeta( + ctx, channel, newTopic, c.nick, + ) + + // Broadcast TOPIC to all members. + memberIDs, _ := c.database.GetChannelMemberIDs( + ctx, chID, + ) + + body, _ := json.Marshal([]string{newTopic}) //nolint:errchkjson + + for _, mid := range memberIDs { + dbID, _, insErr := c.database.InsertMessage( + ctx, irc.CmdTopic, c.nick, channel, + nil, body, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession(ctx, mid, dbID) + c.brk.Notify(mid) + } + + // Echo to the setting client on wire. + c.send(FormatMessage( + c.hostmask(), "TOPIC", channel, newTopic, + )) +} + +// handleMode handles MODE queries and changes. +func (c *Conn) handleMode( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "MODE", "Not enough parameters", + ) + + return + } + + target := msg.Params[0] + + if strings.HasPrefix(target, "#") { + c.handleChannelMode(ctx, msg) + } else { + c.handleUserMode(ctx, msg) + } +} + +// handleChannelMode handles MODE for channels. +func (c *Conn) handleChannelMode( + ctx context.Context, + msg *Message, +) { + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + // Query mode if no mode string given. + if len(msg.Params) < 2 { //nolint:mnd + modeStr := c.buildChannelModeString(ctx, chID) + c.sendNumeric( + irc.RplChannelModeIs, channel, modeStr, + ) + + created, _ := c.database.GetChannelCreatedAt( + ctx, chID, + ) + if !created.IsZero() { + c.sendNumeric( + irc.RplCreationTime, channel, + strconv.FormatInt(created.Unix(), 10), + ) + } + + return + } + + // Need ops to change modes. + isOp, _ := c.database.IsChannelOperator( + ctx, chID, c.sessionID, + ) + if !isOp { + c.sendNumeric( + irc.ErrChanOpPrivsNeeded, + channel, "You're not channel operator", + ) + + return + } + + modeStr := msg.Params[1] + modeArgs := msg.Params[2:] + + c.applyChannelModes(ctx, channel, chID, modeStr, modeArgs) +} + +// buildChannelModeString constructs the mode string for a +// channel. +func (c *Conn) buildChannelModeString( + ctx context.Context, + chID int64, +) string { + modes := "+" + + moderated, _ := c.database.IsChannelModerated(ctx, chID) + if moderated { + modes += "m" + } + + topicLocked, _ := c.database.IsChannelTopicLocked( + ctx, chID, + ) + if topicLocked { + modes += "t" + } + + if modes == "+" { + modes = "+" + } + + return modes +} + +// applyChannelModes applies mode changes. +// +//nolint:cyclop,funlen // mode parsing is inherently branchy +func (c *Conn) applyChannelModes( + ctx context.Context, + channel string, + chID int64, + modeStr string, + args []string, +) { + adding := true + argIdx := 0 + applied := "" + appliedArgs := "" + + for _, modeChar := range modeStr { + switch modeChar { + case '+': + adding = true + case '-': + adding = false + case 'm': + _ = c.database.SetChannelModerated( + ctx, chID, adding, + ) + if adding { + applied += "+m" + } else { + applied += "-m" + } + case 't': + _ = c.database.SetChannelTopicLocked( + ctx, chID, adding, + ) + if adding { + applied += "+t" + } else { + applied += "-t" + } + case 'o': + if argIdx >= len(args) { + break + } + + targetNick := args[argIdx] + argIdx++ + + c.applyMemberMode( + ctx, chID, channel, + targetNick, 'o', adding, + ) + + if adding { + applied += "+o" + } else { + applied += "-o" + } + + appliedArgs += " " + targetNick + case 'v': + if argIdx >= len(args) { + break + } + + targetNick := args[argIdx] + argIdx++ + + c.applyMemberMode( + ctx, chID, channel, + targetNick, 'v', adding, + ) + + if adding { + applied += "+v" + } else { + applied += "-v" + } + + appliedArgs += " " + targetNick + default: + c.sendNumeric( + irc.ErrUnknownMode, + string(modeChar), + "is unknown mode char to me", + ) + } + } + + if applied != "" { + modeReply := applied + if appliedArgs != "" { + modeReply += appliedArgs + } + + c.send(FormatMessage( + c.hostmask(), "MODE", channel, modeReply, + )) + } +} + +// applyMemberMode applies +o/-o or +v/-v on a member. +func (c *Conn) applyMemberMode( + ctx context.Context, + chID int64, + channel, targetNick string, + mode rune, + adding bool, +) { + targetSessionID, err := c.database.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, + targetNick, "No such nick/channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, targetSessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrUserNotInChannel, + targetNick, channel, + "They aren't on that channel", + ) + + return + } + + switch mode { + case 'o': + _ = c.database.SetChannelMemberOperator( + ctx, chID, targetSessionID, adding, + ) + case 'v': + _ = c.database.SetChannelMemberVoiced( + ctx, chID, targetSessionID, adding, + ) + } +} + +// handleUserMode handles MODE for users. +func (c *Conn) handleUserMode( + _ context.Context, + msg *Message, +) { + target := msg.Params[0] + + if !strings.EqualFold(target, c.nick) { + c.sendNumeric( + irc.ErrUsersDoNotMatch, + "Can't change mode for other users", + ) + + return + } + + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric(irc.RplUmodeIs, "+") + + return + } + + // We don't support user modes beyond the basics. + c.sendNumeric(irc.RplUmodeIs, "+") +} + +// handleNames replies with channel member list. +func (c *Conn) handleNames( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.RplEndOfNames, + "*", "End of /NAMES list", + ) + + return + } + + channel := msg.Params[0] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.RplEndOfNames, + channel, "End of /NAMES list", + ) + + return + } + + c.deliverNames(ctx, channel, chID) +} + +// handleList sends the channel list. +func (c *Conn) handleList(ctx context.Context) { + channels, err := c.database.ListAllChannelsWithCounts( + ctx, + ) + if err != nil { + c.sendNumeric( + irc.RplListEnd, "End of /LIST", + ) + + return + } + + c.sendNumeric(irc.RplListStart, "Channel", "Users Name") + + for idx := range channels { + c.sendNumeric( + irc.RplList, channels[idx].Name, + strconv.FormatInt( + channels[idx].MemberCount, 10, + ), + channels[idx].Topic, + ) + } + + c.sendNumeric(irc.RplListEnd, "End of /LIST") +} + +// handleWhois replies with user info. +// +//nolint:funlen // WHOIS has many reply fields +func (c *Conn) handleWhois( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + // The target nick may be the second param + // (WHOIS server nick). + target := msg.Params[0] + + if len(msg.Params) > 1 { + target = msg.Params[1] + } + + targetID, err := c.database.GetSessionByNick( + ctx, target, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, target, "No such nick", + ) + c.sendNumeric( + irc.RplEndOfWhois, + target, "End of /WHOIS list", + ) + + return + } + + // Get host info. + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, targetID, + ) + + username := target + hostname := "*" + + if hostInfo != nil { + username = hostInfo.Username + hostname = hostInfo.Hostname + } + + c.sendNumeric( + irc.RplWhoisUser, target, + username, hostname, "*", target, + ) + + c.sendNumeric( + irc.RplWhoisServer, target, + c.serverSfx, "neoirc server", + ) + + // Check oper status. + isOper, _ := c.database.IsSessionOper(ctx, targetID) + if isOper { + c.sendNumeric( + irc.RplWhoisOperator, + target, "is an IRC operator", + ) + } + + // Get channels. + userChannels, _ := c.database.GetSessionChannels( + ctx, targetID, + ) + + if len(userChannels) > 0 { + var chanList []string + + for _, userChan := range userChannels { + chID, getErr := c.database.GetChannelByName( + ctx, userChan.Name, + ) + if getErr != nil { + chanList = append( + chanList, userChan.Name, + ) + + continue + } + + isChOp, _ := c.database.IsChannelOperator( + ctx, chID, targetID, + ) + isVoiced, _ := c.database.IsChannelVoiced( + ctx, chID, targetID, + ) + + prefix := "" + if isChOp { + prefix = "@" + } else if isVoiced { + prefix = "+" + } + + chanList = append( + chanList, prefix+userChan.Name, + ) + } + + c.sendNumeric( + irc.RplWhoisChannels, target, + strings.Join(chanList, " "), + ) + } + + // Idle time. + lastSeen, _ := c.database.GetSessionLastSeen( + ctx, targetID, + ) + created, _ := c.database.GetSessionCreatedAt( + ctx, targetID, + ) + + if !lastSeen.IsZero() { + idle := int64(time.Since(lastSeen).Seconds()) + + signonTS := int64(0) + if !created.IsZero() { + signonTS = created.Unix() + } + + c.sendNumeric( + irc.RplWhoisIdle, target, + strconv.FormatInt(idle, 10), + strconv.FormatInt(signonTS, 10), + "seconds idle, signon time", + ) + } + + // Away. + away, _ := c.database.GetAway(ctx, targetID) + if away != "" { + c.sendNumeric(irc.RplAway, target, away) + } + + c.sendNumeric( + irc.RplEndOfWhois, + target, "End of /WHOIS list", + ) +} + +// handleWho sends WHO replies for a channel. +func (c *Conn) handleWho( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.RplEndOfWho, "*", "End of /WHO list", + ) + + return + } + + target := msg.Params[0] + + if !strings.HasPrefix(target, "#") { + // WHO for a nick. + c.whoNick(ctx, target) + + return + } + + chID, err := c.database.GetChannelByName(ctx, target) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) + + return + } + + members, err := c.database.ChannelMembers(ctx, chID) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) + + return + } + + for _, member := range members { + flags := "H" + if member.IsOperator { + flags += "@" + } else if member.IsVoiced { + flags += "+" + } + + c.sendNumeric( + irc.RplWhoReply, + target, member.Username, member.Hostname, + c.serverSfx, member.Nick, flags, + "0 "+member.Nick, + ) + } + + c.sendNumeric( + irc.RplEndOfWho, target, "End of /WHO list", + ) +} + +// whoNick sends WHO reply for a single nick. +func (c *Conn) whoNick(ctx context.Context, nick string) { + targetID, err := c.database.GetSessionByNick(ctx, nick) + if err != nil { + c.sendNumeric( + irc.RplEndOfWho, nick, "End of /WHO list", + ) + + return + } + + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, targetID, + ) + + username := nick + hostname := "*" + + if hostInfo != nil { + username = hostInfo.Username + hostname = hostInfo.Hostname + } + + c.sendNumeric( + irc.RplWhoReply, + "*", username, hostname, + c.serverSfx, nick, "H", + "0 "+nick, + ) + c.sendNumeric( + irc.RplEndOfWho, nick, "End of /WHO list", + ) +} + +// handleLusers replies with server statistics. +func (c *Conn) handleLusers(ctx context.Context) { + c.deliverLusers(ctx) +} + +// handleOper handles the OPER command. +func (c *Conn) handleOper( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "OPER", "Not enough parameters", + ) + + return + } + + name := msg.Params[0] + password := msg.Params[1] + + cfgName := c.cfg.OperName + cfgPassword := c.cfg.OperPassword + + if cfgName == "" || cfgPassword == "" { + c.sendNumeric( + irc.ErrNoOperHost, "No O-lines for your host", + ) + + return + } + + if name != cfgName || password != cfgPassword { + c.sendNumeric( + irc.ErrPasswdMismatch, "Password incorrect", + ) + + return + } + + _ = c.database.SetSessionOper(ctx, c.sessionID, true) + c.sendNumeric( + irc.RplYoureOper, + "You are now an IRC operator", + ) +} + +// handleAway sets or clears the AWAY status. +func (c *Conn) handleAway( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 || msg.Params[0] == "" { + _ = c.database.SetAway(ctx, c.sessionID, "") + c.sendNumeric( + irc.RplUnaway, + "You are no longer marked as being away", + ) + + return + } + + _ = c.database.SetAway(ctx, c.sessionID, msg.Params[0]) + c.sendNumeric( + irc.RplNowAway, + "You have been marked as being away", + ) +} + +// handleKick kicks a user from a channel. +// +//nolint:funlen // coherent flow +func (c *Conn) handleKick( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "KICK", "Not enough parameters", + ) + + return + } + + channel := msg.Params[0] + targetNick := msg.Params[1] + + reason := targetNick + if len(msg.Params) > 2 { //nolint:mnd + reason = msg.Params[2] + } + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + isOp, _ := c.database.IsChannelOperator( + ctx, chID, c.sessionID, + ) + if !isOp { + c.sendNumeric( + irc.ErrChanOpPrivsNeeded, + channel, "You're not channel operator", + ) + + return + } + + targetSessionID, err := c.database.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, + targetNick, "No such nick/channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, targetSessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrUserNotInChannel, + targetNick, channel, + "They aren't on that channel", + ) + + return + } + + // Broadcast KICK to all channel members. + memberIDs, _ := c.database.GetChannelMemberIDs( + ctx, chID, + ) + + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + for _, mid := range memberIDs { + dbID, _, insErr := c.database.InsertMessage( + ctx, irc.CmdKick, c.nick, channel, + nil, body, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession(ctx, mid, dbID) + c.brk.Notify(mid) + } + + // Remove from channel. + c.database.PartChannel(ctx, chID, targetSessionID) //nolint:errcheck,gosec + c.database.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec + + // Echo KICK on wire. + c.send(FormatMessage( + c.hostmask(), "KICK", channel, targetNick, reason, + )) +} + +// handlePassPostReg handles PASS after registration (for +// setting a session password). +func (c *Conn) handlePassPostReg( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PASS", "Not enough parameters", + ) + + return + } + + password := msg.Params[0] + if len(password) < minPasswordLen { + c.sendFromServer("NOTICE", c.nick, + "Password must be at least 8 characters", + ) + + return + } + + c.setPassword(ctx, password) + + c.sendFromServer("NOTICE", c.nick, + "Password set. You can reconnect using "+ + "PASS with your nick.", + ) +} + +// handleInvite handles the INVITE command. +func (c *Conn) handleInvite( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 2 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "INVITE", "Not enough parameters", + ) + + return + } + + targetNick := msg.Params[0] + channel := msg.Params[1] + + chID, err := c.database.GetChannelByName(ctx, channel) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchChannel, + channel, "No such channel", + ) + + return + } + + isMember, _ := c.database.IsChannelMember( + ctx, chID, c.sessionID, + ) + if !isMember { + c.sendNumeric( + irc.ErrNotOnChannel, + channel, "You're not on that channel", + ) + + return + } + + targetID, err := c.database.GetSessionByNick( + ctx, targetNick, + ) + if err != nil { + c.sendNumeric( + irc.ErrNoSuchNick, + targetNick, "No such nick", + ) + + return + } + + c.sendNumeric( + irc.RplInviting, targetNick, channel, + ) + + // Send INVITE notice to target. + body, _ := json.Marshal( //nolint:errchkjson + []string{"You have been invited to " + channel}, + ) + + dbID, _, insErr := c.database.InsertMessage( + ctx, "INVITE", c.nick, targetNick, + nil, body, nil, + ) + if insErr == nil { + _ = c.database.EnqueueToSession( + ctx, targetID, dbID, + ) + c.brk.Notify(targetID) + } +} + +// handleUserhost replies with USERHOST info. +func (c *Conn) handleUserhost( + ctx context.Context, + msg *Message, +) { + if len(msg.Params) < 1 { + return + } + + replies := make([]string, 0, len(msg.Params)) + + for _, nick := range msg.Params { + sid, err := c.database.GetSessionByNick(ctx, nick) + if err != nil { + continue + } + + hostInfo, _ := c.database.GetSessionHostInfo( + ctx, sid, + ) + + host := "*" + if hostInfo != nil { + host = hostInfo.Hostname + } + + isOper, _ := c.database.IsSessionOper(ctx, sid) + + operStar := "" + if isOper { + operStar = "*" + } + + replies = append( + replies, + nick+operStar+"=+"+nick+"@"+host, + ) + } + + c.sendNumeric( + irc.RplUserHost, + strings.Join(replies, " "), + ) +} diff --git a/internal/ircserver/conn.go b/internal/ircserver/conn.go new file mode 100644 index 0000000..eddc96d --- /dev/null +++ b/internal/ircserver/conn.go @@ -0,0 +1,551 @@ +package ircserver + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "log/slog" + "net" + "strconv" + "strings" + "sync" + "time" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +const ( + maxLineLen = 512 + readTimeout = 5 * time.Minute + writeTimeout = 30 * time.Second + dnsTimeout = 3 * time.Second + pollInterval = 100 * time.Millisecond + pingInterval = 90 * time.Second + pongDeadline = 30 * time.Second + maxNickLen = 32 + minPasswordLen = 8 +) + +// Conn represents a single IRC client TCP connection. +type Conn struct { + conn net.Conn + log *slog.Logger + database *db.Database + brk *broker.Broker + cfg *config.Config + serverSfx string + + mu sync.Mutex + nick string + username string + realname string + hostname string + remoteIP string + sessionID int64 + clientID int64 + + registered bool + gotNick bool + gotUser bool + passWord string + + lastQueueID int64 + closed bool + cancel context.CancelFunc +} + +func newConn( + ctx context.Context, + tcpConn net.Conn, + log *slog.Logger, + database *db.Database, + brk *broker.Broker, + cfg *config.Config, +) *Conn { + host, _, _ := net.SplitHostPort(tcpConn.RemoteAddr().String()) + + srvName := cfg.ServerName + if srvName == "" { + srvName = "neoirc" + } + + return &Conn{ //nolint:exhaustruct // zero-value defaults + conn: tcpConn, + log: log, + database: database, + brk: brk, + cfg: cfg, + serverSfx: srvName, + remoteIP: host, + hostname: resolveHost(ctx, host), + } +} + +// resolveHost does a reverse DNS lookup, returning the IP +// on failure. +func resolveHost(ctx context.Context, addr string) string { + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + resolver := &net.Resolver{} //nolint:exhaustruct + + names, err := resolver.LookupAddr(ctx, addr) + if err != nil || len(names) == 0 { + return addr + } + + return strings.TrimSuffix(names[0], ".") +} + +// serve is the main loop for a single IRC client connection. +func (c *Conn) serve(ctx context.Context) { + ctx, c.cancel = context.WithCancel(ctx) + defer c.cleanup(ctx) + + scanner := bufio.NewScanner(c.conn) + scanner.Buffer(make([]byte, maxLineLen), maxLineLen) + + for { + _ = c.conn.SetReadDeadline( + time.Now().Add(readTimeout), + ) + + if !scanner.Scan() { + return + } + + line := scanner.Text() + if line == "" { + continue + } + + msg := ParseMessage(line) + if msg == nil { + continue + } + + c.handleMessage(ctx, msg) + + if c.closed { + return + } + } +} + +func (c *Conn) cleanup(ctx context.Context) { + c.mu.Lock() + wasRegistered := c.registered + sessID := c.sessionID + nick := c.nick + c.closed = true + c.mu.Unlock() + + if wasRegistered && sessID > 0 { + c.broadcastQuit(ctx, nick, "Connection closed") + c.database.DeleteSession(ctx, sessID) //nolint:errcheck,gosec + } + + c.conn.Close() //nolint:errcheck,gosec +} + +func (c *Conn) broadcastQuit( + ctx context.Context, + nick, reason string, +) { + channels, err := c.database.GetSessionChannels( + ctx, c.sessionID, + ) + if err != nil { + return + } + + notified := make(map[int64]bool) + + for _, ch := range channels { + chID, getErr := c.database.GetChannelByName( + ctx, ch.Name, + ) + if getErr != nil { + continue + } + + memberIDs, memErr := c.database.GetChannelMemberIDs( + ctx, chID, + ) + if memErr != nil { + continue + } + + for _, mid := range memberIDs { + if mid == c.sessionID || notified[mid] { + continue + } + + notified[mid] = true + } + } + + body, _ := json.Marshal([]string{reason}) //nolint:errchkjson + + for sid := range notified { + dbID, _, insErr := c.database.InsertMessage( + ctx, irc.CmdQuit, nick, "", nil, body, nil, + ) + if insErr != nil { + continue + } + + _ = c.database.EnqueueToSession(ctx, sid, dbID) + c.brk.Notify(sid) + } + + // Part from all channels so they get cleaned up. + for _, ch := range channels { + c.database.PartChannel(ctx, ch.ID, c.sessionID) //nolint:errcheck,gosec + c.database.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec + } +} + +// send writes a formatted IRC line to the connection. +func (c *Conn) send(line string) { + _ = c.conn.SetWriteDeadline( + time.Now().Add(writeTimeout), + ) + + _, _ = fmt.Fprintf(c.conn, "%s\r\n", line) +} + +// sendNumeric sends a numeric reply from the server. +func (c *Conn) sendNumeric( + code irc.IRCMessageType, + params ...string, +) { + nick := c.nick + if nick == "" { + nick = "*" + } + + allParams := make([]string, 0, 1+len(params)) + allParams = append(allParams, nick) + allParams = append(allParams, params...) + + c.send(FormatMessage( + c.serverSfx, code.Code(), allParams..., + )) +} + +// sendFromServer sends a message from the server. +func (c *Conn) sendFromServer( + command string, params ...string, +) { + c.send(FormatMessage(c.serverSfx, command, params...)) +} + +// hostmask returns the client's full hostmask +// (nick!user@host). +func (c *Conn) hostmask() string { + user := c.username + if user == "" { + user = c.nick + } + + host := c.hostname + if host == "" { + host = c.remoteIP + } + + return c.nick + "!" + user + "@" + host +} + +// handleMessage dispatches a parsed IRC message. +// +//nolint:cyclop // dispatch table is inherently branchy +func (c *Conn) handleMessage( + ctx context.Context, + msg *Message, +) { + // Before registration, only NICK, USER, PASS, PING, + // QUIT, and CAP are accepted. + if !c.registered { + c.handlePreRegistration(ctx, msg) + + return + } + + switch msg.Command { + case irc.CmdPing: + c.handlePing(msg) + case "PONG": + // Silently accept. + case irc.CmdNick: + c.handleNick(ctx, msg) + case irc.CmdPrivmsg, irc.CmdNotice: + c.handlePrivmsg(ctx, msg) + case irc.CmdJoin: + c.handleJoin(ctx, msg) + case irc.CmdPart: + c.handlePart(ctx, msg) + case irc.CmdQuit: + c.handleQuit(msg) + case irc.CmdTopic: + c.handleTopic(ctx, msg) + case irc.CmdMode: + c.handleMode(ctx, msg) + case irc.CmdNames: + c.handleNames(ctx, msg) + case irc.CmdList: + c.handleList(ctx) + case irc.CmdWhois: + c.handleWhois(ctx, msg) + case irc.CmdWho: + c.handleWho(ctx, msg) + case irc.CmdLusers: + c.handleLusers(ctx) + case irc.CmdMotd: + c.deliverMOTD() + case irc.CmdOper: + c.handleOper(ctx, msg) + case irc.CmdAway: + c.handleAway(ctx, msg) + case irc.CmdKick: + c.handleKick(ctx, msg) + case irc.CmdPass: + c.handlePassPostReg(ctx, msg) + case "INVITE": + c.handleInvite(ctx, msg) + case "CAP": + c.handleCAP(msg) + case "USERHOST": + c.handleUserhost(ctx, msg) + default: + c.sendNumeric( + irc.ErrUnknownCommand, + msg.Command, "Unknown command", + ) + } +} + +// handlePreRegistration handles messages before the +// connection is registered (NICK+USER received). +func (c *Conn) handlePreRegistration( + ctx context.Context, + msg *Message, +) { + switch msg.Command { + case irc.CmdPass: + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNeedMoreParams, + "PASS", "Not enough parameters", + ) + + return + } + + c.passWord = msg.Params[0] + case irc.CmdNick: + if len(msg.Params) < 1 { + c.sendNumeric( + irc.ErrNoNicknameGiven, + "No nickname given", + ) + + return + } + + c.nick = msg.Params[0] + if len(c.nick) > maxNickLen { + c.nick = c.nick[:maxNickLen] + } + + c.gotNick = true + case irc.CmdUser: + if len(msg.Params) < 4 { //nolint:mnd + c.sendNumeric( + irc.ErrNeedMoreParams, + "USER", "Not enough parameters", + ) + + return + } + + c.username = msg.Params[0] + c.realname = msg.Params[3] + c.gotUser = true + case irc.CmdPing: + c.handlePing(msg) + + return + case irc.CmdQuit: + c.handleQuit(msg) + + return + case "CAP": + c.handleCAP(msg) + + return + default: + c.sendNumeric( + irc.ErrNotRegistered, + "You have not registered", + ) + + return + } + + // Try to complete registration once we have both + // NICK and USER. + if c.gotNick && c.gotUser { + c.completeRegistration(ctx) + } +} + +// completeRegistration creates a session and sends the +// welcome burst. +func (c *Conn) completeRegistration(ctx context.Context) { + // Check if nick is valid. + if c.nick == "" { + c.sendNumeric( + irc.ErrNoNicknameGiven, "No nickname given", + ) + + return + } + + // Create session in DB. + sessionID, clientID, _, err := c.database.CreateSession( + ctx, c.nick, c.username, c.hostname, c.remoteIP, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint") || + strings.Contains(err.Error(), "nick") { + c.sendNumeric( + irc.ErrNicknameInUse, + c.nick, "Nickname is already in use", + ) + + return + } + + c.log.Error( + "failed to create session", "error", err, + ) + c.send("ERROR :Internal server error") + c.closed = true + + return + } + + c.mu.Lock() + c.sessionID = sessionID + c.clientID = clientID + c.registered = true + c.mu.Unlock() + + // If PASS was provided before registration, set the + // session password. + if c.passWord != "" && len(c.passWord) >= minPasswordLen { + c.setPassword(ctx, c.passWord) + } + + // Send welcome burst. + c.deliverWelcome() + c.deliverLusers(ctx) + c.deliverMOTD() + + // Start the message relay goroutine. + go c.relayMessages(ctx) +} + +// deliverWelcome sends 001-005 welcome numerics. +func (c *Conn) deliverWelcome() { + c.sendNumeric(irc.RplWelcome, fmt.Sprintf( + "Welcome to the %s Network, %s", + c.serverSfx, c.hostmask(), + )) + c.sendNumeric(irc.RplYourHost, fmt.Sprintf( + "Your host is %s, running version neoirc", + c.serverSfx, + )) + c.sendNumeric( + irc.RplCreated, + "This server was created recently", + ) + c.sendNumeric( + irc.RplMyInfo, + c.serverSfx, "neoirc", "", "mnst", + ) + c.sendNumeric( + irc.RplIsupport, + "CHANTYPES=#", + "NICKLEN=32", + "PREFIX=(ov)@+", + "CHANMODES=,,H,mnst", + "NETWORK="+c.serverSfx, + "are supported by this server", + ) +} + +// deliverLusers sends 251/252/254/255 server statistics. +func (c *Conn) deliverLusers(ctx context.Context) { + users, _ := c.database.GetUserCount(ctx) + opers, _ := c.database.GetOperCount(ctx) + channels, _ := c.database.GetChannelCount(ctx) + + c.sendNumeric(irc.RplLuserClient, fmt.Sprintf( + "There are %d users and 0 invisible on 1 servers", + users, + )) + c.sendNumeric( + irc.RplLuserOp, + strconv.FormatInt(opers, 10), + "operator(s) online", + ) + c.sendNumeric( + irc.RplLuserChannels, + strconv.FormatInt(channels, 10), + "channels formed", + ) + c.sendNumeric(irc.RplLuserMe, fmt.Sprintf( + "I have %d clients and 1 servers", users, + )) +} + +// deliverMOTD sends 375/372/376 MOTD lines. +func (c *Conn) deliverMOTD() { + motd := c.cfg.MOTD + if motd == "" { + c.sendNumeric( + irc.ErrNoMotd, "MOTD File is missing", + ) + + return + } + + c.sendNumeric(irc.RplMotdStart, fmt.Sprintf( + "- %s Message of the Day -", c.serverSfx, + )) + + for _, line := range strings.Split(motd, "\n") { + c.sendNumeric(irc.RplMotd, "- "+line) + } + + c.sendNumeric( + irc.RplEndOfMotd, "End of /MOTD command", + ) +} + +// setPassword sets a bcrypt password on the session. +func (c *Conn) setPassword(ctx context.Context, pw string) { + // Use the database's auth module to hash and store. + err := c.database.SetPassword(ctx, c.sessionID, pw) + if err != nil { + c.log.Error( + "failed to set password", "error", err, + ) + } +} diff --git a/internal/ircserver/export_test.go b/internal/ircserver/export_test.go new file mode 100644 index 0000000..be355a1 --- /dev/null +++ b/internal/ircserver/export_test.go @@ -0,0 +1,43 @@ +package ircserver + +import ( + "context" + "log/slog" + "net" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" +) + +// NewTestServer creates a Server suitable for testing. +// The caller must call Stop() when finished. +func NewTestServer( + log *slog.Logger, + cfg *config.Config, + database *db.Database, + brk *broker.Broker, +) *Server { + return &Server{ //nolint:exhaustruct + log: log, + cfg: cfg, + database: database, + brk: brk, + conns: make(map[*Conn]struct{}), + } +} + +// Start exposes the unexported start method for tests. +func (s *Server) Start(addr string) error { + return s.start(context.Background(), addr) +} + +// Stop exposes the unexported stop method for tests. +func (s *Server) Stop() { + s.stop() +} + +// Listener returns the server's net.Listener for tests. +func (s *Server) Listener() net.Listener { + return s.listener +} diff --git a/internal/ircserver/parser.go b/internal/ircserver/parser.go new file mode 100644 index 0000000..5374bc0 --- /dev/null +++ b/internal/ircserver/parser.go @@ -0,0 +1,123 @@ +// Package ircserver implements a traditional IRC wire protocol +// listener (RFC 1459/2812) that bridges to the neoirc HTTP/JSON +// server internals. +package ircserver + +import "strings" + +// Message represents a parsed IRC wire protocol message. +type Message struct { + // Prefix is the optional :prefix at the start (may be + // empty for client-to-server messages). + Prefix string + // Command is the IRC command (e.g., "PRIVMSG", "NICK"). + Command string + // Params holds the positional parameters, including the + // trailing parameter (which was preceded by ':' on the + // wire). + Params []string +} + +// ParseMessage parses a single IRC wire protocol line +// (without the trailing CR-LF) into a Message. +// Returns nil if the line is empty. +// +// IRC message format (RFC 1459 §2.3.1): +// +// [":" prefix SPACE] command { SPACE param } [SPACE ":" trailing] +func ParseMessage(line string) *Message { + if line == "" { + return nil + } + + msg := &Message{} //nolint:exhaustruct // fields set below + + // Extract prefix if present. + if line[0] == ':' { + idx := strings.IndexByte(line, ' ') + if idx < 0 { + // Only a prefix, no command — invalid. + return nil + } + + msg.Prefix = line[1:idx] + line = line[idx+1:] + } + + // Skip leading spaces. + line = strings.TrimLeft(line, " ") + if line == "" { + return nil + } + + // Extract command. + idx := strings.IndexByte(line, ' ') + if idx < 0 { + msg.Command = strings.ToUpper(line) + + return msg + } + + msg.Command = strings.ToUpper(line[:idx]) + line = line[idx+1:] + + // Extract parameters. + for line != "" { + line = strings.TrimLeft(line, " ") + if line == "" { + break + } + + // Trailing parameter (everything after ':'). + if line[0] == ':' { + msg.Params = append(msg.Params, line[1:]) + + break + } + + idx = strings.IndexByte(line, ' ') + if idx < 0 { + msg.Params = append(msg.Params, line) + + break + } + + msg.Params = append(msg.Params, line[:idx]) + line = line[idx+1:] + } + + return msg +} + +// FormatMessage formats an IRC message into wire protocol +// format (without the trailing CR-LF). +func FormatMessage( + prefix, command string, + params ...string, +) string { + var buf strings.Builder + + if prefix != "" { + buf.WriteByte(':') + buf.WriteString(prefix) + buf.WriteByte(' ') + } + + buf.WriteString(command) + + for i, param := range params { + buf.WriteByte(' ') + + isLast := i == len(params)-1 + needsColon := strings.Contains(param, " ") || + param == "" || param[0] == ':' + + if isLast && needsColon { + buf.WriteByte(':') + } + + buf.WriteString(param) + } + + return buf.String() +} diff --git a/internal/ircserver/parser_test.go b/internal/ircserver/parser_test.go new file mode 100644 index 0000000..8585eab --- /dev/null +++ b/internal/ircserver/parser_test.go @@ -0,0 +1,328 @@ +package ircserver_test + +import ( + "testing" + + "git.eeqj.de/sneak/neoirc/internal/ircserver" +) + +//nolint:funlen // table-driven test +func TestParseMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want *ircserver.Message + wantNil bool + }{ + { + name: "empty", + input: "", + want: nil, + wantNil: true, + }, + { + name: "simple command", + input: "PING", + want: &ircserver.Message{ + Prefix: "", + Command: "PING", + Params: nil, + }, + wantNil: false, + }, + { + name: "command with one param", + input: "NICK alice", + want: &ircserver.Message{ + Prefix: "", + Command: "NICK", + Params: []string{"alice"}, + }, + wantNil: false, + }, + { + name: "command case insensitive", + input: "nick Alice", + want: &ircserver.Message{ + Prefix: "", + Command: "NICK", + Params: []string{"Alice"}, + }, + wantNil: false, + }, + { + name: "privmsg with trailing", + input: "PRIVMSG #general :hello world", + want: &ircserver.Message{ + Prefix: "", + Command: "PRIVMSG", + Params: []string{"#general", "hello world"}, + }, + wantNil: false, + }, + { + name: "with prefix", + input: ":server.example.com 001 alice :Welcome to IRC", + want: &ircserver.Message{ + Prefix: "server.example.com", + Command: "001", + Params: []string{"alice", "Welcome to IRC"}, + }, + wantNil: false, + }, + { + name: "user command", + input: "USER alice 0 * :Alice Smith", + want: &ircserver.Message{ + Prefix: "", + Command: "USER", + Params: []string{ + "alice", "0", "*", "Alice Smith", + }, + }, + wantNil: false, + }, + { + name: "join channel", + input: "JOIN #general", + want: &ircserver.Message{ + Prefix: "", + Command: "JOIN", + Params: []string{"#general"}, + }, + wantNil: false, + }, + { + name: "quit with trailing", + input: "QUIT :leaving now", + want: &ircserver.Message{ + Prefix: "", + Command: "QUIT", + Params: []string{"leaving now"}, + }, + wantNil: false, + }, + { + name: "quit without reason", + input: "QUIT", + want: &ircserver.Message{ + Prefix: "", + Command: "QUIT", + Params: nil, + }, + wantNil: false, + }, + { + name: "mode query", + input: "MODE #general", + want: &ircserver.Message{ + Prefix: "", + Command: "MODE", + Params: []string{"#general"}, + }, + wantNil: false, + }, + { + name: "kick with reason", + input: "KICK #general bob :misbehaving", + want: &ircserver.Message{ + Prefix: "", + Command: "KICK", + Params: []string{ + "#general", "bob", "misbehaving", + }, + }, + wantNil: false, + }, + { + name: "empty trailing", + input: "PRIVMSG #general :", + want: &ircserver.Message{ + Prefix: "", + Command: "PRIVMSG", + Params: []string{"#general", ""}, + }, + wantNil: false, + }, + { + name: "pass command", + input: "PASS mysecret", + want: &ircserver.Message{ + Prefix: "", + Command: "PASS", + Params: []string{"mysecret"}, + }, + wantNil: false, + }, + { + name: "ping with server", + input: "PING :irc.example.com", + want: &ircserver.Message{ + Prefix: "", + Command: "PING", + Params: []string{"irc.example.com"}, + }, + wantNil: false, + }, + { + name: "topic with trailing spaces", + input: "TOPIC #general :Welcome to the channel!", + want: &ircserver.Message{ + Prefix: "", + Command: "TOPIC", + Params: []string{ + "#general", + "Welcome to the channel!", + }, + }, + wantNil: false, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := ircserver.ParseMessage(testCase.input) + if testCase.wantNil { + if got != nil { + t.Fatalf("expected nil, got %+v", got) + } + + return + } + + if got == nil { + t.Fatal("expected non-nil message") + } + + if got.Prefix != testCase.want.Prefix { + t.Errorf( + "prefix: got %q, want %q", + got.Prefix, testCase.want.Prefix, + ) + } + + if got.Command != testCase.want.Command { + t.Errorf( + "command: got %q, want %q", + got.Command, testCase.want.Command, + ) + } + + if len(got.Params) != len(testCase.want.Params) { + t.Fatalf( + "params length: got %d, want %d (%v vs %v)", + len(got.Params), + len(testCase.want.Params), + got.Params, + testCase.want.Params, + ) + } + + for i, p := range got.Params { + if p != testCase.want.Params[i] { + t.Errorf( + "param[%d]: got %q, want %q", + i, p, testCase.want.Params[i], + ) + } + } + }) + } +} + +func TestFormatMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prefix string + command string + params []string + want string + }{ + { + name: "simple command", + prefix: "", + command: "PING", + params: nil, + want: "PING", + }, + { + name: "with prefix", + prefix: "server", + command: "PONG", + params: []string{"server"}, + want: ":server PONG server", + }, + { + name: "privmsg with trailing", + prefix: "alice!alice@host", + command: "PRIVMSG", + params: []string{"#general", "hello world"}, + want: ":alice!alice@host PRIVMSG #general :hello world", + }, + { + name: "numeric reply", + prefix: "server", + command: "001", + params: []string{"alice", "Welcome to IRC"}, + want: ":server 001 alice :Welcome to IRC", + }, + { + name: "empty trailing", + prefix: "server", + command: "PRIVMSG", + params: []string{"#chan", ""}, + want: ":server PRIVMSG #chan :", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := ircserver.FormatMessage( + testCase.prefix, testCase.command, testCase.params..., + ) + if got != testCase.want { + t.Errorf("got %q, want %q", got, testCase.want) + } + }) + } +} + +func TestParseFormatRoundTrip(t *testing.T) { + t.Parallel() + + // Round-trip only works for lines where the last + // parameter either contains a space (gets ':' prefix + // on format) or is a non-trailing single token. + lines := []string{ + "PING", + "NICK alice", + "PRIVMSG #general :hello world", + "JOIN #general", + "MODE #general", + } + + for _, line := range lines { + msg := ircserver.ParseMessage(line) + if msg == nil { + t.Fatalf("failed to parse: %q", line) + } + + formatted := ircserver.FormatMessage( + msg.Prefix, msg.Command, msg.Params..., + ) + if formatted != line { + t.Errorf( + "round-trip failed: input %q, got %q", + line, formatted, + ) + } + } +} diff --git a/internal/ircserver/relay.go b/internal/ircserver/relay.go new file mode 100644 index 0000000..c1f3436 --- /dev/null +++ b/internal/ircserver/relay.go @@ -0,0 +1,319 @@ +package ircserver + +import ( + "context" + "encoding/json" + "strings" + "time" + + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/pkg/irc" +) + +// relayMessages polls the client output queue and delivers +// IRC-formatted messages to the TCP connection. It runs +// in a goroutine for the lifetime of the connection. +func (c *Conn) relayMessages(ctx context.Context) { + // Use a ticker as a fallback; primary wakeup is via + // broker notification. + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + default: + } + + // Drain any available messages. + delivered := c.drainQueue(ctx) + + if delivered { + // Tight loop while there are messages. + continue + } + + // Wait for notification or timeout. + waitCh := c.brk.Wait(c.sessionID) + + select { + case <-waitCh: + // New message notification — loop back. + case <-ticker.C: + // Periodic check. + case <-ctx.Done(): + c.brk.Remove(c.sessionID, waitCh) + + return + } + } +} + +const relayPollLimit = 100 + +// drainQueue polls the output queue and delivers all +// pending messages. Returns true if at least one message +// was delivered. +func (c *Conn) drainQueue(ctx context.Context) bool { + msgs, lastID, err := c.database.PollMessages( + ctx, c.clientID, c.lastQueueID, relayPollLimit, + ) + if err != nil { + return false + } + + if len(msgs) == 0 { + return false + } + + for i := range msgs { + c.deliverIRCMessage(ctx, &msgs[i]) + } + + if lastID > c.lastQueueID { + c.lastQueueID = lastID + } + + return true +} + +// deliverIRCMessage converts a db.IRCMessage to wire +// protocol and sends it. +// +//nolint:cyclop // dispatch table +func (c *Conn) deliverIRCMessage( + _ context.Context, + msg *db.IRCMessage, +) { + command := msg.Command + + // Decode body as []string for the trailing text. + var bodyLines []string + + if msg.Body != nil { + _ = json.Unmarshal(msg.Body, &bodyLines) + } + + text := "" + if len(bodyLines) > 0 { + text = bodyLines[0] + } + + // Route by command type. + switch { + case isNumeric(command): + c.deliverNumeric(msg, text) + case command == irc.CmdPrivmsg || command == irc.CmdNotice: + c.deliverTextMessage(msg, command, text) + case command == irc.CmdJoin: + c.deliverJoin(msg) + case command == irc.CmdPart: + c.deliverPart(msg, text) + case command == irc.CmdNick: + c.deliverNickChange(msg, text) + case command == irc.CmdQuit: + c.deliverQuitMsg(msg, text) + case command == irc.CmdTopic: + c.deliverTopicChange(msg, text) + case command == irc.CmdKick: + c.deliverKickMsg(msg, text) + case command == "INVITE": + c.deliverInviteMsg(msg, text) + case command == irc.CmdMode: + c.deliverMode(msg, text) + case command == irc.CmdPing: + // Server-originated PING — reply with PONG. + c.sendFromServer("PING", c.serverSfx) + default: + // Unknown command — deliver as server notice. + if text != "" { + c.sendFromServer("NOTICE", c.nick, text) + } + } +} + +// isNumeric returns true if the command is a 3-digit +// numeric code. +func isNumeric(cmd string) bool { + return len(cmd) == 3 && + cmd[0] >= '0' && cmd[0] <= '9' && + cmd[1] >= '0' && cmd[1] <= '9' && + cmd[2] >= '0' && cmd[2] <= '9' +} + +// deliverNumeric sends a numeric reply. +func (c *Conn) deliverNumeric( + msg *db.IRCMessage, + text string, +) { + from := msg.From + if from == "" { + from = c.serverSfx + } + + var params []string + + if msg.Params != nil { + _ = json.Unmarshal(msg.Params, ¶ms) + } + + allParams := make([]string, 0, 1+len(params)+1) + allParams = append(allParams, c.nick) + allParams = append(allParams, params...) + + if text != "" { + allParams = append(allParams, text) + } + + c.send(FormatMessage(from, msg.Command, allParams...)) +} + +// deliverTextMessage sends PRIVMSG or NOTICE. +func (c *Conn) deliverTextMessage( + msg *db.IRCMessage, + command, text string, +) { + from := msg.From + target := msg.To + + // Don't echo our own messages back. + if strings.EqualFold(from, c.nick) { + return + } + + prefix := from + if !strings.Contains(prefix, "!") { + prefix = from + "!" + from + "@*" + } + + c.send(FormatMessage(prefix, command, target, text)) +} + +// deliverJoin sends a JOIN notification. +func (c *Conn) deliverJoin(msg *db.IRCMessage) { + // Don't echo our own JOINs (we already sent them + // during joinChannel). + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + c.send(FormatMessage(prefix, "JOIN", channel)) +} + +// deliverPart sends a PART notification. +func (c *Conn) deliverPart(msg *db.IRCMessage, text string) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + if text != "" { + c.send(FormatMessage( + prefix, "PART", channel, text, + )) + } else { + c.send(FormatMessage(prefix, "PART", channel)) + } +} + +// deliverNickChange sends a NICK change notification. +func (c *Conn) deliverNickChange( + msg *db.IRCMessage, + newNick string, +) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + + c.send(FormatMessage(prefix, "NICK", newNick)) +} + +// deliverQuitMsg sends a QUIT notification. +func (c *Conn) deliverQuitMsg( + msg *db.IRCMessage, + text string, +) { + if strings.EqualFold(msg.From, c.nick) { + return + } + + prefix := msg.From + "!" + msg.From + "@*" + + if text != "" { + c.send(FormatMessage( + prefix, "QUIT", "Quit: "+text, + )) + } else { + c.send(FormatMessage(prefix, "QUIT", "Quit")) + } +} + +// deliverTopicChange sends a TOPIC change notification. +func (c *Conn) deliverTopicChange( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + c.send(FormatMessage(prefix, "TOPIC", channel, text)) +} + +// deliverKickMsg sends a KICK notification. +func (c *Conn) deliverKickMsg( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + channel := msg.To + + var params []string + + if msg.Params != nil { + _ = json.Unmarshal(msg.Params, ¶ms) + } + + kickTarget := "" + if len(params) > 0 { + kickTarget = params[0] + } + + if kickTarget != "" { + c.send(FormatMessage( + prefix, "KICK", channel, kickTarget, text, + )) + } else { + c.send(FormatMessage( + prefix, "KICK", channel, "?", text, + )) + } +} + +// deliverInviteMsg sends an INVITE notification. +func (c *Conn) deliverInviteMsg( + _ *db.IRCMessage, + text string, +) { + c.sendFromServer("NOTICE", c.nick, text) +} + +// deliverMode sends a MODE change notification. +func (c *Conn) deliverMode( + msg *db.IRCMessage, + text string, +) { + prefix := msg.From + "!" + msg.From + "@*" + target := msg.To + + if text != "" { + c.send(FormatMessage(prefix, "MODE", target, text)) + } +} diff --git a/internal/ircserver/server.go b/internal/ircserver/server.go new file mode 100644 index 0000000..56e3b7a --- /dev/null +++ b/internal/ircserver/server.go @@ -0,0 +1,153 @@ +package ircserver + +import ( + "context" + "fmt" + "log/slog" + "net" + "sync" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/logger" + "go.uber.org/fx" +) + +// Params defines the dependencies for creating an IRC +// Server. +type Params struct { + fx.In + + Logger *logger.Logger + Config *config.Config + Database *db.Database + Broker *broker.Broker +} + +// Server is the TCP IRC protocol server. +type Server struct { + log *slog.Logger + cfg *config.Config + database *db.Database + brk *broker.Broker + listener net.Listener + mu sync.Mutex + conns map[*Conn]struct{} + cancel context.CancelFunc +} + +// New creates a new IRC Server and registers its lifecycle +// hooks. The listener is only started if IRC_LISTEN_ADDR +// is configured; otherwise the server is inert. +func New( + lifecycle fx.Lifecycle, + params Params, +) *Server { + srv := &Server{ + log: params.Logger.Get(), + cfg: params.Config, + database: params.Database, + brk: params.Broker, + conns: make(map[*Conn]struct{}), + listener: nil, + cancel: nil, + mu: sync.Mutex{}, + } + + listenAddr := params.Config.IRCListenAddr + if listenAddr == "" { + return srv + } + + lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return srv.start(ctx, listenAddr) + }, + OnStop: func(_ context.Context) error { + srv.stop() + + return nil + }, + }) + + return srv +} + +// start begins listening for TCP connections. +// +//nolint:contextcheck // long-lived server ctx, not the short Fx one +func (s *Server) start(_ context.Context, addr string) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("irc listen: %w", err) + } + + s.listener = ln + + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + s.log.Info( + "irc server listening", "addr", addr, + ) + + go s.acceptLoop(ctx) + + return nil +} + +// stop shuts down the listener and all connections. +func (s *Server) stop() { + if s.cancel != nil { + s.cancel() + } + + if s.listener != nil { + s.listener.Close() //nolint:errcheck,gosec + } + + s.mu.Lock() + for c := range s.conns { + c.conn.Close() //nolint:errcheck,gosec + } + s.mu.Unlock() +} + +// acceptLoop accepts new connections. +func (s *Server) acceptLoop(ctx context.Context) { + for { + tcpConn, err := s.listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + return + default: + s.log.Error( + "irc accept error", "error", err, + ) + + continue + } + } + + client := newConn( + ctx, tcpConn, s.log, + s.database, s.brk, s.cfg, + ) + + s.mu.Lock() + s.conns[client] = struct{}{} + s.mu.Unlock() + + go func() { + defer func() { + s.mu.Lock() + delete(s.conns, client) + s.mu.Unlock() + }() + + client.serve(ctx) + }() + } +} diff --git a/internal/ircserver/server_test.go b/internal/ircserver/server_test.go new file mode 100644 index 0000000..fb21e6b --- /dev/null +++ b/internal/ircserver/server_test.go @@ -0,0 +1,625 @@ +package ircserver_test + +import ( + "bufio" + "database/sql" + "fmt" + "log/slog" + "net" + "os" + "strings" + "testing" + "time" + + "git.eeqj.de/sneak/neoirc/internal/broker" + "git.eeqj.de/sneak/neoirc/internal/config" + "git.eeqj.de/sneak/neoirc/internal/db" + "git.eeqj.de/sneak/neoirc/internal/ircserver" + + _ "modernc.org/sqlite" +) + +const testTimeout = 5 * time.Second + +func TestMain(m *testing.M) { + db.SetBcryptCost(4) + + os.Exit(m.Run()) +} + +// testEnv holds the shared test infrastructure. +type testEnv struct { + database *db.Database + brk *broker.Broker + cfg *config.Config + srv *ircserver.Server +} + +func newTestEnv(t *testing.T) *testEnv { + t.Helper() + + dsn := fmt.Sprintf( + "file:%s?mode=memory&cache=shared&_journal_mode=WAL", + t.Name(), + ) + + conn, err := sql.Open("sqlite", dsn) + if err != nil { + t.Fatalf("open db: %v", err) + } + + conn.SetMaxOpenConns(1) + + _, err = conn.ExecContext( + t.Context(), "PRAGMA foreign_keys = ON", + ) + if err != nil { + t.Fatalf("pragma: %v", err) + } + + database := db.NewTestDatabaseFromConn(conn) + + err = database.RunMigrations(t.Context()) + if err != nil { + t.Fatalf("migrate: %v", err) + } + + brk := broker.New() + + cfg := &config.Config{ //nolint:exhaustruct + ServerName: "test.irc", + MOTD: "Welcome to test IRC", + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + addr := listener.Addr().String() + + err = listener.Close() + if err != nil { + t.Fatalf("close listener: %v", err) + } + + log := slog.New(slog.NewTextHandler( + os.Stderr, + &slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct + )) + + srv := ircserver.NewTestServer(log, cfg, database, brk) + + err = srv.Start(addr) + if err != nil { + t.Fatalf("start irc server: %v", err) + } + + t.Cleanup(func() { + srv.Stop() + + err := conn.Close() + if err != nil { + t.Logf("close db: %v", err) + } + }) + + return &testEnv{ + database: database, + brk: brk, + cfg: cfg, + srv: srv, + } +} + +// dial connects to the test server. +func (env *testEnv) dial(t *testing.T) *testClient { + t.Helper() + + conn, err := net.DialTimeout( + "tcp", + env.srv.Listener().Addr().String(), + testTimeout, + ) + if err != nil { + t.Fatalf("dial: %v", err) + } + + t.Cleanup(func() { + err := conn.Close() + if err != nil { + t.Logf("close conn: %v", err) + } + }) + + return &testClient{ + t: t, + conn: conn, + scanner: bufio.NewScanner(conn), + } +} + +// testClient wraps a raw TCP connection with helpers. +type testClient struct { + t *testing.T + conn net.Conn + scanner *bufio.Scanner +} + +func (tc *testClient) send(line string) { + tc.t.Helper() + + _ = tc.conn.SetWriteDeadline( + time.Now().Add(testTimeout), + ) + + _, err := fmt.Fprintf(tc.conn, "%s\r\n", line) + if err != nil { + tc.t.Fatalf("send: %v", err) + } +} + +func (tc *testClient) readLine() string { + tc.t.Helper() + + _ = tc.conn.SetReadDeadline( + time.Now().Add(testTimeout), + ) + + if !tc.scanner.Scan() { + err := tc.scanner.Err() + if err != nil { + tc.t.Fatalf("read: %v", err) + } + + tc.t.Fatal("connection closed unexpectedly") + } + + return tc.scanner.Text() +} + +// readUntil reads lines until one matches the predicate. +func (tc *testClient) readUntil( + pred func(string) bool, +) []string { + tc.t.Helper() + + var lines []string + + for { + line := tc.readLine() + lines = append(lines, line) + + if pred(line) { + return lines + } + } +} + +// register sends NICK + USER and reads through the welcome +// burst. +func (tc *testClient) register(nick string) []string { + tc.t.Helper() + + tc.send("NICK " + nick) + tc.send("USER " + nick + " 0 * :Test User") + + return tc.readUntil(func(line string) bool { + return strings.Contains(line, " 376 ") || + strings.Contains(line, " 422 ") + }) +} + +// assertContains checks that at least one line matches the +// given substring. +func assertContains( + t *testing.T, + lines []string, + substr, description string, +) { + t.Helper() + + for _, line := range lines { + if strings.Contains(line, substr) { + return + } + } + + t.Errorf("did not find %q in output: %s", substr, description) +} + +// joinAndDrain joins a channel and reads until +// RPL_ENDOFNAMES. +func (tc *testClient) joinAndDrain(channel string) { + tc.t.Helper() + + tc.send("JOIN " + channel) + + tc.readUntil(func(line string) bool { + return strings.Contains(line, " 366 ") + }) +} + +// sendAndExpect sends a command and reads until a line +// containing the expected substring is found. +func (tc *testClient) sendAndExpect( + cmd, expect string, +) []string { + tc.t.Helper() + + tc.send(cmd) + + return tc.readUntil(func(line string) bool { + return strings.Contains(line, expect) + }) +} + +func TestRegistration(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + lines := client.register("alice") + assertContains(t, lines, " 001 ", "RPL_WELCOME") +} + +func TestWelcomeContainsNick(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + lines := client.register("bob") + + for _, line := range lines { + if strings.Contains(line, " 001 ") && + !strings.Contains(line, "bob") { + t.Errorf("001 should contain nick: %s", line) + } + } +} + +func TestPingPong(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("pingtest") + lines := client.sendAndExpect("PING :hello", "PONG") + assertContains(t, lines, "PONG", "PONG response") +} + +func TestJoinChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("joiner") + client.send("JOIN #test") + + lines := client.readUntil(func(line string) bool { + return strings.Contains(line, " 366 ") + }) + + assertContains(t, lines, "JOIN", "JOIN echo") + assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES") +} + +func TestPrivmsgBetweenClients(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + alice := env.dial(t) + alice.register("alice_pm") + + bob := env.dial(t) + bob.register("bob_pm") + + alice.joinAndDrain("#chat") + bob.joinAndDrain("#chat") + + alice.send("PRIVMSG #chat :hello bob!") + lines := bob.sendAndExpect("PING :sync", "hello bob!") + assertContains(t, lines, "hello bob!", "channel PRIVMSG") +} + +func TestNickChange(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("oldnick") + lines := client.sendAndExpect("NICK newnick", "newnick") + assertContains(t, lines, "NICK", "NICK change echo") +} + +func TestDuplicateNick(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + first := env.dial(t) + first.register("taken") + + second := env.dial(t) + second.send("NICK taken") + second.send("USER taken 0 * :Test") + + lines := second.readUntil(func(line string) bool { + return strings.Contains(line, " 433 ") + }) + + assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE") +} + +func TestListChannels(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("lister") + client.joinAndDrain("#listtest") + lines := client.sendAndExpect("LIST", " 323 ") + assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term +} + +func TestWhois(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("whoistest") + lines := client.sendAndExpect( + "WHOIS whoistest", " 318 ", + ) + + assertContains(t, lines, " 311 ", "RPL_WHOISUSER") + assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS") +} + +func TestQuit(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("quitter") + lines := client.sendAndExpect( + "QUIT :goodbye", "ERROR", + ) + + assertContains(t, lines, "goodbye", "QUIT reason") +} + +func TestTopicSetAndGet(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("topicuser") + client.joinAndDrain("#topictest") + lines := client.sendAndExpect( + "TOPIC #topictest :New topic here", + "New topic here", + ) + + assertContains( + t, lines, "New topic here", "TOPIC echo", + ) +} + +func TestUnknownCommand(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("unknowncmd") + lines := client.sendAndExpect("FOOBAR", " 421 ") + assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND") +} + +func TestDirectMessage(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + + sender := env.dial(t) + sender.register("dmsender") + + receiver := env.dial(t) + receiver.register("dmreceiver") + + // Give relay goroutines time to start. + time.Sleep(100 * time.Millisecond) + + sender.send("PRIVMSG dmreceiver :hello privately") + + lines := receiver.readUntil(func(line string) bool { + return strings.Contains(line, "hello privately") + }) + + assertContains( + t, lines, "hello privately", "direct PRIVMSG", + ) +} + +func TestCAPNegotiation(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.send("CAP LS 302") + + line := client.readLine() + if !strings.Contains(line, "CAP") { + t.Errorf("expected CAP response, got: %s", line) + } + + client.send("CAP END") + lines := client.register("capuser") + assertContains(t, lines, " 001 ", "registration after CAP") +} + +func TestPartChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("parter") + client.joinAndDrain("#parttest") + lines := client.sendAndExpect( + "PART #parttest :leaving", "PART", + ) + + assertContains(t, lines, "#parttest", "PART echo") +} + +func TestModeQuery(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("modeuser") + client.joinAndDrain("#modetest") + lines := client.sendAndExpect( + "MODE #modetest", " 324 ", + ) + + assertContains( + t, lines, " 324 ", "RPL_CHANNELMODEIS", + ) +} + +func TestWhoChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("whouser") + client.joinAndDrain("#whotest") + lines := client.sendAndExpect("WHO #whotest", " 315 ") + assertContains(t, lines, " 352 ", "RPL_WHOREPLY") +} + +func TestLusers(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("luseruser") + lines := client.sendAndExpect("LUSERS", " 255 ") + assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT") +} + +func TestMotd(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("motduser") + lines := client.sendAndExpect("MOTD", " 376 ") + assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD") +} + +func TestAwaySetAndClear(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("awayuser") + + setLines := client.sendAndExpect( + "AWAY :brb lunch", " 306 ", + ) + assertContains(t, setLines, " 306 ", "RPL_NOWAWAY") + + clearLines := client.sendAndExpect("AWAY", " 305 ") + assertContains(t, clearLines, " 305 ", "RPL_UNAWAY") +} + +func TestHandlePassPostRegistration(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("passuser") + lines := client.sendAndExpect( + "PASS :mypassword123", "Password set", + ) + + assertContains( + t, lines, "Password set", "password confirmation", + ) +} + +func TestPreRegistrationNotRegistered(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.send("PRIVMSG #test :hello") + + line := client.readLine() + if !strings.Contains(line, " 451 ") { + t.Errorf( + "expected ERR_NOTREGISTERED (451), got: %s", + line, + ) + } +} + +func TestNamesNonExistentChannel(t *testing.T) { + t.Parallel() + + env := newTestEnv(t) + client := env.dial(t) + + client.register("namesuser") + lines := client.sendAndExpect( + "NAMES #doesnotexist", " 366 ", + ) + + assertContains( + t, lines, " 366 ", + "RPL_ENDOFNAMES for non-existent channel", + ) +} + +func BenchmarkParseMessage(b *testing.B) { + line := ":nick!user@host PRIVMSG #channel :Hello, world!" + + b.ResetTimer() + + for range b.N { + _ = ircserver.ParseMessage(line) + } +} + +func BenchmarkFormatMessage(b *testing.B) { + b.ResetTimer() + + for range b.N { + _ = ircserver.FormatMessage( + "nick!user@host", "PRIVMSG", + "#channel", "Hello, world!", + ) + } +} diff --git a/pkg/irc/commands.go b/pkg/irc/commands.go index 8b79e7d..60ef85e 100644 --- a/pkg/irc/commands.go +++ b/pkg/irc/commands.go @@ -3,6 +3,7 @@ package irc // IRC command names (RFC 1459 / RFC 2812). const ( CmdAway = "AWAY" + CmdInvite = "INVITE" CmdJoin = "JOIN" CmdKick = "KICK" CmdList = "LIST" @@ -20,6 +21,7 @@ const ( CmdPrivmsg = "PRIVMSG" CmdQuit = "QUIT" CmdTopic = "TOPIC" + CmdUser = "USER" CmdWho = "WHO" CmdWhois = "WHOIS" )