2 Commits

Author SHA1 Message Date
clawbot
d0e925bf70 docs: add reverse proxy security note to login rate limiting section
All checks were successful
check / check (push) Successful in 1m8s
2026-03-17 04:49:49 -07:00
user
e519ffa1e6 feat: add per-IP rate limiting to login endpoint
Add a token-bucket rate limiter (golang.org/x/time/rate) that limits
login attempts per client IP on POST /api/v1/login. Returns 429 Too
Many Requests with a Retry-After header when the limit is exceeded.

Configurable via LOGIN_RATE_LIMIT (requests/sec, default 1) and
LOGIN_RATE_BURST (burst size, default 5). Stale per-IP entries are
automatically cleaned up every 10 minutes.

Only the login endpoint is rate-limited per sneak's instruction —
session creation and registration use hashcash proof-of-work instead.
2026-03-17 04:48:46 -07:00
38 changed files with 1509 additions and 13789 deletions

View File

@@ -53,7 +53,7 @@ RUN apk add --no-cache ca-certificates \
COPY --from=builder /neoircd /usr/local/bin/neoircd
USER neoirc
EXPOSE 8080 6667
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget -qO- http://localhost:8080/.well-known/healthcheck.json || exit 1
ENTRYPOINT ["neoircd"]

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks ensure-web-dist
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks
BINARY := neoircd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -7,21 +7,10 @@ LDFLAGS := -X main.Version=$(VERSION) -X main.Buildarch=$(BUILDARCH)
all: check build
# ensure-web-dist creates placeholder files so //go:embed dist/* in
# web/embed.go resolves without a full Node.js build. The real SPA is
# built by the web-builder Docker stage; these placeholders let
# "make test" and "make build" work outside Docker.
ensure-web-dist:
@if [ ! -d web/dist ]; then \
mkdir -p web/dist && \
touch web/dist/index.html web/dist/style.css web/dist/app.js && \
echo "==> Created placeholder web/dist/ for go:embed"; \
fi
build: ensure-web-dist
build:
go build -ldflags "$(LDFLAGS)" -o bin/$(BINARY) ./cmd/neoircd
lint: ensure-web-dist
lint:
golangci-lint run --config .golangci.yml ./...
fmt:
@@ -31,8 +20,8 @@ fmt:
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: ensure-web-dist
go test -timeout 30s -race -cover ./... || go test -timeout 30s -race -v ./...
test:
go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong

876
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -2,17 +2,14 @@
package main
import (
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/globals"
"git.eeqj.de/sneak/neoirc/internal/handlers"
"git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/middleware"
"git.eeqj.de/sneak/neoirc/internal/server"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx"
)
@@ -31,23 +28,16 @@ func main() {
fx.New(
fx.Provide(
broker.New,
config.New,
db.New,
globals.New,
handlers.New,
ircserver.New,
logger.New,
server.New,
middleware.New,
healthcheck.New,
service.New,
stats.New,
),
fx.Invoke(func(
_ *server.Server,
_ *ircserver.Server,
) {
}),
fx.Invoke(func(*server.Server) {}),
).Run()
}

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strconv"
"strings"
@@ -29,19 +28,16 @@ var errHTTP = errors.New("HTTP error")
// Client wraps HTTP calls to the neoirc server API.
type Client struct {
BaseURL string
Token string
HTTPClient *http.Client
}
// NewClient creates a new API client with a cookie jar
// for automatic auth cookie management.
// NewClient creates a new API client.
func NewClient(baseURL string) *Client {
jar, _ := cookiejar.New(nil)
return &Client{
return &Client{ //nolint:exhaustruct // Token set after CreateSession
BaseURL: baseURL,
HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: httpTimeout,
Jar: jar,
},
}
}
@@ -83,6 +79,8 @@ func (client *Client) CreateSession(
return nil, fmt.Errorf("decode session: %w", err)
}
client.Token = resp.Token
return &resp, nil
}
@@ -123,7 +121,6 @@ func (client *Client) PollMessages(
Timeout: time.Duration(
timeout+pollExtraTime,
) * time.Second,
Jar: client.HTTPClient.Jar,
}
params := url.Values{}
@@ -148,6 +145,10 @@ func (client *Client) PollMessages(
return nil, fmt.Errorf("new request: %w", err)
}
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
resp, err := pollClient.Do(request)
if err != nil {
return nil, fmt.Errorf("poll request: %w", err)
@@ -303,6 +304,12 @@ func (client *Client) do(
"Content-Type", "application/json",
)
if client.Token != "" {
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
}
resp, err := client.HTTPClient.Do(request)
if err != nil {
return nil, fmt.Errorf("http: %w", err)

View File

@@ -7,8 +7,6 @@ import (
"fmt"
"math/big"
"time"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
)
const (
@@ -39,23 +37,6 @@ func MintHashcash(bits int, resource string) string {
}
}
// MintChannelHashcash computes a hashcash stamp bound to
// a specific channel and message body. The stamp format
// is 1:bits:YYMMDD:channel:bodyhash:counter where
// bodyhash is the hex-encoded SHA-256 of the message
// body bytes. Delegates to the internal/hashcash package.
func MintChannelHashcash(
bits int,
channel string,
body []byte,
) string {
bodyHash := hashcash.BodyHash(body)
return hashcash.MintChannelStamp(
bits, channel, bodyHash,
)
}
// hasLeadingZeroBits checks if hash has at least numBits
// leading zero bits.
func hasLeadingZeroBits(

View File

@@ -10,8 +10,9 @@ type SessionRequest struct {
// SessionResponse is the response from session creation.
type SessionResponse struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
// StateResponse is the response from GET /api/v1/state.

View File

@@ -46,11 +46,8 @@ type Config struct {
FederationKey string
SessionIdleTimeout string
HashcashBits int
OperName string
OperPassword string
LoginRateLimit float64
LoginRateBurst int
IRCListenAddr string
params *Params
log *slog.Logger
}
@@ -83,11 +80,8 @@ func New(
viper.SetDefault("FEDERATION_KEY", "")
viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h")
viper.SetDefault("NEOIRC_HASHCASH_BITS", "20")
viper.SetDefault("NEOIRC_OPER_NAME", "")
viper.SetDefault("NEOIRC_OPER_PASSWORD", "")
viper.SetDefault("LOGIN_RATE_LIMIT", "1")
viper.SetDefault("LOGIN_RATE_BURST", "5")
viper.SetDefault("IRC_LISTEN_ADDR", ":6667")
err := viper.ReadInConfig()
if err != nil {
@@ -114,11 +108,8 @@ func New(
FederationKey: viper.GetString("FEDERATION_KEY"),
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"),
OperName: viper.GetString("NEOIRC_OPER_NAME"),
OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"),
LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"),
LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"),
IRCListenAddr: viper.GetString("IRC_LISTEN_ADDR"),
log: log,
params: &params,
}

View File

@@ -10,46 +10,93 @@ import (
"golang.org/x/crypto/bcrypt"
)
//nolint:gochecknoglobals // var so tests can override via SetBcryptCost
var bcryptCost = bcrypt.DefaultCost
// SetBcryptCost overrides the bcrypt cost.
// Use bcrypt.MinCost in tests to avoid slow hashing.
func SetBcryptCost(cost int) { bcryptCost = cost }
const bcryptCost = bcrypt.DefaultCost
var errNoPassword = errors.New(
"account has no password set",
)
// SetPassword sets a bcrypt-hashed password on a session,
// enabling multi-client login via POST /api/v1/login.
func (database *Database) SetPassword(
// RegisterUser creates a session with a hashed password
// and returns session ID, client ID, and token.
func (database *Database) RegisterUser(
ctx context.Context,
sessionID int64,
password string,
) error {
nick, password string,
) (int64, int64, string, error) {
hash, err := bcrypt.GenerateFromPassword(
[]byte(password), bcryptCost,
)
if err != nil {
return fmt.Errorf("hash password: %w", err)
return 0, 0, "", fmt.Errorf(
"hash password: %w", err,
)
}
_, err = database.conn.ExecContext(ctx,
"UPDATE sessions SET password_hash = ? WHERE id = ?",
string(hash), sessionID)
sessionUUID := uuid.New().String()
clientUUID := uuid.New().String()
token, err := generateToken()
if err != nil {
return fmt.Errorf("set password: %w", err)
return 0, 0, "", err
}
return nil
now := time.Now()
transaction, err := database.conn.BeginTx(ctx, nil)
if err != nil {
return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
}
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, password_hash,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
sessionUUID, nick, string(hash), now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
tokenHash := hashToken(token)
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit registration: %w", err,
)
}
return sessionID, clientID, token, nil
}
// LoginUser verifies a nick/password and creates a new
// client token.
func (database *Database) LoginUser(
ctx context.Context,
nick, password, remoteIP, hostname string,
nick, password string,
) (int64, int64, string, error) {
var (
sessionID int64
@@ -96,11 +143,10 @@ func (database *Database) LoginUser(
res, err := database.conn.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token, ip, hostname,
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
if err != nil {
return 0, 0, "", fmt.Errorf(
"create login client: %w", err,

View File

@@ -6,65 +6,63 @@ import (
_ "modernc.org/sqlite"
)
func TestSetPassword(t *testing.T) {
func TestRegisterUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err :=
database.CreateSession(ctx, "passuser", "", "", "")
sessionID, clientID, token, err :=
database.RegisterUser(ctx, "reguser", "password123")
if err != nil {
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "password123",
)
if err != nil {
t.Fatal(err)
}
// Verify we can now log in with the password.
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "passuser", "password123", "", "")
if err != nil {
t.Fatal(err)
}
if loginSID == 0 || loginCID == 0 || loginToken == "" {
if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token")
}
// Verify session works via token lookup.
sid, cid, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if sid != sessionID || cid != clientID {
t.Fatal("session/client id mismatch")
}
if nick != "reguser" {
t.Fatalf("expected reguser, got %s", nick)
}
}
func TestSetPasswordThenWrongLogin(t *testing.T) {
func TestRegisterUserDuplicateNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err :=
database.CreateSession(ctx, "wrongpw", "", "", "")
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "dupnick", "password123")
if err != nil {
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "correctpass",
)
if err != nil {
t.Fatal(err)
_ = regSID
_ = regCID
_ = regToken
dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345")
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
}
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
_ = dupSID
_ = dupCID
_ = dupToken
}
func TestLoginUser(t *testing.T) {
@@ -73,26 +71,23 @@ func TestLoginUser(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err :=
database.CreateSession(ctx, "loginuser", "", "", "")
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "loginuser", "mypassword")
if err != nil {
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "mypassword",
)
_ = regSID
_ = regCID
_ = regToken
sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword")
if err != nil {
t.Fatal(err)
}
loginSID, loginCID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword", "", "")
if err != nil {
t.Fatal(err)
}
if loginSID == 0 || loginCID == 0 || token == "" {
if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token")
}
@@ -108,6 +103,33 @@ func TestLoginUser(t *testing.T) {
}
}
func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
}
func TestLoginUserNoPassword(t *testing.T) {
t.Parallel()
@@ -116,7 +138,7 @@ func TestLoginUserNoPassword(t *testing.T) {
// Create anonymous session (no password).
anonSID, anonCID, anonToken, err :=
database.CreateSession(ctx, "anon", "", "", "")
database.CreateSession(ctx, "anon")
if err != nil {
t.Fatal(err)
}
@@ -126,7 +148,7 @@ func TestLoginUserNoPassword(t *testing.T) {
_ = anonToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "anon", "anything1", "", "")
database.LoginUser(ctx, "anon", "anything1")
if loginErr == nil {
t.Fatal(
"expected error for login on passwordless account",
@@ -145,7 +167,7 @@ func TestLoginUserNonexistent(t *testing.T) {
ctx := t.Context()
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "ghost", "password123", "", "")
database.LoginUser(ctx, "ghost", "password123")
if err == nil {
t.Fatal("expected error for nonexistent user")
}

View File

@@ -135,21 +135,13 @@ type migration struct {
func (database *Database) runMigrations(
ctx context.Context,
) error {
bootstrap, err := SchemaFiles.ReadFile(
"schema/000.sql",
)
_, err := database.conn.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)`)
if err != nil {
return fmt.Errorf(
"read bootstrap migration: %w", err,
)
}
_, err = database.conn.ExecContext(
ctx, string(bootstrap),
)
if err != nil {
return fmt.Errorf(
"execute bootstrap migration: %w", err,
"create schema_migrations: %w", err,
)
}
@@ -278,11 +270,6 @@ func (database *Database) loadMigrations() (
continue
}
// Skip bootstrap migration; it is executed separately.
if version == 0 {
continue
}
content, readErr := SchemaFiles.ReadFile(
"schema/" + entry.Name(),
)

View File

@@ -1,14 +0,0 @@
package db_test
import (
"os"
"testing"
"git.eeqj.de/sneak/neoirc/internal/db"
"golang.org/x/crypto/bcrypt"
)
func TestMain(m *testing.M) {
db.SetBcryptCost(bcrypt.MinCost)
os.Exit(m.Run())
}

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) {
ctx := t.Context()
sessionID, _, token, err := database.CreateSession(
ctx, "alice", "", "", "",
ctx, "alice",
)
if err != nil {
t.Fatal(err)
@@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) {
}
_, _, dupToken, dupErr := database.CreateSession(
ctx, "alice", "", "", "",
ctx, "alice",
)
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
@@ -54,249 +54,13 @@ func TestCreateSession(t *testing.T) {
_ = dupToken
}
// assertSessionHostInfo creates a session and verifies
// the stored username and hostname match expectations.
func assertSessionHostInfo(
t *testing.T,
database *db.Database,
nick, inputUser, inputHost,
expectUser, expectHost string,
) {
t.Helper()
sessionID, _, _, err := database.CreateSession(
t.Context(), nick, inputUser, inputHost, "",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
t.Context(), sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.Username != expectUser {
t.Fatalf(
"expected username %s, got %s",
expectUser, info.Username,
)
}
if info.Hostname != expectHost {
t.Fatalf(
"expected hostname %s, got %s",
expectHost, info.Hostname,
)
}
}
func TestCreateSessionWithUserHost(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
assertSessionHostInfo(
t, database,
"hostuser", "myident", "example.com",
"myident", "example.com",
)
}
func TestCreateSessionDefaultUsername(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
// Empty username defaults to nick.
assertSessionHostInfo(
t, database,
"defaultu", "", "host.local",
"defaultu", "host.local",
)
}
func TestCreateSessionStoresIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, clientID, _, err := database.CreateSession(
ctx, "ipuser", "ident", "host.example.com",
"192.168.1.42",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "192.168.1.42" {
t.Fatalf(
"expected session IP 192.168.1.42, got %s",
info.IP,
)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "192.168.1.42" {
t.Fatalf(
"expected client IP 192.168.1.42, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "host.example.com" {
t.Fatalf(
"expected client hostname host.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestGetClientHostInfoNotFound(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
_, err := database.GetClientHostInfo(
t.Context(), 99999,
)
if err == nil {
t.Fatal("expected error for nonexistent client")
}
}
func TestGetSessionHostInfoNotFound(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
_, err := database.GetSessionHostInfo(
t.Context(), 99999,
)
if err == nil {
t.Fatal("expected error for nonexistent session")
}
}
func TestFormatHostmask(t *testing.T) {
t.Parallel()
result := db.FormatHostmask(
"nick", "user", "host.com",
)
if result != "nick!user@host.com" {
t.Fatalf(
"expected nick!user@host.com, got %s",
result,
)
}
}
func TestFormatHostmaskDefaults(t *testing.T) {
t.Parallel()
result := db.FormatHostmask("nick", "", "")
if result != "nick!nick@*" {
t.Fatalf(
"expected nick!nick@*, got %s",
result,
)
}
}
func TestMemberInfoHostmask(t *testing.T) {
t.Parallel()
member := &db.MemberInfo{ //nolint:exhaustruct // test only uses hostmask fields
Nick: "alice",
Username: "aliceident",
Hostname: "alice.example.com",
}
hostmask := member.Hostmask()
expected := "alice!aliceident@alice.example.com"
if hostmask != expected {
t.Fatalf(
"expected %s, got %s", expected, hostmask,
)
}
}
func TestChannelMembersIncludeUserHost(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "memuser", "myuser", "myhost.net", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(
ctx, "#hostchan",
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
members, err := database.ChannelMembers(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(members) != 1 {
t.Fatalf(
"expected 1 member, got %d", len(members),
)
}
if members[0].Username != "myuser" {
t.Fatalf(
"expected username myuser, got %s",
members[0].Username,
)
}
if members[0].Hostname != "myhost.net" {
t.Fatalf(
"expected hostname myhost.net, got %s",
members[0].Hostname,
)
}
}
func TestGetSessionByToken(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(ctx, "bob", "", "", "")
_, _, token, err := database.CreateSession(ctx, "bob")
if err != nil {
t.Fatal(err)
}
@@ -329,7 +93,7 @@ func TestGetSessionByNick(t *testing.T) {
ctx := t.Context()
charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie", "", "", "")
database.CreateSession(ctx, "charlie")
if err != nil {
t.Fatal(err)
}
@@ -386,7 +150,7 @@ func TestJoinAndPart(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, "user1", "", "", "")
sid, _, _, err := database.CreateSession(ctx, "user1")
if err != nil {
t.Fatal(err)
}
@@ -435,7 +199,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Fatal(err)
}
sid, _, _, err := database.CreateSession(ctx, "temp", "", "", "")
sid, _, _, err := database.CreateSession(ctx, "temp")
if err != nil {
t.Fatal(err)
}
@@ -470,7 +234,7 @@ func createSessionWithChannels(
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, nick, "", "", "")
sid, _, _, err := database.CreateSession(ctx, nick)
if err != nil {
t.Fatal(err)
}
@@ -553,7 +317,7 @@ func TestChangeNick(t *testing.T) {
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "old", "", "", "",
ctx, "old",
)
if err != nil {
t.Fatal(err)
@@ -637,7 +401,7 @@ func TestPollMessages(t *testing.T) {
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "poller", "", "", "",
ctx, "poller",
)
if err != nil {
t.Fatal(err)
@@ -744,7 +508,7 @@ func TestDeleteSession(t *testing.T) {
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "deleteme", "", "", "",
ctx, "deleteme",
)
if err != nil {
t.Fatal(err)
@@ -784,12 +548,12 @@ func TestChannelMembers(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "", "")
sid1, _, _, err := database.CreateSession(ctx, "m1")
if err != nil {
t.Fatal(err)
}
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "", "")
sid2, _, _, err := database.CreateSession(ctx, "m2")
if err != nil {
t.Fatal(err)
}
@@ -847,7 +611,7 @@ func TestEnqueueToClient(t *testing.T) {
ctx := t.Context()
_, _, token, err := database.CreateSession(
ctx, "enqclient", "", "", "",
ctx, "enqclient",
)
if err != nil {
t.Fatal(err)
@@ -887,604 +651,3 @@ func TestEnqueueToClient(t *testing.T) {
t.Fatalf("expected 1, got %d", len(msgs))
}
}
func TestSetAndCheckSessionOper(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.CreateSession(
ctx, "opernick", "", "", "",
)
if err != nil {
t.Fatal(err)
}
// Initially not oper.
isOper, err := database.IsSessionOper(ctx, sessionID)
if err != nil {
t.Fatal(err)
}
if isOper {
t.Fatal("expected session not to be oper")
}
// Set oper.
err = database.SetSessionOper(ctx, sessionID, true)
if err != nil {
t.Fatal(err)
}
isOper, err = database.IsSessionOper(ctx, sessionID)
if err != nil {
t.Fatal(err)
}
if !isOper {
t.Fatal("expected session to be oper")
}
// Unset oper.
err = database.SetSessionOper(ctx, sessionID, false)
if err != nil {
t.Fatal(err)
}
isOper, err = database.IsSessionOper(ctx, sessionID)
if err != nil {
t.Fatal(err)
}
if isOper {
t.Fatal("expected session not to be oper")
}
}
func TestGetLatestClientForSession(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.CreateSession(
ctx, "clientnick", "", "", "10.0.0.1",
)
if err != nil {
t.Fatal(err)
}
clientInfo, err := database.GetLatestClientForSession(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "10.0.0.1" {
t.Fatalf(
"expected IP 10.0.0.1, got %s",
clientInfo.IP,
)
}
}
func TestGetOperCount(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
// Create two sessions.
sid1, _, _, err := database.CreateSession(
ctx, "user1", "", "", "",
)
if err != nil {
t.Fatal(err)
}
sid2, _, _, err := database.CreateSession(
ctx, "user2", "", "", "",
)
_ = sid2
if err != nil {
t.Fatal(err)
}
// Initially zero opers.
count, err := database.GetOperCount(ctx)
if err != nil {
t.Fatal(err)
}
if count != 0 {
t.Fatalf("expected 0 opers, got %d", count)
}
// Set one as oper.
err = database.SetSessionOper(ctx, sid1, true)
if err != nil {
t.Fatal(err)
}
count, err = database.GetOperCount(ctx)
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatalf("expected 1 oper, got %d", count)
}
}
// --- Tier 2 Tests ---
func TestWildcardMatch(t *testing.T) {
t.Parallel()
tests := []struct {
pattern string
input string
match bool
}{
{"*!*@*", "nick!user@host", true},
{"*!*@*.example.com", "nick!user@foo.example.com", true},
{"*!*@*.example.com", "nick!user@other.net", false},
{"badnick!*@*", "badnick!user@host", true},
{"badnick!*@*", "goodnick!user@host", false},
{"nick!user@host", "nick!user@host", true},
{"nick!user@host", "nick!user@other", false},
{"*", "anything", true},
{"?ick!*@*", "nick!user@host", true},
{"?ick!*@*", "nn!user@host", false},
// Case-insensitive.
{"Nick!*@*", "nick!user@host", true},
}
for _, tc := range tests {
result := db.MatchBanMask(tc.pattern, tc.input)
if result != tc.match {
t.Errorf(
"MatchBanMask(%q, %q) = %v, want %v",
tc.pattern, tc.input, result, tc.match,
)
}
}
}
func TestChannelBanCRUD(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil {
t.Fatal(err)
}
// No bans initially.
bans, err := database.ListChannelBans(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(bans) != 0 {
t.Fatalf("expected 0 bans, got %d", len(bans))
}
// Add a ban.
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op",
)
if err != nil {
t.Fatal(err)
}
bans, err = database.ListChannelBans(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(bans) != 1 {
t.Fatalf("expected 1 ban, got %d", len(bans))
}
if bans[0].Mask != "*!*@evil.com" {
t.Fatalf("wrong mask: %s", bans[0].Mask)
}
// Duplicate add is ignored (OR IGNORE).
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op2",
)
if err != nil {
t.Fatal(err)
}
bans, _ = database.ListChannelBans(ctx, chID)
if len(bans) != 1 {
t.Fatalf("expected 1 ban after dup, got %d", len(bans))
}
// Remove ban.
err = database.RemoveChannelBan(
ctx, chID, "*!*@evil.com",
)
if err != nil {
t.Fatal(err)
}
bans, _ = database.ListChannelBans(ctx, chID)
if len(bans) != 0 {
t.Fatalf("expected 0 bans after remove, got %d", len(bans))
}
}
func TestIsSessionBanned(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "victim", "victim", "evil.com", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#bantest")
if err != nil {
t.Fatal(err)
}
// Not banned initially.
banned, err := database.IsSessionBanned(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if banned {
t.Fatal("should not be banned initially")
}
// Add ban matching the hostmask.
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op",
)
if err != nil {
t.Fatal(err)
}
banned, err = database.IsSessionBanned(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if !banned {
t.Fatal("should be banned")
}
}
func TestChannelInviteOnly(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#invite")
if err != nil {
t.Fatal(err)
}
// Default: not invite-only.
isIO, err := database.IsChannelInviteOnly(ctx, chID)
if err != nil {
t.Fatal(err)
}
if isIO {
t.Fatal("should not be invite-only by default")
}
// Set invite-only.
err = database.SetChannelInviteOnly(ctx, chID, true)
if err != nil {
t.Fatal(err)
}
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
if !isIO {
t.Fatal("should be invite-only")
}
// Unset.
err = database.SetChannelInviteOnly(ctx, chID, false)
if err != nil {
t.Fatal(err)
}
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
if isIO {
t.Fatal("should not be invite-only")
}
}
func TestChannelInviteCRUD(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "invited", "", "", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#inv")
if err != nil {
t.Fatal(err)
}
// No invite initially.
has, err := database.HasChannelInvite(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if has {
t.Fatal("should not have invite")
}
// Add invite.
err = database.AddChannelInvite(ctx, chID, sid, "op")
if err != nil {
t.Fatal(err)
}
has, _ = database.HasChannelInvite(ctx, chID, sid)
if !has {
t.Fatal("should have invite")
}
// Clear invite.
err = database.ClearChannelInvite(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
has, _ = database.HasChannelInvite(ctx, chID, sid)
if has {
t.Fatal("invite should be cleared")
}
}
func TestChannelSecret(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#secret")
if err != nil {
t.Fatal(err)
}
// Default: not secret.
isSec, err := database.IsChannelSecret(ctx, chID)
if err != nil {
t.Fatal(err)
}
if isSec {
t.Fatal("should not be secret by default")
}
err = database.SetChannelSecret(ctx, chID, true)
if err != nil {
t.Fatal(err)
}
isSec, _ = database.IsChannelSecret(ctx, chID)
if !isSec {
t.Fatal("should be secret")
}
}
// createTestSession is a helper to create a session and
// return only the session ID.
func createTestSession(
t *testing.T,
database *db.Database,
nick string,
) int64 {
t.Helper()
sid, _, _, err := database.CreateSession(
t.Context(), nick, "", "", "",
)
if err != nil {
t.Fatalf("create session %s: %v", nick, err)
}
return sid
}
func TestSecretChannelFiltering(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
// Create two sessions.
sid1 := createTestSession(t, database, "member")
sid2 := createTestSession(t, database, "outsider")
// Create a secret channel.
chID, _ := database.GetOrCreateChannel(ctx, "#secret")
_ = database.SetChannelSecret(ctx, chID, true)
_ = database.JoinChannel(ctx, chID, sid1)
// Create a non-secret channel.
chID2, _ := database.GetOrCreateChannel(ctx, "#public")
_ = database.JoinChannel(ctx, chID2, sid1)
// Member should see both.
list, err := database.ListAllChannelsWithCountsFiltered(
ctx, sid1,
)
if err != nil {
t.Fatal(err)
}
if len(list) != 2 {
t.Fatalf("member should see 2 channels, got %d", len(list))
}
// Outsider should only see public.
list, _ = database.ListAllChannelsWithCountsFiltered(
ctx, sid2,
)
if len(list) != 1 {
t.Fatalf("outsider should see 1 channel, got %d", len(list))
}
if list[0].Name != "#public" {
t.Fatalf("outsider should see #public, got %s", list[0].Name)
}
}
func TestWhoisChannelFiltering(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid1 := createTestSession(t, database, "target")
sid2 := createTestSession(t, database, "querier")
// Create secret channel, target joins it.
chID, _ := database.GetOrCreateChannel(ctx, "#hidden")
_ = database.SetChannelSecret(ctx, chID, true)
_ = database.JoinChannel(ctx, chID, sid1)
// Querier (non-member) should not see the channel.
channels, err := database.GetSessionChannelsFiltered(
ctx, sid1, sid2,
)
if err != nil {
t.Fatal(err)
}
if len(channels) != 0 {
t.Fatalf(
"querier should see 0 channels, got %d",
len(channels),
)
}
// Target querying self should see it.
channels, _ = database.GetSessionChannelsFiltered(
ctx, sid1, sid1,
)
if len(channels) != 1 {
t.Fatalf(
"self-query should see 1 channel, got %d",
len(channels),
)
}
}
//nolint:dupl // structurally similar to TestChannelUserLimit
func TestChannelKey(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#keyed")
if err != nil {
t.Fatal(err)
}
// Default: no key.
key, err := database.GetChannelKey(ctx, chID)
if err != nil {
t.Fatal(err)
}
if key != "" {
t.Fatalf("expected empty key, got %q", key)
}
// Set key.
err = database.SetChannelKey(ctx, chID, "secret123")
if err != nil {
t.Fatal(err)
}
key, _ = database.GetChannelKey(ctx, chID)
if key != "secret123" {
t.Fatalf("expected secret123, got %q", key)
}
// Clear key.
err = database.SetChannelKey(ctx, chID, "")
if err != nil {
t.Fatal(err)
}
key, _ = database.GetChannelKey(ctx, chID)
if key != "" {
t.Fatalf("expected empty key, got %q", key)
}
}
//nolint:dupl // structurally similar to TestChannelKey
func TestChannelUserLimit(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#limited")
if err != nil {
t.Fatal(err)
}
// Default: no limit.
limit, err := database.GetChannelUserLimit(ctx, chID)
if err != nil {
t.Fatal(err)
}
if limit != 0 {
t.Fatalf("expected 0 limit, got %d", limit)
}
// Set limit.
err = database.SetChannelUserLimit(ctx, chID, 50)
if err != nil {
t.Fatal(err)
}
limit, _ = database.GetChannelUserLimit(ctx, chID)
if limit != 50 {
t.Fatalf("expected 50, got %d", limit)
}
// Clear limit.
err = database.SetChannelUserLimit(ctx, chID, 0)
if err != nil {
t.Fatal(err)
}
limit, _ = database.GetChannelUserLimit(ctx, chID)
if limit != 0 {
t.Fatalf("expected 0, got %d", limit)
}
}

View File

@@ -1,6 +0,0 @@
-- Bootstrap: create the schema_migrations table itself.
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
INSERT OR IGNORE INTO schema_migrations (version) VALUES (0);

View File

@@ -6,10 +6,6 @@ CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
nick TEXT NOT NULL UNIQUE,
username TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
ip TEXT NOT NULL DEFAULT '',
is_oper INTEGER NOT NULL DEFAULT 0,
password_hash TEXT NOT NULL DEFAULT '',
signing_key TEXT NOT NULL DEFAULT '',
away_message TEXT NOT NULL DEFAULT '',
@@ -24,8 +20,6 @@ CREATE TABLE IF NOT EXISTS clients (
uuid TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE,
ip TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
@@ -39,47 +33,15 @@ CREATE TABLE IF NOT EXISTS channels (
topic TEXT NOT NULL DEFAULT '',
topic_set_by TEXT NOT NULL DEFAULT '',
topic_set_at DATETIME,
hashcash_bits INTEGER NOT NULL DEFAULT 0,
is_moderated INTEGER NOT NULL DEFAULT 0,
is_topic_locked INTEGER NOT NULL DEFAULT 1,
is_invite_only INTEGER NOT NULL DEFAULT 0,
is_secret INTEGER NOT NULL DEFAULT 0,
is_no_external INTEGER NOT NULL DEFAULT 1,
channel_key TEXT NOT NULL DEFAULT '',
user_limit INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Channel bans
CREATE TABLE IF NOT EXISTS channel_bans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
mask TEXT NOT NULL,
set_by TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, mask)
);
CREATE INDEX IF NOT EXISTS idx_channel_bans_channel ON channel_bans(channel_id);
-- Channel invites (in-memory would be simpler but DB survives restarts)
CREATE TABLE IF NOT EXISTS channel_invites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
invited_by TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, session_id)
);
CREATE INDEX IF NOT EXISTS idx_channel_invites_channel ON channel_invites(channel_id);
-- Channel members
CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
is_operator INTEGER NOT NULL DEFAULT 0,
is_voiced INTEGER NOT NULL DEFAULT 0,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, session_id)
);
@@ -99,14 +61,6 @@ CREATE TABLE IF NOT EXISTS messages (
CREATE INDEX IF NOT EXISTS idx_messages_to_id ON messages(msg_to, id);
CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
-- Spent hashcash tokens for replay prevention (1-year TTL)
CREATE TABLE IF NOT EXISTS spent_hashcash (
id INTEGER PRIMARY KEY AUTOINCREMENT,
stamp_hash TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_spent_hashcash_created ON spent_hashcash(created_at);
-- Per-client message queues for fan-out delivery
CREATE TABLE IF NOT EXISTS client_queues (
id INTEGER PRIMARY KEY AUTOINCREMENT,

View File

@@ -1,25 +0,0 @@
package db
import (
"context"
"database/sql"
"log/slog"
)
// NewTestDatabaseFromConn creates a Database wrapping an
// existing *sql.DB connection. Intended for integration
// tests in other packages.
func NewTestDatabaseFromConn(conn *sql.DB) *Database {
return &Database{ //nolint:exhaustruct
conn: conn,
log: slog.Default(),
}
}
// RunMigrations applies all schema migrations. Exposed
// for integration tests in other packages.
func (database *Database) RunMigrations(
ctx context.Context,
) error {
return database.runMigrations(ctx)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,14 +2,151 @@ package handlers
import (
"encoding/json"
"net"
"net/http"
"strings"
"git.eeqj.de/sneak/neoirc/pkg/irc"
"git.eeqj.de/sneak/neoirc/internal/db"
)
const minPasswordLength = 8
// clientIP extracts the client IP address from the request.
// It checks X-Forwarded-For and X-Real-IP headers before
// falling back to RemoteAddr.
func clientIP(request *http.Request) string {
if forwarded := request.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For may contain a comma-separated list;
// the first entry is the original client.
parts := strings.SplitN(forwarded, ",", 2) //nolint:mnd // split into two parts
ip := strings.TrimSpace(parts[0])
if ip != "" {
return ip
}
}
if realIP := request.Header.Get("X-Real-IP"); realIP != "" {
return strings.TrimSpace(realIP)
}
host, _, err := net.SplitHostPort(request.RemoteAddr)
if err != nil {
return request.RemoteAddr
}
return host
}
// HandleRegister creates a new user with a password.
func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
request.Body = http.MaxBytesReader(
writer, request.Body, hdlr.maxBodySize(),
)
hdlr.handleRegister(writer, request)
}
}
func (hdlr *Handlers) handleRegister(
writer http.ResponseWriter,
request *http.Request,
) {
type registerRequest struct {
Nick string `json:"nick"`
Password string `json:"password"`
}
var payload registerRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
if len(payload.Password) < minPasswordLength {
hdlr.respondError(
writer, request,
"password must be at least 8 characters",
http.StatusBadRequest,
)
return
}
sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser(
request.Context(),
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.handleRegisterError(
writer, request, err,
)
return
}
hdlr.stats.IncrSessions()
hdlr.stats.IncrConnections()
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleRegisterError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if db.IsUniqueConstraintError(err) {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"register user failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
}
// HandleLogin authenticates a user with nick and password.
func (hdlr *Handlers) HandleLogin() http.HandlerFunc {
return func(
@@ -73,27 +210,11 @@ func (hdlr *Handlers) handleLogin(
return
}
hdlr.executeLogin(
writer, request, payload.Nick, payload.Password,
)
}
func (hdlr *Handlers) executeLogin(
writer http.ResponseWriter,
request *http.Request,
nick, password string,
) {
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), remoteIP,
)
sessionID, clientID, token, err :=
hdlr.params.Database.LoginUser(
request.Context(),
nick, password,
remoteIP, hostname,
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.respondError(
@@ -108,75 +229,18 @@ func (hdlr *Handlers) executeLogin(
hdlr.stats.IncrConnections()
hdlr.deliverMOTD(
request, clientID, sessionID, nick,
request, clientID, sessionID, payload.Nick,
)
// Initialize channel state so the new client knows
// which channels the session already belongs to.
hdlr.initChannelState(
request, clientID, sessionID, nick,
request, clientID, sessionID, payload.Nick,
)
hdlr.setAuthCookie(writer, request, token)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": nick,
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusOK)
}
// handlePass handles the IRC PASS command to set a
// password on the authenticated session, enabling
// multi-client login via POST /api/v1/login.
func (hdlr *Handlers) handlePass(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
lines := bodyLines()
if len(lines) == 0 || lines[0] == "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Not enough parameters",
)
return
}
password := lines[0]
if len(password) < minPasswordLength {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Password must be at least 8 characters",
)
return
}
err := hdlr.params.Database.SetPassword(
request.Context(), sessionID, password,
)
if err != nil {
hdlr.log.Error(
"set password failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}

View File

@@ -17,7 +17,6 @@ import (
"git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx"
)
@@ -34,29 +33,20 @@ type Params struct {
Database *db.Database
Healthcheck *healthcheck.Healthcheck
Stats *stats.Tracker
Broker *broker.Broker
Service *service.Service
}
const defaultIdleTimeout = 30 * 24 * time.Hour
// spentHashcashTTL is how long spent hashcash tokens are
// retained for replay prevention. Per issue requirements,
// this is 1 year.
const spentHashcashTTL = 365 * 24 * time.Hour
// Handlers manages HTTP request handling.
type Handlers struct {
params *Params
log *slog.Logger
hc *healthcheck.Healthcheck
broker *broker.Broker
svc *service.Service
hashcashVal *hashcash.Validator
channelHashcash *hashcash.ChannelValidator
loginLimiter *ratelimit.Limiter
stats *stats.Tracker
cancelCleanup context.CancelFunc
params *Params
log *slog.Logger
hc *healthcheck.Healthcheck
broker *broker.Broker
hashcashVal *hashcash.Validator
loginLimiter *ratelimit.Limiter
stats *stats.Tracker
cancelCleanup context.CancelFunc
}
// New creates a new Handlers instance.
@@ -80,15 +70,13 @@ func New(
}
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params,
log: params.Logger.Get(),
hc: params.Healthcheck,
broker: params.Broker,
svc: params.Service,
hashcashVal: hashcash.NewValidator(resource),
channelHashcash: hashcash.NewChannelValidator(),
loginLimiter: ratelimit.New(loginRate, loginBurst),
stats: params.Stats,
params: &params,
log: params.Logger.Get(),
hc: params.Healthcheck,
broker: broker.New(),
hashcashVal: hashcash.NewValidator(resource),
loginLimiter: ratelimit.New(loginRate, loginBurst),
stats: params.Stats,
}
lifecycle.Append(fx.Hook{
@@ -314,20 +302,4 @@ func (hdlr *Handlers) pruneQueuesAndMessages(
)
}
}
// Prune spent hashcash tokens older than 1 year.
hashcashCutoff := time.Now().Add(-spentHashcashTTL)
pruned, err := hdlr.params.Database.
PruneSpentHashcash(ctx, hashcashCutoff)
if err != nil {
hdlr.log.Error(
"spent hashcash pruning failed", "error", err,
)
} else if pruned > 0 {
hdlr.log.Info(
"pruned spent hashcash tokens",
"deleted", pruned,
)
}
}

View File

@@ -1,186 +0,0 @@
package hashcash
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"time"
)
var (
errBodyHashMismatch = errors.New(
"body hash mismatch",
)
errBodyHashMissing = errors.New(
"body hash missing",
)
)
// ChannelValidator checks hashcash stamps for
// per-channel PRIVMSG validation. It verifies that
// stamps are bound to a specific channel and message
// body. Replay prevention is handled externally via
// the database spent_hashcash table for persistence
// across server restarts (1-year TTL).
type ChannelValidator struct{}
// NewChannelValidator creates a ChannelValidator.
func NewChannelValidator() *ChannelValidator {
return &ChannelValidator{}
}
// BodyHash computes the hex-encoded SHA-256 hash of a
// message body for use in hashcash stamp validation.
func BodyHash(body []byte) string {
hash := sha256.Sum256(body)
return hex.EncodeToString(hash[:])
}
// ValidateStamp checks a channel hashcash stamp. It
// verifies the stamp format, difficulty, date, channel
// binding, body hash binding, and proof-of-work. Replay
// detection is NOT performed here — callers must check
// the spent_hashcash table separately.
//
// Stamp format: 1:bits:YYMMDD:channel:bodyhash:counter.
func (cv *ChannelValidator) ValidateStamp(
stamp string,
requiredBits int,
channel string,
bodyHash string,
) error {
if requiredBits <= 0 {
return nil
}
parts := strings.Split(stamp, ":")
if len(parts) != stampFields {
return fmt.Errorf(
"%w: expected %d, got %d",
errInvalidFields, stampFields, len(parts),
)
}
version := parts[0]
bitsStr := parts[1]
dateStr := parts[2]
resource := parts[3]
stampBodyHash := parts[4]
headerErr := validateChannelHeader(
version, bitsStr, resource,
requiredBits, channel,
)
if headerErr != nil {
return headerErr
}
stampTime, parseErr := parseStampDate(dateStr)
if parseErr != nil {
return parseErr
}
timeErr := validateTime(stampTime)
if timeErr != nil {
return timeErr
}
bodyErr := validateBodyHash(
stampBodyHash, bodyHash,
)
if bodyErr != nil {
return bodyErr
}
return validateProof(stamp, requiredBits)
}
// StampHash returns a deterministic hash of a stamp
// string for use as a spent-token key.
func StampHash(stamp string) string {
hash := sha256.Sum256([]byte(stamp))
return hex.EncodeToString(hash[:])
}
func validateChannelHeader(
version, bitsStr, resource string,
requiredBits int,
channel string,
) error {
if version != stampVersion {
return fmt.Errorf(
"%w: %s", errBadVersion, version,
)
}
claimedBits, err := strconv.Atoi(bitsStr)
if err != nil || claimedBits < requiredBits {
return fmt.Errorf(
"%w: need %d bits",
errInsufficientBits, requiredBits,
)
}
if resource != channel {
return fmt.Errorf(
"%w: got %q, want %q",
errWrongResource, resource, channel,
)
}
return nil
}
func validateBodyHash(
stampBodyHash, expectedBodyHash string,
) error {
if stampBodyHash == "" {
return errBodyHashMissing
}
if stampBodyHash != expectedBodyHash {
return fmt.Errorf(
"%w: got %q, want %q",
errBodyHashMismatch,
stampBodyHash, expectedBodyHash,
)
}
return nil
}
// MintChannelStamp computes a channel hashcash stamp
// with the given difficulty, channel name, and body hash.
// This is intended for clients to generate stamps before
// sending PRIVMSG to hashcash-protected channels.
//
// Stamp format: 1:bits:YYMMDD:channel:bodyhash:counter.
func MintChannelStamp(
bits int,
channel string,
bodyHash string,
) string {
date := time.Now().UTC().Format(dateFormatShort)
prefix := fmt.Sprintf(
"1:%d:%s:%s:%s:",
bits, date, channel, bodyHash,
)
counter := uint64(0)
for {
stamp := prefix + strconv.FormatUint(counter, 16)
hash := sha256.Sum256([]byte(stamp))
if hasLeadingZeroBits(hash[:], bits) {
return stamp
}
counter++
}
}

View File

@@ -1,244 +0,0 @@
package hashcash_test
import (
"crypto/sha256"
"encoding/hex"
"testing"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
)
const (
testChannel = "#general"
testBodyText = `["hello world"]`
)
func testBodyHash() string {
hash := sha256.Sum256([]byte(testBodyText))
return hex.EncodeToString(hash[:])
}
func TestChannelValidateHappyPath(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err != nil {
t.Fatalf("valid channel stamp rejected: %v", err)
}
}
func TestChannelValidateWrongChannel(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, testBits, "#other", bodyHash,
)
if err == nil {
t.Fatal("expected channel mismatch error")
}
}
func TestChannelValidateWrongBodyHash(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
wrongHash := sha256.Sum256([]byte("different body"))
wrongBodyHash := hex.EncodeToString(wrongHash[:])
err := validator.ValidateStamp(
stamp, testBits, testChannel, wrongBodyHash,
)
if err == nil {
t.Fatal("expected body hash mismatch error")
}
}
func TestChannelValidateInsufficientBits(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Mint with 2 bits but require 4.
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, 4, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected insufficient bits error")
}
}
func TestChannelValidateZeroBitsSkips(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
"garbage", 0, "#ch", "abc",
)
if err != nil {
t.Fatalf("zero bits should skip: %v", err)
}
}
func TestChannelValidateBadFormat(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
"not:valid", testBits, testChannel, "abc",
)
if err == nil {
t.Fatal("expected bad format error")
}
}
func TestChannelValidateBadVersion(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := "2:2:260317:#general:" + bodyHash + ":counter"
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected bad version error")
}
}
func TestChannelValidateExpiredStamp(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Mint with a very old date by manually constructing.
stamp := mintStampWithDate(
t, testBits, testChannel, "200101",
)
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected expired stamp error")
}
}
func TestChannelValidateMissingBodyHash(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Construct a stamp with empty body hash field.
stamp := mintStampWithDate(
t, testBits, testChannel, todayDate(),
)
// This uses the session-style stamp which has empty
// ext field — body hash is missing.
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected missing body hash error")
}
}
func TestBodyHash(t *testing.T) {
t.Parallel()
body := []byte(`["hello world"]`)
bodyHash := hashcash.BodyHash(body)
if len(bodyHash) != 64 {
t.Fatalf(
"expected 64-char hex hash, got %d",
len(bodyHash),
)
}
// Same input should produce same hash.
bodyHash2 := hashcash.BodyHash(body)
if bodyHash != bodyHash2 {
t.Fatal("body hash not deterministic")
}
// Different input should produce different hash.
bodyHash3 := hashcash.BodyHash([]byte("different"))
if bodyHash == bodyHash3 {
t.Fatal("different inputs produced same hash")
}
}
func TestStampHash(t *testing.T) {
t.Parallel()
hash1 := hashcash.StampHash("stamp1")
hash2 := hashcash.StampHash("stamp2")
if hash1 == hash2 {
t.Fatal("different stamps produced same hash")
}
// Same input should be deterministic.
hash1b := hashcash.StampHash("stamp1")
if hash1 != hash1b {
t.Fatal("stamp hash not deterministic")
}
}
func TestMintChannelStamp(t *testing.T) {
t.Parallel()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
if stamp == "" {
t.Fatal("expected non-empty stamp")
}
// Validate the minted stamp.
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err != nil {
t.Fatalf("minted stamp failed validation: %v", err)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,502 +0,0 @@
package ircserver
import (
"bufio"
"context"
"fmt"
"log/slog"
"net"
"strconv"
"strings"
"sync"
"time"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
const (
maxLineLen = 512
readTimeout = 5 * time.Minute
writeTimeout = 30 * time.Second
dnsTimeout = 3 * time.Second
pollInterval = 100 * time.Millisecond
pingInterval = 90 * time.Second
pongDeadline = 30 * time.Second
maxNickLen = 32
minPasswordLen = 8
maxHashcashBits = 40
)
// cmdHandler is the signature for registered IRC command
// handlers.
type cmdHandler func(ctx context.Context, msg *Message)
// Conn represents a single IRC client TCP connection.
type Conn struct {
conn net.Conn
log *slog.Logger
database *db.Database
brk *broker.Broker
cfg *config.Config
svc *service.Service
serverSfx string
commands map[string]cmdHandler
mu sync.Mutex
nick string
username string
realname string
hostname string
remoteIP string
sessionID int64
clientID int64
registered bool
gotNick bool
gotUser bool
passWord string
lastQueueID int64
closed bool
cancel context.CancelFunc
}
func newConn(
ctx context.Context,
tcpConn net.Conn,
log *slog.Logger,
database *db.Database,
brk *broker.Broker,
cfg *config.Config,
svc *service.Service,
) *Conn {
host, _, _ := net.SplitHostPort(tcpConn.RemoteAddr().String())
srvName := cfg.ServerName
if srvName == "" {
srvName = "neoirc"
}
conn := &Conn{ //nolint:exhaustruct // zero-value defaults
conn: tcpConn,
log: log,
database: database,
brk: brk,
cfg: cfg,
svc: svc,
serverSfx: srvName,
remoteIP: host,
hostname: resolveHost(ctx, host),
}
conn.commands = conn.buildCommandMap()
return conn
}
// buildCommandMap returns a map from IRC command strings
// to handler functions.
func (c *Conn) buildCommandMap() map[string]cmdHandler {
return map[string]cmdHandler{
irc.CmdPing: func(_ context.Context, msg *Message) {
c.handlePing(msg)
},
"PONG": func(context.Context, *Message) {},
irc.CmdNick: c.handleNick,
irc.CmdPrivmsg: c.handlePrivmsg,
irc.CmdNotice: c.handlePrivmsg,
irc.CmdJoin: c.handleJoin,
irc.CmdPart: c.handlePart,
irc.CmdQuit: func(_ context.Context, msg *Message) {
c.handleQuit(msg)
},
irc.CmdTopic: c.handleTopic,
irc.CmdMode: c.handleMode,
irc.CmdNames: c.handleNames,
irc.CmdList: func(ctx context.Context, _ *Message) { c.handleList(ctx) },
irc.CmdWhois: c.handleWhois,
irc.CmdWho: c.handleWho,
irc.CmdLusers: func(ctx context.Context, _ *Message) { c.handleLusers(ctx) },
irc.CmdMotd: func(context.Context, *Message) { c.deliverMOTD() },
irc.CmdOper: c.handleOper,
irc.CmdAway: c.handleAway,
irc.CmdKick: c.handleKick,
irc.CmdPass: c.handlePassPostReg,
"INVITE": c.handleInvite,
"CAP": func(_ context.Context, msg *Message) {
c.handleCAP(msg)
},
"USERHOST": c.handleUserhost,
}
}
// resolveHost does a reverse DNS lookup, returning the IP
// on failure.
func resolveHost(ctx context.Context, addr string) string {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
resolver := &net.Resolver{} //nolint:exhaustruct
names, err := resolver.LookupAddr(ctx, addr)
if err != nil || len(names) == 0 {
return addr
}
return strings.TrimSuffix(names[0], ".")
}
// serve is the main loop for a single IRC client connection.
func (c *Conn) serve(ctx context.Context) {
ctx, c.cancel = context.WithCancel(ctx)
defer c.cleanup(ctx)
scanner := bufio.NewScanner(c.conn)
scanner.Buffer(make([]byte, maxLineLen), maxLineLen)
for {
_ = c.conn.SetReadDeadline(
time.Now().Add(readTimeout),
)
if !scanner.Scan() {
return
}
line := scanner.Text()
if line == "" {
continue
}
msg := ParseMessage(line)
if msg == nil {
continue
}
c.handleMessage(ctx, msg)
if c.closed {
return
}
}
}
func (c *Conn) cleanup(ctx context.Context) {
c.mu.Lock()
wasRegistered := c.registered
sessID := c.sessionID
nick := c.nick
c.closed = true
c.mu.Unlock()
if wasRegistered && sessID > 0 {
c.svc.BroadcastQuit(
ctx, sessID, nick, "Connection closed",
)
}
c.conn.Close() //nolint:errcheck,gosec
}
// send writes a formatted IRC line to the connection.
func (c *Conn) send(line string) {
_ = c.conn.SetWriteDeadline(
time.Now().Add(writeTimeout),
)
_, _ = fmt.Fprintf(c.conn, "%s\r\n", line)
}
// sendNumeric sends a numeric reply from the server.
func (c *Conn) sendNumeric(
code irc.IRCMessageType,
params ...string,
) {
nick := c.nick
if nick == "" {
nick = "*"
}
allParams := make([]string, 0, 1+len(params))
allParams = append(allParams, nick)
allParams = append(allParams, params...)
c.send(FormatMessage(
c.serverSfx, code.Code(), allParams...,
))
}
// sendFromServer sends a message from the server.
func (c *Conn) sendFromServer(
command string, params ...string,
) {
c.send(FormatMessage(c.serverSfx, command, params...))
}
// hostmask returns the client's full hostmask
// (nick!user@host).
func (c *Conn) hostmask() string {
user := c.username
if user == "" {
user = c.nick
}
host := c.hostname
if host == "" {
host = c.remoteIP
}
return c.nick + "!" + user + "@" + host
}
// handleMessage dispatches a parsed IRC message using
// the command handler map.
func (c *Conn) handleMessage(
ctx context.Context,
msg *Message,
) {
// Before registration, only NICK, USER, PASS, PING,
// QUIT, and CAP are accepted.
if !c.registered {
c.handlePreRegistration(ctx, msg)
return
}
handler, ok := c.commands[msg.Command]
if !ok {
c.sendNumeric(
irc.ErrUnknownCommand,
msg.Command, "Unknown command",
)
return
}
handler(ctx, msg)
}
// handlePreRegistration handles messages before the
// connection is registered (NICK+USER received).
func (c *Conn) handlePreRegistration(
ctx context.Context,
msg *Message,
) {
switch msg.Command {
case irc.CmdPass:
if len(msg.Params) < 1 {
c.sendNumeric(
irc.ErrNeedMoreParams,
"PASS", "Not enough parameters",
)
return
}
c.passWord = msg.Params[0]
case irc.CmdNick:
if len(msg.Params) < 1 {
c.sendNumeric(
irc.ErrNoNicknameGiven,
"No nickname given",
)
return
}
c.nick = msg.Params[0]
if len(c.nick) > maxNickLen {
c.nick = c.nick[:maxNickLen]
}
c.gotNick = true
case irc.CmdUser:
if len(msg.Params) < 4 { //nolint:mnd
c.sendNumeric(
irc.ErrNeedMoreParams,
"USER", "Not enough parameters",
)
return
}
c.username = msg.Params[0]
c.realname = msg.Params[3]
c.gotUser = true
case irc.CmdPing:
c.handlePing(msg)
return
case irc.CmdQuit:
c.handleQuit(msg)
return
case "CAP":
c.handleCAP(msg)
return
default:
c.sendNumeric(
irc.ErrNotRegistered,
"You have not registered",
)
return
}
// Try to complete registration once we have both
// NICK and USER.
if c.gotNick && c.gotUser {
c.completeRegistration(ctx)
}
}
// completeRegistration creates a session and sends the
// welcome burst.
func (c *Conn) completeRegistration(ctx context.Context) {
// Check if nick is valid.
if c.nick == "" {
c.sendNumeric(
irc.ErrNoNicknameGiven, "No nickname given",
)
return
}
// Create session in DB.
sessionID, clientID, _, err := c.database.CreateSession(
ctx, c.nick, c.username, c.hostname, c.remoteIP,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") ||
strings.Contains(err.Error(), "nick") {
c.sendNumeric(
irc.ErrNicknameInUse,
c.nick, "Nickname is already in use",
)
return
}
c.log.Error(
"failed to create session", "error", err,
)
c.send("ERROR :Internal server error")
c.closed = true
return
}
c.mu.Lock()
c.sessionID = sessionID
c.clientID = clientID
c.registered = true
c.mu.Unlock()
// If PASS was provided before registration, set the
// session password.
if c.passWord != "" && len(c.passWord) >= minPasswordLen {
c.setPassword(ctx, c.passWord)
}
// Send welcome burst.
c.deliverWelcome()
c.deliverLusers(ctx)
c.deliverMOTD()
// Start the message relay goroutine.
go c.relayMessages(ctx)
}
// deliverWelcome sends 001-005 welcome numerics.
func (c *Conn) deliverWelcome() {
c.sendNumeric(irc.RplWelcome, fmt.Sprintf(
"Welcome to the %s Network, %s",
c.serverSfx, c.hostmask(),
))
c.sendNumeric(irc.RplYourHost, fmt.Sprintf(
"Your host is %s, running version neoirc",
c.serverSfx,
))
c.sendNumeric(
irc.RplCreated,
"This server was created recently",
)
c.sendNumeric(
irc.RplMyInfo,
c.serverSfx, "neoirc", "", "mnst",
)
c.sendNumeric(
irc.RplIsupport,
"CHANTYPES=#",
"NICKLEN=32",
"PREFIX=(ov)@+",
"CHANMODES=,,H,imnst",
"NETWORK="+c.serverSfx,
"are supported by this server",
)
}
// deliverLusers sends 251/252/254/255 server statistics.
func (c *Conn) deliverLusers(ctx context.Context) {
users, _ := c.database.GetUserCount(ctx)
opers, _ := c.database.GetOperCount(ctx)
channels, _ := c.database.GetChannelCount(ctx)
c.sendNumeric(irc.RplLuserClient, fmt.Sprintf(
"There are %d users and 0 invisible on 1 servers",
users,
))
c.sendNumeric(
irc.RplLuserOp,
strconv.FormatInt(opers, 10),
"operator(s) online",
)
c.sendNumeric(
irc.RplLuserChannels,
strconv.FormatInt(channels, 10),
"channels formed",
)
c.sendNumeric(irc.RplLuserMe, fmt.Sprintf(
"I have %d clients and 1 servers", users,
))
}
// deliverMOTD sends 375/372/376 MOTD lines.
func (c *Conn) deliverMOTD() {
motd := c.cfg.MOTD
if motd == "" {
c.sendNumeric(
irc.ErrNoMotd, "MOTD File is missing",
)
return
}
c.sendNumeric(irc.RplMotdStart, fmt.Sprintf(
"- %s Message of the Day -", c.serverSfx,
))
for _, line := range strings.Split(motd, "\n") {
c.sendNumeric(irc.RplMotd, "- "+line)
}
c.sendNumeric(
irc.RplEndOfMotd, "End of /MOTD command",
)
}
// setPassword sets a bcrypt password on the session.
func (c *Conn) setPassword(ctx context.Context, pw string) {
// Use the database's auth module to hash and store.
err := c.database.SetPassword(ctx, c.sessionID, pw)
if err != nil {
c.log.Error(
"failed to set password", "error", err,
)
}
}

View File

@@ -1,49 +0,0 @@
package ircserver
import (
"context"
"log/slog"
"net"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/service"
)
// NewTestServer creates a Server suitable for testing.
// The caller must call Stop() when finished.
func NewTestServer(
log *slog.Logger,
cfg *config.Config,
database *db.Database,
brk *broker.Broker,
) *Server {
svc := service.NewTestService(
database, brk, cfg, log,
)
return &Server{ //nolint:exhaustruct
log: log,
cfg: cfg,
database: database,
brk: brk,
svc: svc,
conns: make(map[*Conn]struct{}),
}
}
// Start exposes the unexported start method for tests.
func (s *Server) Start(addr string) error {
return s.start(context.Background(), addr)
}
// Stop exposes the unexported stop method for tests.
func (s *Server) Stop() {
s.stop()
}
// Listener returns the server's net.Listener for tests.
func (s *Server) Listener() net.Listener {
return s.listener
}

View File

@@ -1,913 +0,0 @@
package ircserver_test
import (
"strings"
"testing"
"time"
)
// TestIntegrationTwoClients is a comprehensive integration
// test that spawns the IRC server programmatically, connects
// two real TCP clients, and verifies all major IRC features
// including cross-client message delivery.
//
// The test runs sequentially through IRC features because
// both clients share the same channel state. Each section
// builds on the previous one (e.g. alice and bob must be
// JOINed before PRIVMSG can be tested).
//
//nolint:cyclop,funlen,maintidx // integration test
func TestIntegrationTwoClients(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
bob := env.dial(t)
// ── Registration ──────────────────────────────────
aliceWelcome := alice.register("alice")
assertContains(
t, aliceWelcome, " 001 ", "RPL_WELCOME alice",
)
assertContains(
t, aliceWelcome, " 002 ", "RPL_YOURHOST alice",
)
assertContains(
t, aliceWelcome, " 003 ", "RPL_CREATED alice",
)
assertContains(
t, aliceWelcome, " 004 ", "RPL_MYINFO alice",
)
assertContains(
t, aliceWelcome, "alice",
"nick in welcome burst",
)
bobWelcome := bob.register("bob")
assertContains(
t, bobWelcome, " 001 ", "RPL_WELCOME bob",
)
assertContains(
t, bobWelcome, "bob",
"nick in welcome burst",
)
// ── JOIN and cross-client visibility ──────────────
alice.send("JOIN #integration")
aliceJoinLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
assertContains(
t, aliceJoinLines, "JOIN",
"alice receives JOIN echo",
)
assertContains(
t, aliceJoinLines, " 366 ",
"RPL_ENDOFNAMES for alice",
)
bob.send("JOIN #integration")
bobJoinLines := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
assertContains(
t, bobJoinLines, "JOIN",
"bob receives JOIN echo",
)
// Alice should see bob's JOIN via relay.
aliceSeesBob := alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
assertContains(
t, aliceSeesBob, "bob",
"alice sees bob's JOIN",
)
// ── PRIVMSG (channel) — alice to bob ──────────────
alice.send("PRIVMSG #integration :hello from alice")
bobGetsMsg := bob.readUntil(func(l string) bool {
return strings.Contains(l, "hello from alice")
})
assertContains(
t, bobGetsMsg, "hello from alice",
"bob receives alice's channel message",
)
// ── PRIVMSG (channel) — bob to alice ──────────────
bob.send("PRIVMSG #integration :hello from bob")
aliceGetsMsg := alice.readUntil(func(l string) bool {
return strings.Contains(l, "hello from bob")
})
assertContains(
t, aliceGetsMsg, "hello from bob",
"alice receives bob's channel message",
)
// ── PRIVMSG (DM) — alice to bob ──────────────────
alice.send("PRIVMSG bob :secret message")
bobDM := bob.readUntil(func(l string) bool {
return strings.Contains(l, "secret message")
})
assertContains(
t, bobDM, "secret message",
"bob receives alice's DM",
)
assertContains(
t, bobDM, "alice",
"DM from field is alice",
)
// ── PRIVMSG (DM) — bob to alice ──────────────────
bob.send("PRIVMSG alice :reply to you")
aliceDM := alice.readUntil(func(l string) bool {
return strings.Contains(l, "reply to you")
})
assertContains(
t, aliceDM, "reply to you",
"alice receives bob's DM",
)
// ── NOTICE (channel) ──────────────────────────────
alice.send("NOTICE #integration :notice msg")
bobNotice := bob.readUntil(func(l string) bool {
return strings.Contains(l, "notice msg")
})
assertContains(
t, bobNotice, "NOTICE",
"bob receives NOTICE command",
)
assertContains(
t, bobNotice, "notice msg",
"bob receives NOTICE text",
)
// ── NOTICE (DM) ──────────────────────────────────
bob.send("NOTICE alice :dm notice")
aliceNotice := alice.readUntil(func(l string) bool {
return strings.Contains(l, "dm notice")
})
assertContains(
t, aliceNotice, "dm notice",
"alice receives DM NOTICE",
)
// ── TOPIC ─────────────────────────────────────────
// alice is the channel creator so she is +o.
alice.send("TOPIC #integration :Integration Test Topic")
aliceTopic := alice.readUntil(func(l string) bool {
return strings.Contains(
l, "Integration Test Topic",
)
})
assertContains(
t, aliceTopic, "Integration Test Topic",
"alice sees TOPIC echo",
)
bobTopic := bob.readUntil(func(l string) bool {
return strings.Contains(
l, "Integration Test Topic",
)
})
assertContains(
t, bobTopic, "Integration Test Topic",
"bob receives TOPIC change",
)
// ── MODE (query) ──────────────────────────────────
alice.send("MODE #integration")
aliceMode := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 324 ")
})
assertContains(
t, aliceMode, " 324 ",
"RPL_CHANNELMODEIS",
)
// ── MODE (+m moderated, then -m) ──────────────────
alice.send("MODE #integration +m")
aliceModeM := alice.readUntil(func(l string) bool {
return strings.Contains(l, "MODE") &&
strings.Contains(l, "+m")
})
assertContains(
t, aliceModeM, "+m",
"alice sees MODE +m echo",
)
bobModeM := bob.readUntil(func(l string) bool {
return strings.Contains(l, "+m")
})
assertContains(
t, bobModeM, "+m",
"bob sees MODE +m relay",
)
// Revert moderated mode.
alice.send("MODE #integration -m")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "-m")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "-m")
})
// ── MODE (+v voice, then -v) ──────────────────────
alice.send("MODE #integration +v bob")
aliceVoice := alice.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
assertContains(
t, aliceVoice, "+v",
"alice sees +v echo",
)
bobVoice := bob.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
assertContains(
t, bobVoice, "+v",
"bob receives +v relay",
)
// Remove voice.
alice.send("MODE #integration -v bob")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "-v")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "-v")
})
// ── NAMES ─────────────────────────────────────────
alice.send("NAMES #integration")
aliceNames := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
assertContains(
t, aliceNames, " 353 ",
"RPL_NAMREPLY",
)
assertContains(
t, aliceNames, " 366 ",
"RPL_ENDOFNAMES",
)
// Both nicks should appear in the name list.
foundBothNames := false
for _, line := range aliceNames {
if strings.Contains(line, " 353 ") &&
strings.Contains(line, "alice") &&
strings.Contains(line, "bob") {
foundBothNames = true
break
}
}
if !foundBothNames {
t.Error("NAMES reply should list both alice and bob")
}
// ── LIST ──────────────────────────────────────────
alice.send("LIST")
aliceList := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 323 ")
})
assertContains(
t, aliceList, " 322 ",
"RPL_LIST entry",
)
assertContains(
t, aliceList, "#integration",
"LIST includes #integration",
)
assertContains(
t, aliceList, " 323 ", //nolint:misspell // IRC RPL_LISTEND
"RPL_LISTEND", //nolint:misspell // IRC term
)
// ── WHO ───────────────────────────────────────────
bob.send("WHO #integration")
bobWho := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 315 ")
})
assertContains(
t, bobWho, " 352 ",
"RPL_WHOREPLY",
)
assertContains(
t, bobWho, " 315 ",
"RPL_ENDOFWHO",
)
// ── WHOIS ─────────────────────────────────────────
alice.send("WHOIS bob")
aliceWhois := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 318 ")
})
assertContains(
t, aliceWhois, " 311 ",
"RPL_WHOISUSER",
)
assertContains(
t, aliceWhois, " 312 ",
"RPL_WHOISSERVER",
)
assertContains(
t, aliceWhois, " 318 ",
"RPL_ENDOFWHOIS",
)
// ── WHOIS with channels ───────────────────────────
bob.send("WHOIS alice")
bobWhois := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 318 ")
})
assertContains(
t, bobWhois, " 319 ",
"RPL_WHOISCHANNELS",
)
assertContains(
t, bobWhois, "#integration",
"WHOIS shows #integration channel",
)
// ── LUSERS ────────────────────────────────────────
alice.send("LUSERS")
aliceLusers := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 255 ")
})
assertContains(
t, aliceLusers, " 251 ",
"RPL_LUSERCLIENT",
)
assertContains(
t, aliceLusers, " 255 ",
"RPL_LUSERME",
)
// ── NICK change ───────────────────────────────────
bob.send("NICK bobby")
bobNick := bob.readUntil(func(l string) bool {
return strings.Contains(l, "NICK") &&
strings.Contains(l, "bobby")
})
assertContains(
t, bobNick, "bobby",
"bob sees NICK change to bobby",
)
// alice should see the nick change relayed.
aliceNick := alice.readUntil(func(l string) bool {
return strings.Contains(l, "bobby")
})
assertContains(
t, aliceNick, "NICK",
"alice sees NICK command",
)
assertContains(
t, aliceNick, "bobby",
"alice sees new nick bobby",
)
// Change it back for remaining tests.
bob.send("NICK bob")
bob.readUntil(func(l string) bool {
return strings.Contains(l, "bob")
})
alice.readUntil(func(l string) bool {
return strings.Contains(l, "NICK") &&
strings.Contains(l, "bob")
})
// ── Duplicate NICK ────────────────────────────────
bob.send("NICK alice")
bobDupNick := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 433 ")
})
assertContains(
t, bobDupNick, " 433 ",
"ERR_NICKNAMEINUSE",
)
// ── KICK ──────────────────────────────────────────
// alice is op; she kicks bob.
alice.send("KICK #integration bob :testing kick")
aliceKick := alice.readUntil(func(l string) bool {
return strings.Contains(l, "KICK")
})
assertContains(
t, aliceKick, "KICK",
"alice sees KICK echo",
)
assertContains(
t, aliceKick, "bob",
"KICK mentions bob",
)
bobKick := bob.readUntil(func(l string) bool {
return strings.Contains(l, "KICK")
})
assertContains(
t, bobKick, "KICK",
"bob receives KICK",
)
assertContains(
t, bobKick, "testing kick",
"KICK reason is relayed",
)
// bob rejoins.
bob.joinAndDrain("#integration")
// Drain alice's view of the rejoin.
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
// ── KICK non-op should fail ───────────────────────
bob.send("KICK #integration alice :nope")
bobKickFail := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 482 ")
})
assertContains(
t, bobKickFail, " 482 ",
"ERR_CHANOPRIVSNEEDED",
)
// ── TOPIC lock (+t default) ───────────────────────
// +t is default, so bob should not be able to set
// topic.
bob.send("TOPIC #integration :bob tries topic")
bobTopicFail := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 482 ")
})
assertContains(
t, bobTopicFail, " 482 ",
"ERR_CHANOPRIVSNEEDED for topic",
)
// ── PING / PONG ───────────────────────────────────
alice.send("PING :testtoken")
alicePong := alice.readUntil(func(l string) bool {
return strings.Contains(l, "PONG")
})
assertContains(
t, alicePong, "PONG",
"PONG response received",
)
// ── Unknown command ───────────────────────────────
bob.send("FOOBAR")
bobUnknown := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 421 ")
})
assertContains(
t, bobUnknown, " 421 ",
"ERR_UNKNOWNCOMMAND",
)
// ── MOTD ──────────────────────────────────────────
alice.send("MOTD")
aliceMOTD := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 376 ")
})
assertContains(
t, aliceMOTD, " 376 ",
"RPL_ENDOFMOTD",
)
// ── AWAY (set, check via DM, clear) ───────────────
alice.send("AWAY :gone fishing")
aliceAway := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 306 ")
})
assertContains(
t, aliceAway, " 306 ",
"RPL_NOWAWAY",
)
// bob DMs alice — should get RPL_AWAY.
bob.send("PRIVMSG alice :are you there?")
bobAwayReply := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 301 ")
})
assertContains(
t, bobAwayReply, " 301 ",
"RPL_AWAY for bob when messaging alice",
)
assertContains(
t, bobAwayReply, "gone fishing",
"away message relayed",
)
// Clear away.
alice.send("AWAY")
alice.readUntil(func(l string) bool {
return strings.Contains(l, " 305 ")
})
// ── PASS (set password post-registration) ─────────
alice.send("PASS :mypassword123")
alicePass := alice.readUntil(func(l string) bool {
return strings.Contains(l, "Password set")
})
assertContains(
t, alicePass, "Password set",
"password set confirmation",
)
// ── MODE -t/+t topic lock toggle ──────────────────
alice.send("MODE #integration -t")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "-t")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "-t")
})
// Now bob should be able to set topic.
bob.send("TOPIC #integration :bob sets topic now")
bobTopicOK := bob.readUntil(func(l string) bool {
return strings.Contains(l, "bob sets topic now")
})
assertContains(
t, bobTopicOK, "bob sets topic now",
"bob can set topic after -t",
)
// alice sees the topic change.
aliceTopicRelay := alice.readUntil(func(l string) bool {
return strings.Contains(l, "bob sets topic now")
})
assertContains(
t, aliceTopicRelay, "bob sets topic now",
"alice sees bob's topic after -t",
)
// Restore +t.
alice.send("MODE #integration +t")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "+t")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "+t")
})
// ── DM to nonexistent nick ────────────────────────
alice.send("PRIVMSG nobody123 :hello")
aliceNoSuch := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 401 ")
})
assertContains(
t, aliceNoSuch, " 401 ",
"ERR_NOSUCHNICK",
)
// ── PART with reason ──────────────────────────────
bob.send("PART #integration :bye for now")
bobPart := bob.readUntil(func(l string) bool {
return strings.Contains(l, "PART")
})
assertContains(
t, bobPart, "PART",
"bob sees PART echo",
)
// alice sees bob PART via relay.
alicePart := alice.readUntil(func(l string) bool {
return strings.Contains(l, "PART") &&
strings.Contains(l, "bob")
})
assertContains(
t, alicePart, "bob",
"alice sees bob's PART",
)
assertContains(
t, alicePart, "bye for now",
"PART reason is relayed",
)
// bob rejoins for remaining tests.
bob.joinAndDrain("#integration")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
// ── PART non-existent channel ─────────────────────
bob.send("PART #nonexistent")
bobPartFail := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 403 ") ||
strings.Contains(l, " 442 ")
})
foundPartErr := false
for _, line := range bobPartFail {
if strings.Contains(line, " 403 ") ||
strings.Contains(line, " 442 ") {
foundPartErr = true
break
}
}
if !foundPartErr {
t.Error(
"expected ERR_NOSUCHCHANNEL or " +
"ERR_NOTONCHANNEL",
)
}
// ── User MODE query ───────────────────────────────
alice.send("MODE alice")
aliceUMode := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 221 ")
})
assertContains(
t, aliceUMode, " 221 ",
"RPL_UMODEIS",
)
// ── Multiple channel operation ────────────────────
alice.send("JOIN #second")
alice.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
bob.send("JOIN #second")
bob.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
// Drain alice seeing bob join.
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
alice.send("PRIVMSG #second :cross-channel test")
bobCross := bob.readUntil(func(l string) bool {
return strings.Contains(l, "cross-channel test")
})
assertContains(
t, bobCross, "cross-channel test",
"bob receives message in #second",
)
// Clean up #second.
alice.send("PART #second")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "PART")
})
bob.send("PART #second")
bob.readUntil(func(l string) bool {
return strings.Contains(l, "PART")
})
// ── QUIT ──────────────────────────────────────────
bob.send("QUIT :integration test done")
bobQuit := bob.readUntil(func(l string) bool {
return strings.Contains(l, "ERROR")
})
assertContains(
t, bobQuit, "integration test done",
"QUIT reason echoed",
)
// alice should see bob's QUIT via relay.
aliceQuit := alice.readUntil(func(l string) bool {
return strings.Contains(l, "QUIT") &&
strings.Contains(l, "bob")
})
assertContains(
t, aliceQuit, "bob",
"alice sees bob's QUIT",
)
}
// TestIntegrationModeSecret tests +s (secret) channel
// mode — verifies that +s can be set and the mode is
// reflected in MODE queries.
func TestIntegrationModeSecret(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice")
alice.joinAndDrain("#secretroom")
// Set +s.
alice.send("MODE #secretroom +s")
aliceLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, "+s")
})
assertContains(
t, aliceLines, "+s",
"alice sees MODE +s confirmation",
)
// Verify mode is reflected in query.
alice.send("MODE #secretroom")
modeLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 324 ")
})
assertContains(
t, modeLines, "s",
"channel mode includes s",
)
}
// TestIntegrationModeModerated tests +m (moderated) mode
// — non-voiced users cannot send.
func TestIntegrationModeModerated(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice")
bob := env.dial(t)
bob.register("bob")
alice.joinAndDrain("#modtest")
bob.joinAndDrain("#modtest")
// Drain alice's view of bob's join.
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
// Set +m.
alice.send("MODE #modtest +m")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "+m")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "+m")
})
// bob should get an error trying to send.
bob.send("PRIVMSG #modtest :should fail")
bobLines := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 404 ") ||
strings.Contains(l, " 482 ")
})
foundModErr := false
for _, line := range bobLines {
if strings.Contains(line, " 404 ") ||
strings.Contains(line, " 482 ") {
foundModErr = true
break
}
}
if !foundModErr {
t.Error(
"non-voiced user should not be able to send " +
"in +m channel",
)
}
// Grant +v to bob, then he should be able to send.
alice.send("MODE #modtest +v bob")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
bob.send("PRIVMSG #modtest :voiced message")
aliceLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, "voiced message")
})
assertContains(
t, aliceLines, "voiced message",
"alice receives voiced bob's message",
)
}
// TestIntegrationThirdClientObserver verifies that a third
// client observing the same channel receives messages from
// the other two.
func TestIntegrationThirdClientObserver(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice")
bob := env.dial(t)
bob.register("bob")
carol := env.dial(t)
carol.register("carol")
alice.joinAndDrain("#trio")
bob.joinAndDrain("#trio")
carol.joinAndDrain("#trio")
// Drain join notifications.
time.Sleep(100 * time.Millisecond)
// alice sends; both bob and carol should receive.
alice.send("PRIVMSG #trio :hello trio")
bobLines := bob.readUntil(func(l string) bool {
return strings.Contains(l, "hello trio")
})
assertContains(
t, bobLines, "hello trio",
"bob receives trio message",
)
carolLines := carol.readUntil(func(l string) bool {
return strings.Contains(l, "hello trio")
})
assertContains(
t, carolLines, "hello trio",
"carol receives trio message",
)
}

View File

@@ -1,123 +0,0 @@
// Package ircserver implements a traditional IRC wire protocol
// listener (RFC 1459/2812) that bridges to the neoirc HTTP/JSON
// server internals.
package ircserver
import "strings"
// Message represents a parsed IRC wire protocol message.
type Message struct {
// Prefix is the optional :prefix at the start (may be
// empty for client-to-server messages).
Prefix string
// Command is the IRC command (e.g., "PRIVMSG", "NICK").
Command string
// Params holds the positional parameters, including the
// trailing parameter (which was preceded by ':' on the
// wire).
Params []string
}
// ParseMessage parses a single IRC wire protocol line
// (without the trailing CR-LF) into a Message.
// Returns nil if the line is empty.
//
// IRC message format (RFC 1459 §2.3.1):
//
// [":" prefix SPACE] command { SPACE param } [SPACE ":" trailing]
func ParseMessage(line string) *Message {
if line == "" {
return nil
}
msg := &Message{} //nolint:exhaustruct // fields set below
// Extract prefix if present.
if line[0] == ':' {
idx := strings.IndexByte(line, ' ')
if idx < 0 {
// Only a prefix, no command — invalid.
return nil
}
msg.Prefix = line[1:idx]
line = line[idx+1:]
}
// Skip leading spaces.
line = strings.TrimLeft(line, " ")
if line == "" {
return nil
}
// Extract command.
idx := strings.IndexByte(line, ' ')
if idx < 0 {
msg.Command = strings.ToUpper(line)
return msg
}
msg.Command = strings.ToUpper(line[:idx])
line = line[idx+1:]
// Extract parameters.
for line != "" {
line = strings.TrimLeft(line, " ")
if line == "" {
break
}
// Trailing parameter (everything after ':').
if line[0] == ':' {
msg.Params = append(msg.Params, line[1:])
break
}
idx = strings.IndexByte(line, ' ')
if idx < 0 {
msg.Params = append(msg.Params, line)
break
}
msg.Params = append(msg.Params, line[:idx])
line = line[idx+1:]
}
return msg
}
// FormatMessage formats an IRC message into wire protocol
// format (without the trailing CR-LF).
func FormatMessage(
prefix, command string,
params ...string,
) string {
var buf strings.Builder
if prefix != "" {
buf.WriteByte(':')
buf.WriteString(prefix)
buf.WriteByte(' ')
}
buf.WriteString(command)
for i, param := range params {
buf.WriteByte(' ')
isLast := i == len(params)-1
needsColon := strings.Contains(param, " ") ||
param == "" || param[0] == ':'
if isLast && needsColon {
buf.WriteByte(':')
}
buf.WriteString(param)
}
return buf.String()
}

View File

@@ -1,328 +0,0 @@
package ircserver_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
)
//nolint:funlen // table-driven test
func TestParseMessage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want *ircserver.Message
wantNil bool
}{
{
name: "empty",
input: "",
want: nil,
wantNil: true,
},
{
name: "simple command",
input: "PING",
want: &ircserver.Message{
Prefix: "",
Command: "PING",
Params: nil,
},
wantNil: false,
},
{
name: "command with one param",
input: "NICK alice",
want: &ircserver.Message{
Prefix: "",
Command: "NICK",
Params: []string{"alice"},
},
wantNil: false,
},
{
name: "command case insensitive",
input: "nick Alice",
want: &ircserver.Message{
Prefix: "",
Command: "NICK",
Params: []string{"Alice"},
},
wantNil: false,
},
{
name: "privmsg with trailing",
input: "PRIVMSG #general :hello world",
want: &ircserver.Message{
Prefix: "",
Command: "PRIVMSG",
Params: []string{"#general", "hello world"},
},
wantNil: false,
},
{
name: "with prefix",
input: ":server.example.com 001 alice :Welcome to IRC",
want: &ircserver.Message{
Prefix: "server.example.com",
Command: "001",
Params: []string{"alice", "Welcome to IRC"},
},
wantNil: false,
},
{
name: "user command",
input: "USER alice 0 * :Alice Smith",
want: &ircserver.Message{
Prefix: "",
Command: "USER",
Params: []string{
"alice", "0", "*", "Alice Smith",
},
},
wantNil: false,
},
{
name: "join channel",
input: "JOIN #general",
want: &ircserver.Message{
Prefix: "",
Command: "JOIN",
Params: []string{"#general"},
},
wantNil: false,
},
{
name: "quit with trailing",
input: "QUIT :leaving now",
want: &ircserver.Message{
Prefix: "",
Command: "QUIT",
Params: []string{"leaving now"},
},
wantNil: false,
},
{
name: "quit without reason",
input: "QUIT",
want: &ircserver.Message{
Prefix: "",
Command: "QUIT",
Params: nil,
},
wantNil: false,
},
{
name: "mode query",
input: "MODE #general",
want: &ircserver.Message{
Prefix: "",
Command: "MODE",
Params: []string{"#general"},
},
wantNil: false,
},
{
name: "kick with reason",
input: "KICK #general bob :misbehaving",
want: &ircserver.Message{
Prefix: "",
Command: "KICK",
Params: []string{
"#general", "bob", "misbehaving",
},
},
wantNil: false,
},
{
name: "empty trailing",
input: "PRIVMSG #general :",
want: &ircserver.Message{
Prefix: "",
Command: "PRIVMSG",
Params: []string{"#general", ""},
},
wantNil: false,
},
{
name: "pass command",
input: "PASS mysecret",
want: &ircserver.Message{
Prefix: "",
Command: "PASS",
Params: []string{"mysecret"},
},
wantNil: false,
},
{
name: "ping with server",
input: "PING :irc.example.com",
want: &ircserver.Message{
Prefix: "",
Command: "PING",
Params: []string{"irc.example.com"},
},
wantNil: false,
},
{
name: "topic with trailing spaces",
input: "TOPIC #general :Welcome to the channel!",
want: &ircserver.Message{
Prefix: "",
Command: "TOPIC",
Params: []string{
"#general",
"Welcome to the channel!",
},
},
wantNil: false,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
got := ircserver.ParseMessage(testCase.input)
if testCase.wantNil {
if got != nil {
t.Fatalf("expected nil, got %+v", got)
}
return
}
if got == nil {
t.Fatal("expected non-nil message")
}
if got.Prefix != testCase.want.Prefix {
t.Errorf(
"prefix: got %q, want %q",
got.Prefix, testCase.want.Prefix,
)
}
if got.Command != testCase.want.Command {
t.Errorf(
"command: got %q, want %q",
got.Command, testCase.want.Command,
)
}
if len(got.Params) != len(testCase.want.Params) {
t.Fatalf(
"params length: got %d, want %d (%v vs %v)",
len(got.Params),
len(testCase.want.Params),
got.Params,
testCase.want.Params,
)
}
for i, p := range got.Params {
if p != testCase.want.Params[i] {
t.Errorf(
"param[%d]: got %q, want %q",
i, p, testCase.want.Params[i],
)
}
}
})
}
}
func TestFormatMessage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prefix string
command string
params []string
want string
}{
{
name: "simple command",
prefix: "",
command: "PING",
params: nil,
want: "PING",
},
{
name: "with prefix",
prefix: "server",
command: "PONG",
params: []string{"server"},
want: ":server PONG server",
},
{
name: "privmsg with trailing",
prefix: "alice!alice@host",
command: "PRIVMSG",
params: []string{"#general", "hello world"},
want: ":alice!alice@host PRIVMSG #general :hello world",
},
{
name: "numeric reply",
prefix: "server",
command: "001",
params: []string{"alice", "Welcome to IRC"},
want: ":server 001 alice :Welcome to IRC",
},
{
name: "empty trailing",
prefix: "server",
command: "PRIVMSG",
params: []string{"#chan", ""},
want: ":server PRIVMSG #chan :",
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
got := ircserver.FormatMessage(
testCase.prefix, testCase.command, testCase.params...,
)
if got != testCase.want {
t.Errorf("got %q, want %q", got, testCase.want)
}
})
}
}
func TestParseFormatRoundTrip(t *testing.T) {
t.Parallel()
// Round-trip only works for lines where the last
// parameter either contains a space (gets ':' prefix
// on format) or is a non-trailing single token.
lines := []string{
"PING",
"NICK alice",
"PRIVMSG #general :hello world",
"JOIN #general",
"MODE #general",
}
for _, line := range lines {
msg := ircserver.ParseMessage(line)
if msg == nil {
t.Fatalf("failed to parse: %q", line)
}
formatted := ircserver.FormatMessage(
msg.Prefix, msg.Command, msg.Params...,
)
if formatted != line {
t.Errorf(
"round-trip failed: input %q, got %q",
line, formatted,
)
}
}
}

View File

@@ -1,319 +0,0 @@
package ircserver
import (
"context"
"encoding/json"
"strings"
"time"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
// relayMessages polls the client output queue and delivers
// IRC-formatted messages to the TCP connection. It runs
// in a goroutine for the lifetime of the connection.
func (c *Conn) relayMessages(ctx context.Context) {
// Use a ticker as a fallback; primary wakeup is via
// broker notification.
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
default:
}
// Drain any available messages.
delivered := c.drainQueue(ctx)
if delivered {
// Tight loop while there are messages.
continue
}
// Wait for notification or timeout.
waitCh := c.brk.Wait(c.sessionID)
select {
case <-waitCh:
// New message notification — loop back.
case <-ticker.C:
// Periodic check.
case <-ctx.Done():
c.brk.Remove(c.sessionID, waitCh)
return
}
}
}
const relayPollLimit = 100
// drainQueue polls the output queue and delivers all
// pending messages. Returns true if at least one message
// was delivered.
func (c *Conn) drainQueue(ctx context.Context) bool {
msgs, lastID, err := c.database.PollMessages(
ctx, c.clientID, c.lastQueueID, relayPollLimit,
)
if err != nil {
return false
}
if len(msgs) == 0 {
return false
}
for i := range msgs {
c.deliverIRCMessage(ctx, &msgs[i])
}
if lastID > c.lastQueueID {
c.lastQueueID = lastID
}
return true
}
// deliverIRCMessage converts a db.IRCMessage to wire
// protocol and sends it.
//
//nolint:cyclop // dispatch table
func (c *Conn) deliverIRCMessage(
_ context.Context,
msg *db.IRCMessage,
) {
command := msg.Command
// Decode body as []string for the trailing text.
var bodyLines []string
if msg.Body != nil {
_ = json.Unmarshal(msg.Body, &bodyLines)
}
text := ""
if len(bodyLines) > 0 {
text = bodyLines[0]
}
// Route by command type.
switch {
case isNumeric(command):
c.deliverNumeric(msg, text)
case command == irc.CmdPrivmsg || command == irc.CmdNotice:
c.deliverTextMessage(msg, command, text)
case command == irc.CmdJoin:
c.deliverJoin(msg)
case command == irc.CmdPart:
c.deliverPart(msg, text)
case command == irc.CmdNick:
c.deliverNickChange(msg, text)
case command == irc.CmdQuit:
c.deliverQuitMsg(msg, text)
case command == irc.CmdTopic:
c.deliverTopicChange(msg, text)
case command == irc.CmdKick:
c.deliverKickMsg(msg, text)
case command == "INVITE":
c.deliverInviteMsg(msg, text)
case command == irc.CmdMode:
c.deliverMode(msg, text)
case command == irc.CmdPing:
// Server-originated PING — reply with PONG.
c.sendFromServer("PING", c.serverSfx)
default:
// Unknown command — deliver as server notice.
if text != "" {
c.sendFromServer("NOTICE", c.nick, text)
}
}
}
// isNumeric returns true if the command is a 3-digit
// numeric code.
func isNumeric(cmd string) bool {
return len(cmd) == 3 &&
cmd[0] >= '0' && cmd[0] <= '9' &&
cmd[1] >= '0' && cmd[1] <= '9' &&
cmd[2] >= '0' && cmd[2] <= '9'
}
// deliverNumeric sends a numeric reply.
func (c *Conn) deliverNumeric(
msg *db.IRCMessage,
text string,
) {
from := msg.From
if from == "" {
from = c.serverSfx
}
var params []string
if msg.Params != nil {
_ = json.Unmarshal(msg.Params, &params)
}
allParams := make([]string, 0, 1+len(params)+1)
allParams = append(allParams, c.nick)
allParams = append(allParams, params...)
if text != "" {
allParams = append(allParams, text)
}
c.send(FormatMessage(from, msg.Command, allParams...))
}
// deliverTextMessage sends PRIVMSG or NOTICE.
func (c *Conn) deliverTextMessage(
msg *db.IRCMessage,
command, text string,
) {
from := msg.From
target := msg.To
// Don't echo our own messages back.
if strings.EqualFold(from, c.nick) {
return
}
prefix := from
if !strings.Contains(prefix, "!") {
prefix = from + "!" + from + "@*"
}
c.send(FormatMessage(prefix, command, target, text))
}
// deliverJoin sends a JOIN notification.
func (c *Conn) deliverJoin(msg *db.IRCMessage) {
// Don't echo our own JOINs (we already sent them
// during joinChannel).
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
c.send(FormatMessage(prefix, "JOIN", channel))
}
// deliverPart sends a PART notification.
func (c *Conn) deliverPart(msg *db.IRCMessage, text string) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
if text != "" {
c.send(FormatMessage(
prefix, "PART", channel, text,
))
} else {
c.send(FormatMessage(prefix, "PART", channel))
}
}
// deliverNickChange sends a NICK change notification.
func (c *Conn) deliverNickChange(
msg *db.IRCMessage,
newNick string,
) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
c.send(FormatMessage(prefix, "NICK", newNick))
}
// deliverQuitMsg sends a QUIT notification.
func (c *Conn) deliverQuitMsg(
msg *db.IRCMessage,
text string,
) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
if text != "" {
c.send(FormatMessage(
prefix, "QUIT", "Quit: "+text,
))
} else {
c.send(FormatMessage(prefix, "QUIT", "Quit"))
}
}
// deliverTopicChange sends a TOPIC change notification.
func (c *Conn) deliverTopicChange(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
c.send(FormatMessage(prefix, "TOPIC", channel, text))
}
// deliverKickMsg sends a KICK notification.
func (c *Conn) deliverKickMsg(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
var params []string
if msg.Params != nil {
_ = json.Unmarshal(msg.Params, &params)
}
kickTarget := ""
if len(params) > 0 {
kickTarget = params[0]
}
if kickTarget != "" {
c.send(FormatMessage(
prefix, "KICK", channel, kickTarget, text,
))
} else {
c.send(FormatMessage(
prefix, "KICK", channel, "?", text,
))
}
}
// deliverInviteMsg sends an INVITE notification.
func (c *Conn) deliverInviteMsg(
_ *db.IRCMessage,
text string,
) {
c.sendFromServer("NOTICE", c.nick, text)
}
// deliverMode sends a MODE change notification.
func (c *Conn) deliverMode(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
target := msg.To
if text != "" {
c.send(FormatMessage(prefix, "MODE", target, text))
}
}

View File

@@ -1,157 +0,0 @@
package ircserver
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/service"
"go.uber.org/fx"
)
// Params defines the dependencies for creating an IRC
// Server.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *db.Database
Broker *broker.Broker
Service *service.Service
}
// Server is the TCP IRC protocol server.
type Server struct {
log *slog.Logger
cfg *config.Config
database *db.Database
brk *broker.Broker
svc *service.Service
listener net.Listener
mu sync.Mutex
conns map[*Conn]struct{}
cancel context.CancelFunc
}
// New creates a new IRC Server and registers its lifecycle
// hooks. The listener is only started if IRC_LISTEN_ADDR
// is configured; otherwise the server is inert.
func New(
lifecycle fx.Lifecycle,
params Params,
) *Server {
srv := &Server{
log: params.Logger.Get(),
cfg: params.Config,
database: params.Database,
brk: params.Broker,
svc: params.Service,
conns: make(map[*Conn]struct{}),
listener: nil,
cancel: nil,
mu: sync.Mutex{},
}
listenAddr := params.Config.IRCListenAddr
if listenAddr == "" {
return srv
}
lifecycle.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return srv.start(ctx, listenAddr)
},
OnStop: func(_ context.Context) error {
srv.stop()
return nil
},
})
return srv
}
// start begins listening for TCP connections.
//
//nolint:contextcheck // long-lived server ctx, not the short Fx one
func (s *Server) start(_ context.Context, addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("irc listen: %w", err)
}
s.listener = ln
ctx, cancel := context.WithCancel(context.Background())
s.cancel = cancel
s.log.Info(
"irc server listening", "addr", addr,
)
go s.acceptLoop(ctx)
return nil
}
// stop shuts down the listener and all connections.
func (s *Server) stop() {
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
s.listener.Close() //nolint:errcheck,gosec
}
s.mu.Lock()
for c := range s.conns {
c.conn.Close() //nolint:errcheck,gosec
}
s.mu.Unlock()
}
// acceptLoop accepts new connections.
func (s *Server) acceptLoop(ctx context.Context) {
for {
tcpConn, err := s.listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
s.log.Error(
"irc accept error", "error", err,
)
continue
}
}
client := newConn(
ctx, tcpConn, s.log,
s.database, s.brk, s.cfg, s.svc,
)
s.mu.Lock()
s.conns[client] = struct{}{}
s.mu.Unlock()
go func() {
defer func() {
s.mu.Lock()
delete(s.conns, client)
s.mu.Unlock()
}()
client.serve(ctx)
}()
}
}

View File

@@ -1,625 +0,0 @@
package ircserver_test
import (
"bufio"
"database/sql"
"fmt"
"log/slog"
"net"
"os"
"strings"
"testing"
"time"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
_ "modernc.org/sqlite"
)
const testTimeout = 5 * time.Second
func TestMain(m *testing.M) {
db.SetBcryptCost(4)
os.Exit(m.Run())
}
// testEnv holds the shared test infrastructure.
type testEnv struct {
database *db.Database
brk *broker.Broker
cfg *config.Config
srv *ircserver.Server
}
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
dsn := fmt.Sprintf(
"file:%s?mode=memory&cache=shared&_journal_mode=WAL",
t.Name(),
)
conn, err := sql.Open("sqlite", dsn)
if err != nil {
t.Fatalf("open db: %v", err)
}
conn.SetMaxOpenConns(1)
_, err = conn.ExecContext(
t.Context(), "PRAGMA foreign_keys = ON",
)
if err != nil {
t.Fatalf("pragma: %v", err)
}
database := db.NewTestDatabaseFromConn(conn)
err = database.RunMigrations(t.Context())
if err != nil {
t.Fatalf("migrate: %v", err)
}
brk := broker.New()
cfg := &config.Config{ //nolint:exhaustruct
ServerName: "test.irc",
MOTD: "Welcome to test IRC",
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
addr := listener.Addr().String()
err = listener.Close()
if err != nil {
t.Fatalf("close listener: %v", err)
}
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct
))
srv := ircserver.NewTestServer(log, cfg, database, brk)
err = srv.Start(addr)
if err != nil {
t.Fatalf("start irc server: %v", err)
}
t.Cleanup(func() {
srv.Stop()
err := conn.Close()
if err != nil {
t.Logf("close db: %v", err)
}
})
return &testEnv{
database: database,
brk: brk,
cfg: cfg,
srv: srv,
}
}
// dial connects to the test server.
func (env *testEnv) dial(t *testing.T) *testClient {
t.Helper()
conn, err := net.DialTimeout(
"tcp",
env.srv.Listener().Addr().String(),
testTimeout,
)
if err != nil {
t.Fatalf("dial: %v", err)
}
t.Cleanup(func() {
err := conn.Close()
if err != nil {
t.Logf("close conn: %v", err)
}
})
return &testClient{
t: t,
conn: conn,
scanner: bufio.NewScanner(conn),
}
}
// testClient wraps a raw TCP connection with helpers.
type testClient struct {
t *testing.T
conn net.Conn
scanner *bufio.Scanner
}
func (tc *testClient) send(line string) {
tc.t.Helper()
_ = tc.conn.SetWriteDeadline(
time.Now().Add(testTimeout),
)
_, err := fmt.Fprintf(tc.conn, "%s\r\n", line)
if err != nil {
tc.t.Fatalf("send: %v", err)
}
}
func (tc *testClient) readLine() string {
tc.t.Helper()
_ = tc.conn.SetReadDeadline(
time.Now().Add(testTimeout),
)
if !tc.scanner.Scan() {
err := tc.scanner.Err()
if err != nil {
tc.t.Fatalf("read: %v", err)
}
tc.t.Fatal("connection closed unexpectedly")
}
return tc.scanner.Text()
}
// readUntil reads lines until one matches the predicate.
func (tc *testClient) readUntil(
pred func(string) bool,
) []string {
tc.t.Helper()
var lines []string
for {
line := tc.readLine()
lines = append(lines, line)
if pred(line) {
return lines
}
}
}
// register sends NICK + USER and reads through the welcome
// burst.
func (tc *testClient) register(nick string) []string {
tc.t.Helper()
tc.send("NICK " + nick)
tc.send("USER " + nick + " 0 * :Test User")
return tc.readUntil(func(line string) bool {
return strings.Contains(line, " 376 ") ||
strings.Contains(line, " 422 ")
})
}
// assertContains checks that at least one line matches the
// given substring.
func assertContains(
t *testing.T,
lines []string,
substr, description string,
) {
t.Helper()
for _, line := range lines {
if strings.Contains(line, substr) {
return
}
}
t.Errorf("did not find %q in output: %s", substr, description)
}
// joinAndDrain joins a channel and reads until
// RPL_ENDOFNAMES.
func (tc *testClient) joinAndDrain(channel string) {
tc.t.Helper()
tc.send("JOIN " + channel)
tc.readUntil(func(line string) bool {
return strings.Contains(line, " 366 ")
})
}
// sendAndExpect sends a command and reads until a line
// containing the expected substring is found.
func (tc *testClient) sendAndExpect(
cmd, expect string,
) []string {
tc.t.Helper()
tc.send(cmd)
return tc.readUntil(func(line string) bool {
return strings.Contains(line, expect)
})
}
func TestRegistration(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
lines := client.register("alice")
assertContains(t, lines, " 001 ", "RPL_WELCOME")
}
func TestWelcomeContainsNick(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
lines := client.register("bob")
for _, line := range lines {
if strings.Contains(line, " 001 ") &&
!strings.Contains(line, "bob") {
t.Errorf("001 should contain nick: %s", line)
}
}
}
func TestPingPong(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("pingtest")
lines := client.sendAndExpect("PING :hello", "PONG")
assertContains(t, lines, "PONG", "PONG response")
}
func TestJoinChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("joiner")
client.send("JOIN #test")
lines := client.readUntil(func(line string) bool {
return strings.Contains(line, " 366 ")
})
assertContains(t, lines, "JOIN", "JOIN echo")
assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES")
}
func TestPrivmsgBetweenClients(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice_pm")
bob := env.dial(t)
bob.register("bob_pm")
alice.joinAndDrain("#chat")
bob.joinAndDrain("#chat")
alice.send("PRIVMSG #chat :hello bob!")
lines := bob.sendAndExpect("PING :sync", "hello bob!")
assertContains(t, lines, "hello bob!", "channel PRIVMSG")
}
func TestNickChange(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("oldnick")
lines := client.sendAndExpect("NICK newnick", "newnick")
assertContains(t, lines, "NICK", "NICK change echo")
}
func TestDuplicateNick(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
first := env.dial(t)
first.register("taken")
second := env.dial(t)
second.send("NICK taken")
second.send("USER taken 0 * :Test")
lines := second.readUntil(func(line string) bool {
return strings.Contains(line, " 433 ")
})
assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE")
}
func TestListChannels(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("lister")
client.joinAndDrain("#listtest")
lines := client.sendAndExpect("LIST", " 323 ")
assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term
}
func TestWhois(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("whoistest")
lines := client.sendAndExpect(
"WHOIS whoistest", " 318 ",
)
assertContains(t, lines, " 311 ", "RPL_WHOISUSER")
assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS")
}
func TestQuit(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("quitter")
lines := client.sendAndExpect(
"QUIT :goodbye", "ERROR",
)
assertContains(t, lines, "goodbye", "QUIT reason")
}
func TestTopicSetAndGet(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("topicuser")
client.joinAndDrain("#topictest")
lines := client.sendAndExpect(
"TOPIC #topictest :New topic here",
"New topic here",
)
assertContains(
t, lines, "New topic here", "TOPIC echo",
)
}
func TestUnknownCommand(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("unknowncmd")
lines := client.sendAndExpect("FOOBAR", " 421 ")
assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND")
}
func TestDirectMessage(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
sender := env.dial(t)
sender.register("dmsender")
receiver := env.dial(t)
receiver.register("dmreceiver")
// Give relay goroutines time to start.
time.Sleep(100 * time.Millisecond)
sender.send("PRIVMSG dmreceiver :hello privately")
lines := receiver.readUntil(func(line string) bool {
return strings.Contains(line, "hello privately")
})
assertContains(
t, lines, "hello privately", "direct PRIVMSG",
)
}
func TestCAPNegotiation(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.send("CAP LS 302")
line := client.readLine()
if !strings.Contains(line, "CAP") {
t.Errorf("expected CAP response, got: %s", line)
}
client.send("CAP END")
lines := client.register("capuser")
assertContains(t, lines, " 001 ", "registration after CAP")
}
func TestPartChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("parter")
client.joinAndDrain("#parttest")
lines := client.sendAndExpect(
"PART #parttest :leaving", "PART",
)
assertContains(t, lines, "#parttest", "PART echo")
}
func TestModeQuery(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("modeuser")
client.joinAndDrain("#modetest")
lines := client.sendAndExpect(
"MODE #modetest", " 324 ",
)
assertContains(
t, lines, " 324 ", "RPL_CHANNELMODEIS",
)
}
func TestWhoChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("whouser")
client.joinAndDrain("#whotest")
lines := client.sendAndExpect("WHO #whotest", " 315 ")
assertContains(t, lines, " 352 ", "RPL_WHOREPLY")
}
func TestLusers(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("luseruser")
lines := client.sendAndExpect("LUSERS", " 255 ")
assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT")
}
func TestMotd(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("motduser")
lines := client.sendAndExpect("MOTD", " 376 ")
assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD")
}
func TestAwaySetAndClear(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("awayuser")
setLines := client.sendAndExpect(
"AWAY :brb lunch", " 306 ",
)
assertContains(t, setLines, " 306 ", "RPL_NOWAWAY")
clearLines := client.sendAndExpect("AWAY", " 305 ")
assertContains(t, clearLines, " 305 ", "RPL_UNAWAY")
}
func TestHandlePassPostRegistration(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("passuser")
lines := client.sendAndExpect(
"PASS :mypassword123", "Password set",
)
assertContains(
t, lines, "Password set", "password confirmation",
)
}
func TestPreRegistrationNotRegistered(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.send("PRIVMSG #test :hello")
line := client.readLine()
if !strings.Contains(line, " 451 ") {
t.Errorf(
"expected ERR_NOTREGISTERED (451), got: %s",
line,
)
}
}
func TestNamesNonExistentChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("namesuser")
lines := client.sendAndExpect(
"NAMES #doesnotexist", " 366 ",
)
assertContains(
t, lines, " 366 ",
"RPL_ENDOFNAMES for non-existent channel",
)
}
func BenchmarkParseMessage(b *testing.B) {
line := ":nick!user@host PRIVMSG #channel :Hello, world!"
b.ResetTimer()
for range b.N {
_ = ircserver.ParseMessage(line)
}
}
func BenchmarkFormatMessage(b *testing.B) {
b.ResetTimer()
for range b.N {
_ = ircserver.FormatMessage(
"nick!user@host", "PRIVMSG",
"#channel", "Hello, world!",
)
}
}

View File

@@ -126,23 +126,18 @@ func (mware *Middleware) Logging() func(http.Handler) http.Handler {
}
// CORS returns middleware that handles Cross-Origin Resource Sharing.
// AllowCredentials is true so browsers include cookies in
// cross-origin API requests.
func (mware *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields
AllowOriginFunc: func(
_ *http.Request, _ string,
) bool {
return true
},
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
"GET", "POST", "PUT", "DELETE", "OPTIONS",
},
AllowedHeaders: []string{
"Accept", "Content-Type", "X-CSRF-Token",
"Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
AllowCredentials: false,
MaxAge: corsMaxAge,
})
}

View File

@@ -75,6 +75,10 @@ func (srv *Server) setupAPIv1(router chi.Router) {
"/session",
srv.handlers.HandleCreateSession(),
)
router.Post(
"/register",
srv.handlers.HandleRegister(),
)
router.Post(
"/login",
srv.handlers.HandleLogin(),

View File

@@ -1,901 +0,0 @@
// Package service provides shared business logic for both
// the IRC wire protocol and HTTP/JSON transports.
package service
import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"log/slog"
"strconv"
"strings"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/pkg/irc"
"go.uber.org/fx"
)
// Params defines the dependencies for creating a Service.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *db.Database
Broker *broker.Broker
}
// Service provides shared business logic for IRC commands.
type Service struct {
db *db.Database
broker *broker.Broker
config *config.Config
log *slog.Logger
}
// New creates a new Service.
func New(params Params) *Service {
return &Service{
db: params.Database,
broker: params.Broker,
config: params.Config,
log: params.Logger.Get(),
}
}
// NewTestService creates a Service for use in tests
// outside the service package.
func NewTestService(
database *db.Database,
brk *broker.Broker,
cfg *config.Config,
log *slog.Logger,
) *Service {
return &Service{
db: database,
broker: brk,
config: cfg,
log: log,
}
}
// IRCError represents an IRC protocol-level error with a
// numeric code that both transports can map to responses.
type IRCError struct {
Code irc.IRCMessageType
Params []string
Message string
}
func (e *IRCError) Error() string { return e.Message }
// JoinResult contains the outcome of a channel join.
type JoinResult struct {
ChannelID int64
IsCreator bool
}
// DirectMsgResult contains the outcome of a direct message.
type DirectMsgResult struct {
UUID string
AwayMsg string
}
// FanOut inserts a message and enqueues it to all given
// session IDs, notifying each via the broker.
func (s *Service) FanOut(
ctx context.Context,
command, from, to string,
params, body, meta json.RawMessage,
sessionIDs []int64,
) (int64, string, error) {
dbID, msgUUID, err := s.db.InsertMessage(
ctx, command, from, to, params, body, meta,
)
if err != nil {
return 0, "", fmt.Errorf("insert message: %w", err)
}
for _, sid := range sessionIDs {
_ = s.db.EnqueueToSession(ctx, sid, dbID)
s.broker.Notify(sid)
}
return dbID, msgUUID, nil
}
// excludeSession returns a copy of ids without the given
// session.
func excludeSession(
ids []int64,
exclude int64,
) []int64 {
out := make([]int64, 0, len(ids))
for _, id := range ids {
if id != exclude {
out = append(out, id)
}
}
return out
}
// SendChannelMessage validates membership and moderation,
// then fans out a message to all channel members except
// the sender. Returns the database row ID, message UUID,
// and any error. The dbID lets callers enqueue the same
// message to the sender when echo is needed (HTTP
// transport).
func (s *Service) SendChannelMessage(
ctx context.Context,
sessionID int64,
nick, command, channel string,
body, meta json.RawMessage,
) (int64, string, error) {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return 0, "", &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID,
)
if !isMember {
return 0, "", &IRCError{
irc.ErrCannotSendToChan,
[]string{channel},
"Cannot send to channel",
}
}
// Ban check — banned users cannot send messages.
isBanned, banErr := s.db.IsSessionBanned(
ctx, chID, sessionID,
)
if banErr == nil && isBanned {
return 0, "", &IRCError{
irc.ErrCannotSendToChan,
[]string{channel},
"Cannot send to channel (+b)",
}
}
moderated, _ := s.db.IsChannelModerated(ctx, chID)
if moderated {
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
isVoiced, _ := s.db.IsChannelVoiced(
ctx, chID, sessionID,
)
if !isOp && !isVoiced {
return 0, "", &IRCError{
irc.ErrCannotSendToChan,
[]string{channel},
"Cannot send to channel (+m)",
}
}
}
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
recipients := excludeSession(memberIDs, sessionID)
dbID, uuid, fanErr := s.FanOut(
ctx, command, nick, channel,
nil, body, meta, recipients,
)
if fanErr != nil {
return 0, "", fanErr
}
return dbID, uuid, nil
}
// SendDirectMessage validates the target and sends a
// direct message, returning the message UUID and any away
// message set on the target.
func (s *Service) SendDirectMessage(
ctx context.Context,
sessionID int64,
nick, command, target string,
body, meta json.RawMessage,
) (*DirectMsgResult, error) {
targetSID, err := s.db.GetSessionByNick(ctx, target)
if err != nil {
return nil, &IRCError{
irc.ErrNoSuchNick,
[]string{target},
"No such nick",
}
}
away, _ := s.db.GetAway(ctx, targetSID)
recipients := []int64{targetSID}
if targetSID != sessionID {
recipients = append(recipients, sessionID)
}
_, uuid, fanErr := s.FanOut(
ctx, command, nick, target,
nil, body, meta, recipients,
)
if fanErr != nil {
return nil, fanErr
}
return &DirectMsgResult{UUID: uuid, AwayMsg: away}, nil
}
// JoinChannel creates or joins a channel, making the
// first joiner the operator. Fans out the JOIN to all
// channel members.
func (s *Service) JoinChannel(
ctx context.Context,
sessionID int64,
nick, channel, suppliedKey string,
) (*JoinResult, error) {
chID, err := s.db.GetOrCreateChannel(ctx, channel)
if err != nil {
return nil, fmt.Errorf("get/create channel: %w", err)
}
memberCount, countErr := s.db.CountChannelMembers(
ctx, chID,
)
isCreator := countErr == nil && memberCount == 0
if !isCreator {
if joinErr := checkJoinRestrictions(
ctx, s.db, chID, sessionID,
channel, suppliedKey, memberCount,
); joinErr != nil {
return nil, joinErr
}
}
if isCreator {
err = s.db.JoinChannelAsOperator(
ctx, chID, sessionID,
)
} else {
err = s.db.JoinChannel(ctx, chID, sessionID)
}
if err != nil {
return nil, fmt.Errorf("join channel: %w", err)
}
// Clear invite after successful join.
_ = s.db.ClearChannelInvite(ctx, chID, sessionID)
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{channel}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdJoin, nick, channel,
nil, body, nil, memberIDs,
)
return &JoinResult{
ChannelID: chID,
IsCreator: isCreator,
}, nil
}
// PartChannel validates membership, broadcasts PART to
// remaining members, removes the user, and cleans up empty
// channels.
func (s *Service) PartChannel(
ctx context.Context,
sessionID int64,
nick, channel, reason string,
) error {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID,
)
if !isMember {
return &IRCError{
irc.ErrNotOnChannel,
[]string{channel},
"You're not on that channel",
}
}
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
recipients := excludeSession(memberIDs, sessionID)
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdPart, nick, channel,
nil, body, nil, recipients,
)
s.db.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec
s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec
return nil
}
// SetTopic validates membership and topic-lock, sets the
// topic, and broadcasts the change.
func (s *Service) SetTopic(
ctx context.Context,
sessionID int64,
nick, channel, topic string,
) error {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID,
)
if !isMember {
return &IRCError{
irc.ErrNotOnChannel,
[]string{channel},
"You're not on that channel",
}
}
topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID)
if topicLocked {
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
if !isOp {
return &IRCError{
irc.ErrChanOpPrivsNeeded,
[]string{channel},
"You're not channel operator",
}
}
}
if setErr := s.db.SetTopic(
ctx, channel, topic,
); setErr != nil {
return fmt.Errorf("set topic: %w", setErr)
}
_ = s.db.SetTopicMeta(ctx, channel, topic, nick)
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{topic}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdTopic, nick, channel,
nil, body, nil, memberIDs,
)
return nil
}
// KickUser validates operator status and target
// membership, broadcasts the KICK, removes the target,
// and cleans up empty channels.
func (s *Service) KickUser(
ctx context.Context,
sessionID int64,
nick, channel, targetNick, reason string,
) error {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
if !isOp {
return &IRCError{
irc.ErrChanOpPrivsNeeded,
[]string{channel},
"You're not channel operator",
}
}
targetSID, err := s.db.GetSessionByNick(
ctx, targetNick,
)
if err != nil {
return &IRCError{
irc.ErrNoSuchNick,
[]string{targetNick},
"No such nick/channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, targetSID,
)
if !isMember {
return &IRCError{
irc.ErrUserNotInChannel,
[]string{targetNick, channel},
"They aren't on that channel",
}
}
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
params, _ := json.Marshal( //nolint:errchkjson
[]string{targetNick},
)
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdKick, nick, channel,
params, body, nil, memberIDs,
)
s.db.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec
s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec
return nil
}
// ChangeNick changes a user's nickname and broadcasts the
// change to all users sharing channels.
func (s *Service) ChangeNick(
ctx context.Context,
sessionID int64,
oldNick, newNick string,
) error {
err := s.db.ChangeNick(ctx, sessionID, newNick)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") ||
db.IsUniqueConstraintError(err) {
return &IRCError{
irc.ErrNicknameInUse,
[]string{newNick},
"Nickname is already in use",
}
}
return &IRCError{
irc.ErrErroneusNickname,
[]string{newNick},
"Erroneous nickname",
}
}
s.broadcastNickChange(ctx, sessionID, oldNick, newNick)
return nil
}
// BroadcastQuit broadcasts a QUIT to all channel peers,
// parts all channels, and deletes the session. Uses the
// FanOut pattern: one message row fanned out to all unique
// peer sessions.
func (s *Service) BroadcastQuit(
ctx context.Context,
sessionID int64,
nick, reason string,
) {
channels, err := s.db.GetSessionChannels(
ctx, sessionID,
)
if err != nil {
return
}
notified := make(map[int64]bool)
for _, ch := range channels {
memberIDs, memErr := s.db.GetChannelMemberIDs(
ctx, ch.ID,
)
if memErr != nil {
continue
}
for _, mid := range memberIDs {
if mid == sessionID || notified[mid] {
continue
}
notified[mid] = true
}
}
if len(notified) > 0 {
recipients := make([]int64, 0, len(notified))
for sid := range notified {
recipients = append(recipients, sid)
}
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
_, _, _ = s.FanOut(
ctx, irc.CmdQuit, nick, "",
nil, body, nil, recipients,
)
}
for _, ch := range channels {
s.db.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec
s.db.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec
}
s.db.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec
}
// SetAway sets or clears the away message. Returns true
// if the message was cleared (empty string).
func (s *Service) SetAway(
ctx context.Context,
sessionID int64,
message string,
) (bool, error) {
err := s.db.SetAway(ctx, sessionID, message)
if err != nil {
return false, fmt.Errorf("set away: %w", err)
}
return message == "", nil
}
// Oper validates operator credentials and grants oper
// status to the session.
func (s *Service) Oper(
ctx context.Context,
sessionID int64,
name, password string,
) error {
cfgName := s.config.OperName
cfgPassword := s.config.OperPassword
// Use constant-time comparison and return the same
// error for all failures to prevent information
// leakage about valid operator names.
if cfgName == "" || cfgPassword == "" ||
subtle.ConstantTimeCompare(
[]byte(name), []byte(cfgName),
) != 1 ||
subtle.ConstantTimeCompare(
[]byte(password), []byte(cfgPassword),
) != 1 {
return &IRCError{
irc.ErrNoOperHost,
nil,
"No O-lines for your host",
}
}
_ = s.db.SetSessionOper(ctx, sessionID, true)
return nil
}
// ValidateChannelOp checks that the session is a channel
// operator. Returns the channel ID.
func (s *Service) ValidateChannelOp(
ctx context.Context,
sessionID int64,
channel string,
) (int64, error) {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return 0, &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
if !isOp {
return 0, &IRCError{
irc.ErrChanOpPrivsNeeded,
[]string{channel},
"You're not channel operator",
}
}
return chID, nil
}
// ApplyMemberMode applies +o/-o or +v/-v on a channel
// member after validating the target.
func (s *Service) ApplyMemberMode(
ctx context.Context,
chID int64,
channel, targetNick string,
mode rune,
adding bool,
) error {
targetSID, err := s.db.GetSessionByNick(
ctx, targetNick,
)
if err != nil {
return &IRCError{
irc.ErrNoSuchNick,
[]string{targetNick},
"No such nick/channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, targetSID,
)
if !isMember {
return &IRCError{
irc.ErrUserNotInChannel,
[]string{targetNick, channel},
"They aren't on that channel",
}
}
switch mode {
case 'o':
_ = s.db.SetChannelMemberOperator(
ctx, chID, targetSID, adding,
)
case 'v':
_ = s.db.SetChannelMemberVoiced(
ctx, chID, targetSID, adding,
)
}
return nil
}
// SetChannelFlag applies a simple boolean channel mode
// (+m/-m, +t/-t, +i/-i, +s/-s, +n/-n).
func (s *Service) SetChannelFlag(
ctx context.Context,
chID int64,
flag rune,
setting bool,
) error {
switch flag {
case 'm':
if err := s.db.SetChannelModerated(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set moderated: %w", err)
}
case 't':
if err := s.db.SetChannelTopicLocked(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set topic locked: %w", err)
}
case 'i':
if err := s.db.SetChannelInviteOnly(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set invite only: %w", err)
}
case 's':
if err := s.db.SetChannelSecret(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set secret: %w", err)
}
case 'n':
if err := s.db.SetChannelNoExternal(
ctx, chID, setting,
); err != nil {
return fmt.Errorf(
"set no external: %w", err,
)
}
}
return nil
}
// BroadcastMode fans out a MODE change to all channel
// members.
func (s *Service) BroadcastMode(
ctx context.Context,
nick, channel string,
chID int64,
modeText string,
) {
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdMode, nick, channel,
nil, body, nil, memberIDs,
)
}
// QueryChannelMode returns the complete channel mode
// string including all flags and parameterized modes.
func (s *Service) QueryChannelMode(
ctx context.Context,
chID int64,
) string {
modes := "+"
noExternal, _ := s.db.IsChannelNoExternal(ctx, chID)
if noExternal {
modes += "n"
}
inviteOnly, _ := s.db.IsChannelInviteOnly(ctx, chID)
if inviteOnly {
modes += "i"
}
moderated, _ := s.db.IsChannelModerated(ctx, chID)
if moderated {
modes += "m"
}
secret, _ := s.db.IsChannelSecret(ctx, chID)
if secret {
modes += "s"
}
topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID)
if topicLocked {
modes += "t"
}
var modeParams string
key, _ := s.db.GetChannelKey(ctx, chID)
if key != "" {
modes += "k"
modeParams += " " + key
}
limit, _ := s.db.GetChannelUserLimit(ctx, chID)
if limit > 0 {
modes += "l"
modeParams += " " + strconv.Itoa(limit)
}
bits, _ := s.db.GetChannelHashcashBits(ctx, chID)
if bits > 0 {
modes += "H"
modeParams += " " + strconv.Itoa(bits)
}
return modes + modeParams
}
// broadcastNickChange notifies channel peers of a nick
// change.
func (s *Service) broadcastNickChange(
ctx context.Context,
sessionID int64,
oldNick, newNick string,
) {
channels, err := s.db.GetSessionChannels(
ctx, sessionID,
)
if err != nil {
return
}
body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson
notified := make(map[int64]bool)
dbID, _, insErr := s.db.InsertMessage(
ctx, irc.CmdNick, oldNick, "",
nil, body, nil,
)
if insErr != nil {
return
}
// Notify the user themselves (for multi-client sync).
_ = s.db.EnqueueToSession(ctx, sessionID, dbID)
s.broker.Notify(sessionID)
notified[sessionID] = true
for _, ch := range channels {
memberIDs, memErr := s.db.GetChannelMemberIDs(
ctx, ch.ID,
)
if memErr != nil {
continue
}
for _, mid := range memberIDs {
if notified[mid] {
continue
}
notified[mid] = true
_ = s.db.EnqueueToSession(ctx, mid, dbID)
s.broker.Notify(mid)
}
}
}
// checkJoinRestrictions validates Tier 2 join conditions:
// bans, invite-only, channel key, and user limit.
func checkJoinRestrictions(
ctx context.Context,
database *db.Database,
chID, sessionID int64,
channel, suppliedKey string,
memberCount int64,
) error {
isBanned, banErr := database.IsSessionBanned(
ctx, chID, sessionID,
)
if banErr == nil && isBanned {
return &IRCError{
Code: irc.ErrBannedFromChan,
Params: []string{channel},
Message: "Cannot join channel (+b)",
}
}
isInviteOnly, ioErr := database.IsChannelInviteOnly(
ctx, chID,
)
if ioErr == nil && isInviteOnly {
hasInvite, invErr := database.HasChannelInvite(
ctx, chID, sessionID,
)
if invErr != nil || !hasInvite {
return &IRCError{
Code: irc.ErrInviteOnlyChan,
Params: []string{channel},
Message: "Cannot join channel (+i)",
}
}
}
key, keyErr := database.GetChannelKey(ctx, chID)
if keyErr == nil && key != "" && suppliedKey != key {
return &IRCError{
Code: irc.ErrBadChannelKey,
Params: []string{channel},
Message: "Cannot join channel (+k)",
}
}
limit, limErr := database.GetChannelUserLimit(ctx, chID)
if limErr == nil && limit > 0 &&
memberCount >= int64(limit) {
return &IRCError{
Code: irc.ErrChannelIsFull,
Params: []string{channel},
Message: "Cannot join channel (+l)",
}
}
return nil
}

View File

@@ -1,365 +0,0 @@
// Tests use a global viper instance for configuration,
// making parallel execution unsafe.
//
//nolint:paralleltest
package service_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"testing"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/globals"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/pkg/irc"
"go.uber.org/fx"
"go.uber.org/fx/fxtest"
"golang.org/x/crypto/bcrypt"
)
func TestMain(m *testing.M) {
db.SetBcryptCost(bcrypt.MinCost)
os.Exit(m.Run())
}
// testEnv holds all dependencies for a service test.
type testEnv struct {
svc *service.Service
db *db.Database
broker *broker.Broker
app *fxtest.App
}
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
dbURL := fmt.Sprintf(
"file:svc_test_%p?mode=memory&cache=shared",
t,
)
var (
database *db.Database
svc *service.Service
)
brk := broker.New()
app := fxtest.New(t,
fx.Provide(
func() *globals.Globals {
return &globals.Globals{ //nolint:exhaustruct
Appname: "neoirc-test",
Version: "test",
}
},
logger.New,
func(
lifecycle fx.Lifecycle,
globs *globals.Globals,
log *logger.Logger,
) (*config.Config, error) {
cfg, err := config.New(
lifecycle, config.Params{ //nolint:exhaustruct
Globals: globs, Logger: log,
},
)
if err != nil {
return nil, fmt.Errorf(
"test config: %w", err,
)
}
cfg.DBURL = dbURL
cfg.Port = 0
cfg.OperName = "admin"
cfg.OperPassword = "secret"
return cfg, nil
},
func(
lifecycle fx.Lifecycle,
log *logger.Logger,
cfg *config.Config,
) (*db.Database, error) {
return db.New(lifecycle, db.Params{ //nolint:exhaustruct
Logger: log, Config: cfg,
})
},
func() *broker.Broker { return brk },
service.New,
),
fx.Populate(&database, &svc),
)
app.RequireStart()
t.Cleanup(func() {
app.RequireStop()
})
return &testEnv{
svc: svc,
db: database,
broker: brk,
app: app,
}
}
// createSession is a test helper that creates a session
// and returns the session ID.
func createSession(
ctx context.Context,
t *testing.T,
database *db.Database,
nick string,
) int64 {
t.Helper()
sessionID, _, _, err := database.CreateSession(
ctx, nick, nick, "localhost", "127.0.0.1",
)
if err != nil {
t.Fatalf("create session %s: %v", nick, err)
}
return sessionID
}
func TestFanOut(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
body, _ := json.Marshal([]string{"hello"}) //nolint:errchkjson
dbID, uuid, err := env.svc.FanOut(
ctx, irc.CmdPrivmsg, "alice", "#test",
nil, body, nil,
[]int64{sid1, sid2},
)
if err != nil {
t.Fatalf("FanOut: %v", err)
}
if dbID == 0 {
t.Error("expected non-zero dbID")
}
if uuid == "" {
t.Error("expected non-empty UUID")
}
}
func TestJoinChannel(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid := createSession(ctx, t, env.db, "alice")
result, err := env.svc.JoinChannel(
ctx, sid, "alice", "#general", "",
)
if err != nil {
t.Fatalf("JoinChannel: %v", err)
}
if result.ChannelID == 0 {
t.Error("expected non-zero channel ID")
}
if !result.IsCreator {
t.Error("first joiner should be creator")
}
// Second user joins — not creator.
sid2 := createSession(ctx, t, env.db, "bob")
result2, err := env.svc.JoinChannel(
ctx, sid2, "bob", "#general", "",
)
if err != nil {
t.Fatalf("JoinChannel bob: %v", err)
}
if result2.IsCreator {
t.Error("second joiner should not be creator")
}
if result2.ChannelID != result.ChannelID {
t.Error("both should join the same channel")
}
}
func TestPartChannel(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid := createSession(ctx, t, env.db, "alice")
_, err := env.svc.JoinChannel(
ctx, sid, "alice", "#general", "",
)
if err != nil {
t.Fatalf("JoinChannel: %v", err)
}
err = env.svc.PartChannel(
ctx, sid, "alice", "#general", "bye",
)
if err != nil {
t.Fatalf("PartChannel: %v", err)
}
// Parting a non-existent channel returns error.
err = env.svc.PartChannel(
ctx, sid, "alice", "#nonexistent", "",
)
if err == nil {
t.Error("expected error for non-existent channel")
}
var ircErr *service.IRCError
if !errors.As(err, &ircErr) {
t.Errorf("expected IRCError, got %T", err)
}
}
func TestSendChannelMessage(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
_, err := env.svc.JoinChannel(
ctx, sid1, "alice", "#chat", "",
)
if err != nil {
t.Fatalf("join alice: %v", err)
}
_, err = env.svc.JoinChannel(
ctx, sid2, "bob", "#chat", "",
)
if err != nil {
t.Fatalf("join bob: %v", err)
}
body, _ := json.Marshal([]string{"hello world"}) //nolint:errchkjson
dbID, uuid, err := env.svc.SendChannelMessage(
ctx, sid1, "alice",
irc.CmdPrivmsg, "#chat", body, nil,
)
if err != nil {
t.Fatalf("SendChannelMessage: %v", err)
}
if dbID == 0 {
t.Error("expected non-zero dbID")
}
if uuid == "" {
t.Error("expected non-empty UUID")
}
// Non-member cannot send.
sid3 := createSession(ctx, t, env.db, "charlie")
_, _, err = env.svc.SendChannelMessage(
ctx, sid3, "charlie",
irc.CmdPrivmsg, "#chat", body, nil,
)
if err == nil {
t.Error("expected error for non-member send")
}
}
func TestBroadcastQuit(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
_, err := env.svc.JoinChannel(
ctx, sid1, "alice", "#room", "",
)
if err != nil {
t.Fatalf("join alice: %v", err)
}
_, err = env.svc.JoinChannel(
ctx, sid2, "bob", "#room", "",
)
if err != nil {
t.Fatalf("join bob: %v", err)
}
// BroadcastQuit should not panic and should clean up.
env.svc.BroadcastQuit(
ctx, sid1, "alice", "Goodbye",
)
// Session should be deleted.
_, lookupErr := env.db.GetSessionByNick(ctx, "alice")
if lookupErr == nil {
t.Error("expected session to be deleted after quit")
}
}
func TestSendChannelMessage_Moderated(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
result, err := env.svc.JoinChannel(
ctx, sid1, "alice", "#modchat", "",
)
if err != nil {
t.Fatalf("join alice: %v", err)
}
_, err = env.svc.JoinChannel(
ctx, sid2, "bob", "#modchat", "",
)
if err != nil {
t.Fatalf("join bob: %v", err)
}
// Set channel to moderated.
chID := result.ChannelID
_ = env.svc.SetChannelFlag(ctx, chID, 'm', true)
body, _ := json.Marshal([]string{"test"}) //nolint:errchkjson
// Bob (non-op, non-voiced) should fail to send.
_, _, err = env.svc.SendChannelMessage(
ctx, sid2, "bob",
irc.CmdPrivmsg, "#modchat", body, nil,
)
if err == nil {
t.Error("expected error for non-voiced user in moderated channel")
}
// Alice (operator) should succeed.
_, _, err = env.svc.SendChannelMessage(
ctx, sid1, "alice",
irc.CmdPrivmsg, "#modchat", body, nil,
)
if err != nil {
t.Errorf("operator should be able to send in moderated channel: %v", err)
}
}

View File

@@ -3,9 +3,7 @@ package irc
// IRC command names (RFC 1459 / RFC 2812).
const (
CmdAway = "AWAY"
CmdInvite = "INVITE"
CmdJoin = "JOIN"
CmdKick = "KICK"
CmdList = "LIST"
CmdLusers = "LUSERS"
CmdMode = "MODE"
@@ -13,15 +11,12 @@ const (
CmdNames = "NAMES"
CmdNick = "NICK"
CmdNotice = "NOTICE"
CmdOper = "OPER"
CmdPass = "PASS"
CmdPart = "PART"
CmdPing = "PING"
CmdPong = "PONG"
CmdPrivmsg = "PRIVMSG"
CmdQuit = "QUIT"
CmdTopic = "TOPIC"
CmdUser = "USER"
CmdWho = "WHO"
CmdWhois = "WHOIS"
)

View File

@@ -132,7 +132,6 @@ const (
RplNoTopic IRCMessageType = 331
RplTopic IRCMessageType = 332
RplTopicWhoTime IRCMessageType = 333
RplWhoisActually IRCMessageType = 338
RplInviting IRCMessageType = 341
RplSummoning IRCMessageType = 342
RplInviteList IRCMessageType = 346
@@ -296,7 +295,6 @@ var names = map[IRCMessageType]string{
RplNoTopic: "RPL_NOTOPIC",
RplTopic: "RPL_TOPIC",
RplTopicWhoTime: "RPL_TOPICWHOTIME",
RplWhoisActually: "RPL_WHOISACTUALLY",
RplInviting: "RPL_INVITING",
RplSummoning: "RPL_SUMMONING",
RplInviteList: "RPL_INVITELIST",