feat: add IRC wire protocol listener with shared service layer
Some checks failed
check / check (push) Failing after 46s

Add a traditional IRC wire protocol listener (RFC 1459/2812) on
configurable port (default :6667), sharing business logic with
the HTTP API via a new service layer.

- IRC listener: NICK, USER, PASS, JOIN, PART, PRIVMSG, NOTICE,
  TOPIC, MODE, KICK, QUIT, NAMES, LIST, WHOIS, WHO, AWAY, OPER,
  INVITE, LUSERS, MOTD, PING/PONG, CAP
- Service layer: shared logic for both transports including
  channel join (with Tier 2 checks: ban/invite/key/limit),
  message send (with ban + moderation checks), nick change,
  topic, kick, mode, quit broadcast, away, oper, invite
- BroadcastQuit uses FanOut pattern (one insert, N enqueues)
- HTTP handlers delegate to service for all command logic
- Tier 2 mode operations (+b/+i/+s/+k/+l) use service methods

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 18:01:36 -07:00
parent 9a79d92c0d
commit 92d5145ac6
19 changed files with 4795 additions and 1112 deletions

File diff suppressed because it is too large Load Diff

501
internal/ircserver/conn.go Normal file
View File

@@ -0,0 +1,501 @@
package ircserver
import (
"bufio"
"context"
"fmt"
"log/slog"
"net"
"strconv"
"strings"
"sync"
"time"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
const (
maxLineLen = 512
readTimeout = 5 * time.Minute
writeTimeout = 30 * time.Second
dnsTimeout = 3 * time.Second
pollInterval = 100 * time.Millisecond
pingInterval = 90 * time.Second
pongDeadline = 30 * time.Second
maxNickLen = 32
minPasswordLen = 8
)
// cmdHandler is the signature for registered IRC command
// handlers.
type cmdHandler func(ctx context.Context, msg *Message)
// Conn represents a single IRC client TCP connection.
type Conn struct {
conn net.Conn
log *slog.Logger
database *db.Database
brk *broker.Broker
cfg *config.Config
svc *service.Service
serverSfx string
commands map[string]cmdHandler
mu sync.Mutex
nick string
username string
realname string
hostname string
remoteIP string
sessionID int64
clientID int64
registered bool
gotNick bool
gotUser bool
passWord string
lastQueueID int64
closed bool
cancel context.CancelFunc
}
func newConn(
ctx context.Context,
tcpConn net.Conn,
log *slog.Logger,
database *db.Database,
brk *broker.Broker,
cfg *config.Config,
svc *service.Service,
) *Conn {
host, _, _ := net.SplitHostPort(tcpConn.RemoteAddr().String())
srvName := cfg.ServerName
if srvName == "" {
srvName = "neoirc"
}
conn := &Conn{ //nolint:exhaustruct // zero-value defaults
conn: tcpConn,
log: log,
database: database,
brk: brk,
cfg: cfg,
svc: svc,
serverSfx: srvName,
remoteIP: host,
hostname: resolveHost(ctx, host),
}
conn.commands = conn.buildCommandMap()
return conn
}
// buildCommandMap returns a map from IRC command strings
// to handler functions.
func (c *Conn) buildCommandMap() map[string]cmdHandler {
return map[string]cmdHandler{
irc.CmdPing: func(_ context.Context, msg *Message) {
c.handlePing(msg)
},
"PONG": func(context.Context, *Message) {},
irc.CmdNick: c.handleNick,
irc.CmdPrivmsg: c.handlePrivmsg,
irc.CmdNotice: c.handlePrivmsg,
irc.CmdJoin: c.handleJoin,
irc.CmdPart: c.handlePart,
irc.CmdQuit: func(_ context.Context, msg *Message) {
c.handleQuit(msg)
},
irc.CmdTopic: c.handleTopic,
irc.CmdMode: c.handleMode,
irc.CmdNames: c.handleNames,
irc.CmdList: func(ctx context.Context, _ *Message) { c.handleList(ctx) },
irc.CmdWhois: c.handleWhois,
irc.CmdWho: c.handleWho,
irc.CmdLusers: func(ctx context.Context, _ *Message) { c.handleLusers(ctx) },
irc.CmdMotd: func(context.Context, *Message) { c.deliverMOTD() },
irc.CmdOper: c.handleOper,
irc.CmdAway: c.handleAway,
irc.CmdKick: c.handleKick,
irc.CmdPass: c.handlePassPostReg,
"INVITE": c.handleInvite,
"CAP": func(_ context.Context, msg *Message) {
c.handleCAP(msg)
},
"USERHOST": c.handleUserhost,
}
}
// resolveHost does a reverse DNS lookup, returning the IP
// on failure.
func resolveHost(ctx context.Context, addr string) string {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
resolver := &net.Resolver{} //nolint:exhaustruct
names, err := resolver.LookupAddr(ctx, addr)
if err != nil || len(names) == 0 {
return addr
}
return strings.TrimSuffix(names[0], ".")
}
// serve is the main loop for a single IRC client connection.
func (c *Conn) serve(ctx context.Context) {
ctx, c.cancel = context.WithCancel(ctx)
defer c.cleanup(ctx)
scanner := bufio.NewScanner(c.conn)
scanner.Buffer(make([]byte, maxLineLen), maxLineLen)
for {
_ = c.conn.SetReadDeadline(
time.Now().Add(readTimeout),
)
if !scanner.Scan() {
return
}
line := scanner.Text()
if line == "" {
continue
}
msg := ParseMessage(line)
if msg == nil {
continue
}
c.handleMessage(ctx, msg)
if c.closed {
return
}
}
}
func (c *Conn) cleanup(ctx context.Context) {
c.mu.Lock()
wasRegistered := c.registered
sessID := c.sessionID
nick := c.nick
c.closed = true
c.mu.Unlock()
if wasRegistered && sessID > 0 {
c.svc.BroadcastQuit(
ctx, sessID, nick, "Connection closed",
)
}
c.conn.Close() //nolint:errcheck,gosec
}
// send writes a formatted IRC line to the connection.
func (c *Conn) send(line string) {
_ = c.conn.SetWriteDeadline(
time.Now().Add(writeTimeout),
)
_, _ = fmt.Fprintf(c.conn, "%s\r\n", line)
}
// sendNumeric sends a numeric reply from the server.
func (c *Conn) sendNumeric(
code irc.IRCMessageType,
params ...string,
) {
nick := c.nick
if nick == "" {
nick = "*"
}
allParams := make([]string, 0, 1+len(params))
allParams = append(allParams, nick)
allParams = append(allParams, params...)
c.send(FormatMessage(
c.serverSfx, code.Code(), allParams...,
))
}
// sendFromServer sends a message from the server.
func (c *Conn) sendFromServer(
command string, params ...string,
) {
c.send(FormatMessage(c.serverSfx, command, params...))
}
// hostmask returns the client's full hostmask
// (nick!user@host).
func (c *Conn) hostmask() string {
user := c.username
if user == "" {
user = c.nick
}
host := c.hostname
if host == "" {
host = c.remoteIP
}
return c.nick + "!" + user + "@" + host
}
// handleMessage dispatches a parsed IRC message using
// the command handler map.
func (c *Conn) handleMessage(
ctx context.Context,
msg *Message,
) {
// Before registration, only NICK, USER, PASS, PING,
// QUIT, and CAP are accepted.
if !c.registered {
c.handlePreRegistration(ctx, msg)
return
}
handler, ok := c.commands[msg.Command]
if !ok {
c.sendNumeric(
irc.ErrUnknownCommand,
msg.Command, "Unknown command",
)
return
}
handler(ctx, msg)
}
// handlePreRegistration handles messages before the
// connection is registered (NICK+USER received).
func (c *Conn) handlePreRegistration(
ctx context.Context,
msg *Message,
) {
switch msg.Command {
case irc.CmdPass:
if len(msg.Params) < 1 {
c.sendNumeric(
irc.ErrNeedMoreParams,
"PASS", "Not enough parameters",
)
return
}
c.passWord = msg.Params[0]
case irc.CmdNick:
if len(msg.Params) < 1 {
c.sendNumeric(
irc.ErrNoNicknameGiven,
"No nickname given",
)
return
}
c.nick = msg.Params[0]
if len(c.nick) > maxNickLen {
c.nick = c.nick[:maxNickLen]
}
c.gotNick = true
case irc.CmdUser:
if len(msg.Params) < 4 { //nolint:mnd
c.sendNumeric(
irc.ErrNeedMoreParams,
"USER", "Not enough parameters",
)
return
}
c.username = msg.Params[0]
c.realname = msg.Params[3]
c.gotUser = true
case irc.CmdPing:
c.handlePing(msg)
return
case irc.CmdQuit:
c.handleQuit(msg)
return
case "CAP":
c.handleCAP(msg)
return
default:
c.sendNumeric(
irc.ErrNotRegistered,
"You have not registered",
)
return
}
// Try to complete registration once we have both
// NICK and USER.
if c.gotNick && c.gotUser {
c.completeRegistration(ctx)
}
}
// completeRegistration creates a session and sends the
// welcome burst.
func (c *Conn) completeRegistration(ctx context.Context) {
// Check if nick is valid.
if c.nick == "" {
c.sendNumeric(
irc.ErrNoNicknameGiven, "No nickname given",
)
return
}
// Create session in DB.
sessionID, clientID, _, err := c.database.CreateSession(
ctx, c.nick, c.username, c.hostname, c.remoteIP,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") ||
strings.Contains(err.Error(), "nick") {
c.sendNumeric(
irc.ErrNicknameInUse,
c.nick, "Nickname is already in use",
)
return
}
c.log.Error(
"failed to create session", "error", err,
)
c.send("ERROR :Internal server error")
c.closed = true
return
}
c.mu.Lock()
c.sessionID = sessionID
c.clientID = clientID
c.registered = true
c.mu.Unlock()
// If PASS was provided before registration, set the
// session password.
if c.passWord != "" && len(c.passWord) >= minPasswordLen {
c.setPassword(ctx, c.passWord)
}
// Send welcome burst.
c.deliverWelcome()
c.deliverLusers(ctx)
c.deliverMOTD()
// Start the message relay goroutine.
go c.relayMessages(ctx)
}
// deliverWelcome sends 001-005 welcome numerics.
func (c *Conn) deliverWelcome() {
c.sendNumeric(irc.RplWelcome, fmt.Sprintf(
"Welcome to the %s Network, %s",
c.serverSfx, c.hostmask(),
))
c.sendNumeric(irc.RplYourHost, fmt.Sprintf(
"Your host is %s, running version neoirc",
c.serverSfx,
))
c.sendNumeric(
irc.RplCreated,
"This server was created recently",
)
c.sendNumeric(
irc.RplMyInfo,
c.serverSfx, "neoirc", "", "mnst",
)
c.sendNumeric(
irc.RplIsupport,
"CHANTYPES=#",
"NICKLEN=32",
"PREFIX=(ov)@+",
"CHANMODES=,,H,mnst",
"NETWORK="+c.serverSfx,
"are supported by this server",
)
}
// deliverLusers sends 251/252/254/255 server statistics.
func (c *Conn) deliverLusers(ctx context.Context) {
users, _ := c.database.GetUserCount(ctx)
opers, _ := c.database.GetOperCount(ctx)
channels, _ := c.database.GetChannelCount(ctx)
c.sendNumeric(irc.RplLuserClient, fmt.Sprintf(
"There are %d users and 0 invisible on 1 servers",
users,
))
c.sendNumeric(
irc.RplLuserOp,
strconv.FormatInt(opers, 10),
"operator(s) online",
)
c.sendNumeric(
irc.RplLuserChannels,
strconv.FormatInt(channels, 10),
"channels formed",
)
c.sendNumeric(irc.RplLuserMe, fmt.Sprintf(
"I have %d clients and 1 servers", users,
))
}
// deliverMOTD sends 375/372/376 MOTD lines.
func (c *Conn) deliverMOTD() {
motd := c.cfg.MOTD
if motd == "" {
c.sendNumeric(
irc.ErrNoMotd, "MOTD File is missing",
)
return
}
c.sendNumeric(irc.RplMotdStart, fmt.Sprintf(
"- %s Message of the Day -", c.serverSfx,
))
for _, line := range strings.Split(motd, "\n") {
c.sendNumeric(irc.RplMotd, "- "+line)
}
c.sendNumeric(
irc.RplEndOfMotd, "End of /MOTD command",
)
}
// setPassword sets a bcrypt password on the session.
func (c *Conn) setPassword(ctx context.Context, pw string) {
// Use the database's auth module to hash and store.
err := c.database.SetPassword(ctx, c.sessionID, pw)
if err != nil {
c.log.Error(
"failed to set password", "error", err,
)
}
}

View File

@@ -0,0 +1,52 @@
package ircserver
import (
"context"
"log/slog"
"net"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/service"
)
// NewTestServer creates a Server suitable for testing.
// The caller must call Stop() when finished.
func NewTestServer(
log *slog.Logger,
cfg *config.Config,
database *db.Database,
brk *broker.Broker,
) *Server {
svc := &service.Service{
DB: database,
Broker: brk,
Config: cfg,
Log: log,
}
return &Server{ //nolint:exhaustruct
log: log,
cfg: cfg,
database: database,
brk: brk,
svc: svc,
conns: make(map[*Conn]struct{}),
}
}
// Start exposes the unexported start method for tests.
func (s *Server) Start(addr string) error {
return s.start(context.Background(), addr)
}
// Stop exposes the unexported stop method for tests.
func (s *Server) Stop() {
s.stop()
}
// Listener returns the server's net.Listener for tests.
func (s *Server) Listener() net.Listener {
return s.listener
}

View File

@@ -0,0 +1,123 @@
// Package ircserver implements a traditional IRC wire protocol
// listener (RFC 1459/2812) that bridges to the neoirc HTTP/JSON
// server internals.
package ircserver
import "strings"
// Message represents a parsed IRC wire protocol message.
type Message struct {
// Prefix is the optional :prefix at the start (may be
// empty for client-to-server messages).
Prefix string
// Command is the IRC command (e.g., "PRIVMSG", "NICK").
Command string
// Params holds the positional parameters, including the
// trailing parameter (which was preceded by ':' on the
// wire).
Params []string
}
// ParseMessage parses a single IRC wire protocol line
// (without the trailing CR-LF) into a Message.
// Returns nil if the line is empty.
//
// IRC message format (RFC 1459 §2.3.1):
//
// [":" prefix SPACE] command { SPACE param } [SPACE ":" trailing]
func ParseMessage(line string) *Message {
if line == "" {
return nil
}
msg := &Message{} //nolint:exhaustruct // fields set below
// Extract prefix if present.
if line[0] == ':' {
idx := strings.IndexByte(line, ' ')
if idx < 0 {
// Only a prefix, no command — invalid.
return nil
}
msg.Prefix = line[1:idx]
line = line[idx+1:]
}
// Skip leading spaces.
line = strings.TrimLeft(line, " ")
if line == "" {
return nil
}
// Extract command.
idx := strings.IndexByte(line, ' ')
if idx < 0 {
msg.Command = strings.ToUpper(line)
return msg
}
msg.Command = strings.ToUpper(line[:idx])
line = line[idx+1:]
// Extract parameters.
for line != "" {
line = strings.TrimLeft(line, " ")
if line == "" {
break
}
// Trailing parameter (everything after ':').
if line[0] == ':' {
msg.Params = append(msg.Params, line[1:])
break
}
idx = strings.IndexByte(line, ' ')
if idx < 0 {
msg.Params = append(msg.Params, line)
break
}
msg.Params = append(msg.Params, line[:idx])
line = line[idx+1:]
}
return msg
}
// FormatMessage formats an IRC message into wire protocol
// format (without the trailing CR-LF).
func FormatMessage(
prefix, command string,
params ...string,
) string {
var buf strings.Builder
if prefix != "" {
buf.WriteByte(':')
buf.WriteString(prefix)
buf.WriteByte(' ')
}
buf.WriteString(command)
for i, param := range params {
buf.WriteByte(' ')
isLast := i == len(params)-1
needsColon := strings.Contains(param, " ") ||
param == "" || param[0] == ':'
if isLast && needsColon {
buf.WriteByte(':')
}
buf.WriteString(param)
}
return buf.String()
}

View File

@@ -0,0 +1,328 @@
package ircserver_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
)
//nolint:funlen // table-driven test
func TestParseMessage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want *ircserver.Message
wantNil bool
}{
{
name: "empty",
input: "",
want: nil,
wantNil: true,
},
{
name: "simple command",
input: "PING",
want: &ircserver.Message{
Prefix: "",
Command: "PING",
Params: nil,
},
wantNil: false,
},
{
name: "command with one param",
input: "NICK alice",
want: &ircserver.Message{
Prefix: "",
Command: "NICK",
Params: []string{"alice"},
},
wantNil: false,
},
{
name: "command case insensitive",
input: "nick Alice",
want: &ircserver.Message{
Prefix: "",
Command: "NICK",
Params: []string{"Alice"},
},
wantNil: false,
},
{
name: "privmsg with trailing",
input: "PRIVMSG #general :hello world",
want: &ircserver.Message{
Prefix: "",
Command: "PRIVMSG",
Params: []string{"#general", "hello world"},
},
wantNil: false,
},
{
name: "with prefix",
input: ":server.example.com 001 alice :Welcome to IRC",
want: &ircserver.Message{
Prefix: "server.example.com",
Command: "001",
Params: []string{"alice", "Welcome to IRC"},
},
wantNil: false,
},
{
name: "user command",
input: "USER alice 0 * :Alice Smith",
want: &ircserver.Message{
Prefix: "",
Command: "USER",
Params: []string{
"alice", "0", "*", "Alice Smith",
},
},
wantNil: false,
},
{
name: "join channel",
input: "JOIN #general",
want: &ircserver.Message{
Prefix: "",
Command: "JOIN",
Params: []string{"#general"},
},
wantNil: false,
},
{
name: "quit with trailing",
input: "QUIT :leaving now",
want: &ircserver.Message{
Prefix: "",
Command: "QUIT",
Params: []string{"leaving now"},
},
wantNil: false,
},
{
name: "quit without reason",
input: "QUIT",
want: &ircserver.Message{
Prefix: "",
Command: "QUIT",
Params: nil,
},
wantNil: false,
},
{
name: "mode query",
input: "MODE #general",
want: &ircserver.Message{
Prefix: "",
Command: "MODE",
Params: []string{"#general"},
},
wantNil: false,
},
{
name: "kick with reason",
input: "KICK #general bob :misbehaving",
want: &ircserver.Message{
Prefix: "",
Command: "KICK",
Params: []string{
"#general", "bob", "misbehaving",
},
},
wantNil: false,
},
{
name: "empty trailing",
input: "PRIVMSG #general :",
want: &ircserver.Message{
Prefix: "",
Command: "PRIVMSG",
Params: []string{"#general", ""},
},
wantNil: false,
},
{
name: "pass command",
input: "PASS mysecret",
want: &ircserver.Message{
Prefix: "",
Command: "PASS",
Params: []string{"mysecret"},
},
wantNil: false,
},
{
name: "ping with server",
input: "PING :irc.example.com",
want: &ircserver.Message{
Prefix: "",
Command: "PING",
Params: []string{"irc.example.com"},
},
wantNil: false,
},
{
name: "topic with trailing spaces",
input: "TOPIC #general :Welcome to the channel!",
want: &ircserver.Message{
Prefix: "",
Command: "TOPIC",
Params: []string{
"#general",
"Welcome to the channel!",
},
},
wantNil: false,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
got := ircserver.ParseMessage(testCase.input)
if testCase.wantNil {
if got != nil {
t.Fatalf("expected nil, got %+v", got)
}
return
}
if got == nil {
t.Fatal("expected non-nil message")
}
if got.Prefix != testCase.want.Prefix {
t.Errorf(
"prefix: got %q, want %q",
got.Prefix, testCase.want.Prefix,
)
}
if got.Command != testCase.want.Command {
t.Errorf(
"command: got %q, want %q",
got.Command, testCase.want.Command,
)
}
if len(got.Params) != len(testCase.want.Params) {
t.Fatalf(
"params length: got %d, want %d (%v vs %v)",
len(got.Params),
len(testCase.want.Params),
got.Params,
testCase.want.Params,
)
}
for i, p := range got.Params {
if p != testCase.want.Params[i] {
t.Errorf(
"param[%d]: got %q, want %q",
i, p, testCase.want.Params[i],
)
}
}
})
}
}
func TestFormatMessage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prefix string
command string
params []string
want string
}{
{
name: "simple command",
prefix: "",
command: "PING",
params: nil,
want: "PING",
},
{
name: "with prefix",
prefix: "server",
command: "PONG",
params: []string{"server"},
want: ":server PONG server",
},
{
name: "privmsg with trailing",
prefix: "alice!alice@host",
command: "PRIVMSG",
params: []string{"#general", "hello world"},
want: ":alice!alice@host PRIVMSG #general :hello world",
},
{
name: "numeric reply",
prefix: "server",
command: "001",
params: []string{"alice", "Welcome to IRC"},
want: ":server 001 alice :Welcome to IRC",
},
{
name: "empty trailing",
prefix: "server",
command: "PRIVMSG",
params: []string{"#chan", ""},
want: ":server PRIVMSG #chan :",
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
got := ircserver.FormatMessage(
testCase.prefix, testCase.command, testCase.params...,
)
if got != testCase.want {
t.Errorf("got %q, want %q", got, testCase.want)
}
})
}
}
func TestParseFormatRoundTrip(t *testing.T) {
t.Parallel()
// Round-trip only works for lines where the last
// parameter either contains a space (gets ':' prefix
// on format) or is a non-trailing single token.
lines := []string{
"PING",
"NICK alice",
"PRIVMSG #general :hello world",
"JOIN #general",
"MODE #general",
}
for _, line := range lines {
msg := ircserver.ParseMessage(line)
if msg == nil {
t.Fatalf("failed to parse: %q", line)
}
formatted := ircserver.FormatMessage(
msg.Prefix, msg.Command, msg.Params...,
)
if formatted != line {
t.Errorf(
"round-trip failed: input %q, got %q",
line, formatted,
)
}
}
}

319
internal/ircserver/relay.go Normal file
View File

@@ -0,0 +1,319 @@
package ircserver
import (
"context"
"encoding/json"
"strings"
"time"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
// relayMessages polls the client output queue and delivers
// IRC-formatted messages to the TCP connection. It runs
// in a goroutine for the lifetime of the connection.
func (c *Conn) relayMessages(ctx context.Context) {
// Use a ticker as a fallback; primary wakeup is via
// broker notification.
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
default:
}
// Drain any available messages.
delivered := c.drainQueue(ctx)
if delivered {
// Tight loop while there are messages.
continue
}
// Wait for notification or timeout.
waitCh := c.brk.Wait(c.sessionID)
select {
case <-waitCh:
// New message notification — loop back.
case <-ticker.C:
// Periodic check.
case <-ctx.Done():
c.brk.Remove(c.sessionID, waitCh)
return
}
}
}
const relayPollLimit = 100
// drainQueue polls the output queue and delivers all
// pending messages. Returns true if at least one message
// was delivered.
func (c *Conn) drainQueue(ctx context.Context) bool {
msgs, lastID, err := c.database.PollMessages(
ctx, c.clientID, c.lastQueueID, relayPollLimit,
)
if err != nil {
return false
}
if len(msgs) == 0 {
return false
}
for i := range msgs {
c.deliverIRCMessage(ctx, &msgs[i])
}
if lastID > c.lastQueueID {
c.lastQueueID = lastID
}
return true
}
// deliverIRCMessage converts a db.IRCMessage to wire
// protocol and sends it.
//
//nolint:cyclop // dispatch table
func (c *Conn) deliverIRCMessage(
_ context.Context,
msg *db.IRCMessage,
) {
command := msg.Command
// Decode body as []string for the trailing text.
var bodyLines []string
if msg.Body != nil {
_ = json.Unmarshal(msg.Body, &bodyLines)
}
text := ""
if len(bodyLines) > 0 {
text = bodyLines[0]
}
// Route by command type.
switch {
case isNumeric(command):
c.deliverNumeric(msg, text)
case command == irc.CmdPrivmsg || command == irc.CmdNotice:
c.deliverTextMessage(msg, command, text)
case command == irc.CmdJoin:
c.deliverJoin(msg)
case command == irc.CmdPart:
c.deliverPart(msg, text)
case command == irc.CmdNick:
c.deliverNickChange(msg, text)
case command == irc.CmdQuit:
c.deliverQuitMsg(msg, text)
case command == irc.CmdTopic:
c.deliverTopicChange(msg, text)
case command == irc.CmdKick:
c.deliverKickMsg(msg, text)
case command == "INVITE":
c.deliverInviteMsg(msg, text)
case command == irc.CmdMode:
c.deliverMode(msg, text)
case command == irc.CmdPing:
// Server-originated PING — reply with PONG.
c.sendFromServer("PING", c.serverSfx)
default:
// Unknown command — deliver as server notice.
if text != "" {
c.sendFromServer("NOTICE", c.nick, text)
}
}
}
// isNumeric returns true if the command is a 3-digit
// numeric code.
func isNumeric(cmd string) bool {
return len(cmd) == 3 &&
cmd[0] >= '0' && cmd[0] <= '9' &&
cmd[1] >= '0' && cmd[1] <= '9' &&
cmd[2] >= '0' && cmd[2] <= '9'
}
// deliverNumeric sends a numeric reply.
func (c *Conn) deliverNumeric(
msg *db.IRCMessage,
text string,
) {
from := msg.From
if from == "" {
from = c.serverSfx
}
var params []string
if msg.Params != nil {
_ = json.Unmarshal(msg.Params, &params)
}
allParams := make([]string, 0, 1+len(params)+1)
allParams = append(allParams, c.nick)
allParams = append(allParams, params...)
if text != "" {
allParams = append(allParams, text)
}
c.send(FormatMessage(from, msg.Command, allParams...))
}
// deliverTextMessage sends PRIVMSG or NOTICE.
func (c *Conn) deliverTextMessage(
msg *db.IRCMessage,
command, text string,
) {
from := msg.From
target := msg.To
// Don't echo our own messages back.
if strings.EqualFold(from, c.nick) {
return
}
prefix := from
if !strings.Contains(prefix, "!") {
prefix = from + "!" + from + "@*"
}
c.send(FormatMessage(prefix, command, target, text))
}
// deliverJoin sends a JOIN notification.
func (c *Conn) deliverJoin(msg *db.IRCMessage) {
// Don't echo our own JOINs (we already sent them
// during joinChannel).
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
c.send(FormatMessage(prefix, "JOIN", channel))
}
// deliverPart sends a PART notification.
func (c *Conn) deliverPart(msg *db.IRCMessage, text string) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
if text != "" {
c.send(FormatMessage(
prefix, "PART", channel, text,
))
} else {
c.send(FormatMessage(prefix, "PART", channel))
}
}
// deliverNickChange sends a NICK change notification.
func (c *Conn) deliverNickChange(
msg *db.IRCMessage,
newNick string,
) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
c.send(FormatMessage(prefix, "NICK", newNick))
}
// deliverQuitMsg sends a QUIT notification.
func (c *Conn) deliverQuitMsg(
msg *db.IRCMessage,
text string,
) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
if text != "" {
c.send(FormatMessage(
prefix, "QUIT", "Quit: "+text,
))
} else {
c.send(FormatMessage(prefix, "QUIT", "Quit"))
}
}
// deliverTopicChange sends a TOPIC change notification.
func (c *Conn) deliverTopicChange(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
c.send(FormatMessage(prefix, "TOPIC", channel, text))
}
// deliverKickMsg sends a KICK notification.
func (c *Conn) deliverKickMsg(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
var params []string
if msg.Params != nil {
_ = json.Unmarshal(msg.Params, &params)
}
kickTarget := ""
if len(params) > 0 {
kickTarget = params[0]
}
if kickTarget != "" {
c.send(FormatMessage(
prefix, "KICK", channel, kickTarget, text,
))
} else {
c.send(FormatMessage(
prefix, "KICK", channel, "?", text,
))
}
}
// deliverInviteMsg sends an INVITE notification.
func (c *Conn) deliverInviteMsg(
_ *db.IRCMessage,
text string,
) {
c.sendFromServer("NOTICE", c.nick, text)
}
// deliverMode sends a MODE change notification.
func (c *Conn) deliverMode(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
target := msg.To
if text != "" {
c.send(FormatMessage(prefix, "MODE", target, text))
}
}

View File

@@ -0,0 +1,157 @@
package ircserver
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/service"
"go.uber.org/fx"
)
// Params defines the dependencies for creating an IRC
// Server.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *db.Database
Broker *broker.Broker
Service *service.Service
}
// Server is the TCP IRC protocol server.
type Server struct {
log *slog.Logger
cfg *config.Config
database *db.Database
brk *broker.Broker
svc *service.Service
listener net.Listener
mu sync.Mutex
conns map[*Conn]struct{}
cancel context.CancelFunc
}
// New creates a new IRC Server and registers its lifecycle
// hooks. The listener is only started if IRC_LISTEN_ADDR
// is configured; otherwise the server is inert.
func New(
lifecycle fx.Lifecycle,
params Params,
) *Server {
srv := &Server{
log: params.Logger.Get(),
cfg: params.Config,
database: params.Database,
brk: params.Broker,
svc: params.Service,
conns: make(map[*Conn]struct{}),
listener: nil,
cancel: nil,
mu: sync.Mutex{},
}
listenAddr := params.Config.IRCListenAddr
if listenAddr == "" {
return srv
}
lifecycle.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return srv.start(ctx, listenAddr)
},
OnStop: func(_ context.Context) error {
srv.stop()
return nil
},
})
return srv
}
// start begins listening for TCP connections.
//
//nolint:contextcheck // long-lived server ctx, not the short Fx one
func (s *Server) start(_ context.Context, addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("irc listen: %w", err)
}
s.listener = ln
ctx, cancel := context.WithCancel(context.Background())
s.cancel = cancel
s.log.Info(
"irc server listening", "addr", addr,
)
go s.acceptLoop(ctx)
return nil
}
// stop shuts down the listener and all connections.
func (s *Server) stop() {
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
s.listener.Close() //nolint:errcheck,gosec
}
s.mu.Lock()
for c := range s.conns {
c.conn.Close() //nolint:errcheck,gosec
}
s.mu.Unlock()
}
// acceptLoop accepts new connections.
func (s *Server) acceptLoop(ctx context.Context) {
for {
tcpConn, err := s.listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
s.log.Error(
"irc accept error", "error", err,
)
continue
}
}
client := newConn(
ctx, tcpConn, s.log,
s.database, s.brk, s.cfg, s.svc,
)
s.mu.Lock()
s.conns[client] = struct{}{}
s.mu.Unlock()
go func() {
defer func() {
s.mu.Lock()
delete(s.conns, client)
s.mu.Unlock()
}()
client.serve(ctx)
}()
}
}

View File

@@ -0,0 +1,625 @@
package ircserver_test
import (
"bufio"
"database/sql"
"fmt"
"log/slog"
"net"
"os"
"strings"
"testing"
"time"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
_ "modernc.org/sqlite"
)
const testTimeout = 5 * time.Second
func TestMain(m *testing.M) {
db.SetBcryptCost(4)
os.Exit(m.Run())
}
// testEnv holds the shared test infrastructure.
type testEnv struct {
database *db.Database
brk *broker.Broker
cfg *config.Config
srv *ircserver.Server
}
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
dsn := fmt.Sprintf(
"file:%s?mode=memory&cache=shared&_journal_mode=WAL",
t.Name(),
)
conn, err := sql.Open("sqlite", dsn)
if err != nil {
t.Fatalf("open db: %v", err)
}
conn.SetMaxOpenConns(1)
_, err = conn.ExecContext(
t.Context(), "PRAGMA foreign_keys = ON",
)
if err != nil {
t.Fatalf("pragma: %v", err)
}
database := db.NewTestDatabaseFromConn(conn)
err = database.RunMigrations(t.Context())
if err != nil {
t.Fatalf("migrate: %v", err)
}
brk := broker.New()
cfg := &config.Config{ //nolint:exhaustruct
ServerName: "test.irc",
MOTD: "Welcome to test IRC",
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
addr := listener.Addr().String()
err = listener.Close()
if err != nil {
t.Fatalf("close listener: %v", err)
}
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct
))
srv := ircserver.NewTestServer(log, cfg, database, brk)
err = srv.Start(addr)
if err != nil {
t.Fatalf("start irc server: %v", err)
}
t.Cleanup(func() {
srv.Stop()
err := conn.Close()
if err != nil {
t.Logf("close db: %v", err)
}
})
return &testEnv{
database: database,
brk: brk,
cfg: cfg,
srv: srv,
}
}
// dial connects to the test server.
func (env *testEnv) dial(t *testing.T) *testClient {
t.Helper()
conn, err := net.DialTimeout(
"tcp",
env.srv.Listener().Addr().String(),
testTimeout,
)
if err != nil {
t.Fatalf("dial: %v", err)
}
t.Cleanup(func() {
err := conn.Close()
if err != nil {
t.Logf("close conn: %v", err)
}
})
return &testClient{
t: t,
conn: conn,
scanner: bufio.NewScanner(conn),
}
}
// testClient wraps a raw TCP connection with helpers.
type testClient struct {
t *testing.T
conn net.Conn
scanner *bufio.Scanner
}
func (tc *testClient) send(line string) {
tc.t.Helper()
_ = tc.conn.SetWriteDeadline(
time.Now().Add(testTimeout),
)
_, err := fmt.Fprintf(tc.conn, "%s\r\n", line)
if err != nil {
tc.t.Fatalf("send: %v", err)
}
}
func (tc *testClient) readLine() string {
tc.t.Helper()
_ = tc.conn.SetReadDeadline(
time.Now().Add(testTimeout),
)
if !tc.scanner.Scan() {
err := tc.scanner.Err()
if err != nil {
tc.t.Fatalf("read: %v", err)
}
tc.t.Fatal("connection closed unexpectedly")
}
return tc.scanner.Text()
}
// readUntil reads lines until one matches the predicate.
func (tc *testClient) readUntil(
pred func(string) bool,
) []string {
tc.t.Helper()
var lines []string
for {
line := tc.readLine()
lines = append(lines, line)
if pred(line) {
return lines
}
}
}
// register sends NICK + USER and reads through the welcome
// burst.
func (tc *testClient) register(nick string) []string {
tc.t.Helper()
tc.send("NICK " + nick)
tc.send("USER " + nick + " 0 * :Test User")
return tc.readUntil(func(line string) bool {
return strings.Contains(line, " 376 ") ||
strings.Contains(line, " 422 ")
})
}
// assertContains checks that at least one line matches the
// given substring.
func assertContains(
t *testing.T,
lines []string,
substr, description string,
) {
t.Helper()
for _, line := range lines {
if strings.Contains(line, substr) {
return
}
}
t.Errorf("did not find %q in output: %s", substr, description)
}
// joinAndDrain joins a channel and reads until
// RPL_ENDOFNAMES.
func (tc *testClient) joinAndDrain(channel string) {
tc.t.Helper()
tc.send("JOIN " + channel)
tc.readUntil(func(line string) bool {
return strings.Contains(line, " 366 ")
})
}
// sendAndExpect sends a command and reads until a line
// containing the expected substring is found.
func (tc *testClient) sendAndExpect(
cmd, expect string,
) []string {
tc.t.Helper()
tc.send(cmd)
return tc.readUntil(func(line string) bool {
return strings.Contains(line, expect)
})
}
func TestRegistration(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
lines := client.register("alice")
assertContains(t, lines, " 001 ", "RPL_WELCOME")
}
func TestWelcomeContainsNick(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
lines := client.register("bob")
for _, line := range lines {
if strings.Contains(line, " 001 ") &&
!strings.Contains(line, "bob") {
t.Errorf("001 should contain nick: %s", line)
}
}
}
func TestPingPong(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("pingtest")
lines := client.sendAndExpect("PING :hello", "PONG")
assertContains(t, lines, "PONG", "PONG response")
}
func TestJoinChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("joiner")
client.send("JOIN #test")
lines := client.readUntil(func(line string) bool {
return strings.Contains(line, " 366 ")
})
assertContains(t, lines, "JOIN", "JOIN echo")
assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES")
}
func TestPrivmsgBetweenClients(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice_pm")
bob := env.dial(t)
bob.register("bob_pm")
alice.joinAndDrain("#chat")
bob.joinAndDrain("#chat")
alice.send("PRIVMSG #chat :hello bob!")
lines := bob.sendAndExpect("PING :sync", "hello bob!")
assertContains(t, lines, "hello bob!", "channel PRIVMSG")
}
func TestNickChange(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("oldnick")
lines := client.sendAndExpect("NICK newnick", "newnick")
assertContains(t, lines, "NICK", "NICK change echo")
}
func TestDuplicateNick(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
first := env.dial(t)
first.register("taken")
second := env.dial(t)
second.send("NICK taken")
second.send("USER taken 0 * :Test")
lines := second.readUntil(func(line string) bool {
return strings.Contains(line, " 433 ")
})
assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE")
}
func TestListChannels(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("lister")
client.joinAndDrain("#listtest")
lines := client.sendAndExpect("LIST", " 323 ")
assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term
}
func TestWhois(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("whoistest")
lines := client.sendAndExpect(
"WHOIS whoistest", " 318 ",
)
assertContains(t, lines, " 311 ", "RPL_WHOISUSER")
assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS")
}
func TestQuit(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("quitter")
lines := client.sendAndExpect(
"QUIT :goodbye", "ERROR",
)
assertContains(t, lines, "goodbye", "QUIT reason")
}
func TestTopicSetAndGet(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("topicuser")
client.joinAndDrain("#topictest")
lines := client.sendAndExpect(
"TOPIC #topictest :New topic here",
"New topic here",
)
assertContains(
t, lines, "New topic here", "TOPIC echo",
)
}
func TestUnknownCommand(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("unknowncmd")
lines := client.sendAndExpect("FOOBAR", " 421 ")
assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND")
}
func TestDirectMessage(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
sender := env.dial(t)
sender.register("dmsender")
receiver := env.dial(t)
receiver.register("dmreceiver")
// Give relay goroutines time to start.
time.Sleep(100 * time.Millisecond)
sender.send("PRIVMSG dmreceiver :hello privately")
lines := receiver.readUntil(func(line string) bool {
return strings.Contains(line, "hello privately")
})
assertContains(
t, lines, "hello privately", "direct PRIVMSG",
)
}
func TestCAPNegotiation(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.send("CAP LS 302")
line := client.readLine()
if !strings.Contains(line, "CAP") {
t.Errorf("expected CAP response, got: %s", line)
}
client.send("CAP END")
lines := client.register("capuser")
assertContains(t, lines, " 001 ", "registration after CAP")
}
func TestPartChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("parter")
client.joinAndDrain("#parttest")
lines := client.sendAndExpect(
"PART #parttest :leaving", "PART",
)
assertContains(t, lines, "#parttest", "PART echo")
}
func TestModeQuery(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("modeuser")
client.joinAndDrain("#modetest")
lines := client.sendAndExpect(
"MODE #modetest", " 324 ",
)
assertContains(
t, lines, " 324 ", "RPL_CHANNELMODEIS",
)
}
func TestWhoChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("whouser")
client.joinAndDrain("#whotest")
lines := client.sendAndExpect("WHO #whotest", " 315 ")
assertContains(t, lines, " 352 ", "RPL_WHOREPLY")
}
func TestLusers(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("luseruser")
lines := client.sendAndExpect("LUSERS", " 255 ")
assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT")
}
func TestMotd(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("motduser")
lines := client.sendAndExpect("MOTD", " 376 ")
assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD")
}
func TestAwaySetAndClear(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("awayuser")
setLines := client.sendAndExpect(
"AWAY :brb lunch", " 306 ",
)
assertContains(t, setLines, " 306 ", "RPL_NOWAWAY")
clearLines := client.sendAndExpect("AWAY", " 305 ")
assertContains(t, clearLines, " 305 ", "RPL_UNAWAY")
}
func TestHandlePassPostRegistration(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("passuser")
lines := client.sendAndExpect(
"PASS :mypassword123", "Password set",
)
assertContains(
t, lines, "Password set", "password confirmation",
)
}
func TestPreRegistrationNotRegistered(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.send("PRIVMSG #test :hello")
line := client.readLine()
if !strings.Contains(line, " 451 ") {
t.Errorf(
"expected ERR_NOTREGISTERED (451), got: %s",
line,
)
}
}
func TestNamesNonExistentChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("namesuser")
lines := client.sendAndExpect(
"NAMES #doesnotexist", " 366 ",
)
assertContains(
t, lines, " 366 ",
"RPL_ENDOFNAMES for non-existent channel",
)
}
func BenchmarkParseMessage(b *testing.B) {
line := ":nick!user@host PRIVMSG #channel :Hello, world!"
b.ResetTimer()
for range b.N {
_ = ircserver.ParseMessage(line)
}
}
func BenchmarkFormatMessage(b *testing.B) {
b.ResetTimer()
for range b.N {
_ = ircserver.FormatMessage(
"nick!user@host", "PRIVMSG",
"#channel", "Hello, world!",
)
}
}