6 Commits

Author SHA1 Message Date
user
67460ea6b2 fix: use timing-safe comparison for OPER credentials
Some checks failed
check / check (push) Failing after 1m49s
Replace plain != string comparison with crypto/subtle.ConstantTimeCompare
for both operator name and password checks in handleOper to prevent
timing-based side-channel attacks.

Closes review feedback on PR #82.
2026-03-17 19:42:44 -07:00
user
b7999b201f fix: correct misplaced doc comments for handleOper and handleAway 2026-03-17 19:42:44 -07:00
clawbot
d85956ca1a feat: add OPER command and oper-only WHOIS client info
- Add OPER command with NEOIRC_OPER_NAME/NEOIRC_OPER_PASSWORD config
- Add is_oper column to sessions table
- Add RPL_WHOISACTUALLY (338): show client IP/hostname to opers
- Add RPL_WHOISOPERATOR (313): show oper status in WHOIS
- Add GetOperCount for accurate LUSERS oper count
- Fix README schema: add ip/is_oper to sessions, ip/hostname to clients
- Add OPER command documentation and numeric references to README
- Refactor executeWhois to stay under funlen limit
- Add comprehensive tests for OPER auth, oper WHOIS, non-oper WHOIS

Closes #81
2026-03-17 19:42:44 -07:00
user
58f958c8d3 fix: include hostmask in NAMES replies (RPL_NAMREPLY) 2026-03-17 19:42:26 -07:00
user
0fef7929ad add IP to sessions, IP+hostname to clients
- Add ip column to sessions table (real client IP of session creator)
- Add ip and hostname columns to clients table (per-connection tracking)
- Update CreateSession, RegisterUser, LoginUser to store new fields
- Add GetClientHostInfo query method
- Update SessionHostInfo to include IP
- Extract executeCreateSession to fix funlen lint
- Add tests for session IP, client IP/hostname, login client tracking
- Update README with new field documentation
2026-03-17 19:41:25 -07:00
user
c4652728b8 feat: add username/hostname support with IRC hostmask format
- Add username and hostname columns to sessions table (001_initial.sql)
- Accept optional username field in session creation and registration
  endpoints; defaults to nick if not provided
- Resolve hostname via reverse DNS of connecting client IP at session
  creation time (supports X-Forwarded-For and X-Real-IP headers)
- Display real username and hostname in WHOIS (311 RPL_WHOISUSER) and
  WHO (352 RPL_WHOREPLY) responses instead of nick/servername
- Add FormatHostmask helper for nick!user@host format
- Add SessionHostInfo type and GetSessionHostInfo query
- Include username/hostname in MemberInfo and ChannelMembers results
- Extract validateHashcash and resolveUsername helpers to stay under
  funlen limits
- Add comprehensive unit tests for all new DB functions, hostmask
  formatting, and integration tests for WHOIS/WHO responses
- Update README with hostmask documentation, new API fields, and
  updated schema reference
2026-03-17 19:41:25 -07:00
38 changed files with 1630 additions and 11101 deletions

View File

@@ -53,7 +53,7 @@ RUN apk add --no-cache ca-certificates \
COPY --from=builder /neoircd /usr/local/bin/neoircd COPY --from=builder /neoircd /usr/local/bin/neoircd
USER neoirc USER neoirc
EXPOSE 8080 6667 EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget -qO- http://localhost:8080/.well-known/healthcheck.json || exit 1 CMD wget -qO- http://localhost:8080/.well-known/healthcheck.json || exit 1
ENTRYPOINT ["neoircd"] ENTRYPOINT ["neoircd"]

View File

@@ -32,7 +32,7 @@ fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1) @test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: ensure-web-dist test: ensure-web-dist
go test -timeout 30s -race -cover ./... || go test -timeout 30s -race -v ./... go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes # check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong # Used by CI and Docker build — fails if anything is wrong

711
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -2,17 +2,14 @@
package main package main
import ( import (
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/globals" "git.eeqj.de/sneak/neoirc/internal/globals"
"git.eeqj.de/sneak/neoirc/internal/handlers" "git.eeqj.de/sneak/neoirc/internal/handlers"
"git.eeqj.de/sneak/neoirc/internal/healthcheck" "git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
"git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/middleware" "git.eeqj.de/sneak/neoirc/internal/middleware"
"git.eeqj.de/sneak/neoirc/internal/server" "git.eeqj.de/sneak/neoirc/internal/server"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/internal/stats" "git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx" "go.uber.org/fx"
) )
@@ -31,23 +28,16 @@ func main() {
fx.New( fx.New(
fx.Provide( fx.Provide(
broker.New,
config.New, config.New,
db.New, db.New,
globals.New, globals.New,
handlers.New, handlers.New,
ircserver.New,
logger.New, logger.New,
server.New, server.New,
middleware.New, middleware.New,
healthcheck.New, healthcheck.New,
service.New,
stats.New, stats.New,
), ),
fx.Invoke(func( fx.Invoke(func(*server.Server) {}),
_ *server.Server,
_ *ircserver.Server,
) {
}),
).Run() ).Run()
} }

1
go.mod
View File

@@ -16,7 +16,6 @@ require (
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0 golang.org/x/crypto v0.48.0
golang.org/x/time v0.6.0
modernc.org/sqlite v1.45.0 modernc.org/sqlite v1.45.0
) )

2
go.sum
View File

@@ -151,8 +151,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/cookiejar"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@@ -29,19 +28,16 @@ var errHTTP = errors.New("HTTP error")
// Client wraps HTTP calls to the neoirc server API. // Client wraps HTTP calls to the neoirc server API.
type Client struct { type Client struct {
BaseURL string BaseURL string
Token string
HTTPClient *http.Client HTTPClient *http.Client
} }
// NewClient creates a new API client with a cookie jar // NewClient creates a new API client.
// for automatic auth cookie management.
func NewClient(baseURL string) *Client { func NewClient(baseURL string) *Client {
jar, _ := cookiejar.New(nil) return &Client{ //nolint:exhaustruct // Token set after CreateSession
return &Client{
BaseURL: baseURL, BaseURL: baseURL,
HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: httpTimeout, Timeout: httpTimeout,
Jar: jar,
}, },
} }
} }
@@ -83,6 +79,8 @@ func (client *Client) CreateSession(
return nil, fmt.Errorf("decode session: %w", err) return nil, fmt.Errorf("decode session: %w", err)
} }
client.Token = resp.Token
return &resp, nil return &resp, nil
} }
@@ -123,7 +121,6 @@ func (client *Client) PollMessages(
Timeout: time.Duration( Timeout: time.Duration(
timeout+pollExtraTime, timeout+pollExtraTime,
) * time.Second, ) * time.Second,
Jar: client.HTTPClient.Jar,
} }
params := url.Values{} params := url.Values{}
@@ -148,6 +145,10 @@ func (client *Client) PollMessages(
return nil, fmt.Errorf("new request: %w", err) return nil, fmt.Errorf("new request: %w", err)
} }
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
resp, err := pollClient.Do(request) resp, err := pollClient.Do(request)
if err != nil { if err != nil {
return nil, fmt.Errorf("poll request: %w", err) return nil, fmt.Errorf("poll request: %w", err)
@@ -303,6 +304,12 @@ func (client *Client) do(
"Content-Type", "application/json", "Content-Type", "application/json",
) )
if client.Token != "" {
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
}
resp, err := client.HTTPClient.Do(request) resp, err := client.HTTPClient.Do(request)
if err != nil { if err != nil {
return nil, fmt.Errorf("http: %w", err) return nil, fmt.Errorf("http: %w", err)

View File

@@ -12,6 +12,7 @@ type SessionRequest struct {
type SessionResponse struct { type SessionResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Nick string `json:"nick"` Nick string `json:"nick"`
Token string `json:"token"`
} }
// StateResponse is the response from GET /api/v1/state. // StateResponse is the response from GET /api/v1/state.

View File

@@ -48,9 +48,6 @@ type Config struct {
HashcashBits int HashcashBits int
OperName string OperName string
OperPassword string OperPassword string
LoginRateLimit float64
LoginRateBurst int
IRCListenAddr string
params *Params params *Params
log *slog.Logger log *slog.Logger
} }
@@ -85,9 +82,6 @@ func New(
viper.SetDefault("NEOIRC_HASHCASH_BITS", "20") viper.SetDefault("NEOIRC_HASHCASH_BITS", "20")
viper.SetDefault("NEOIRC_OPER_NAME", "") viper.SetDefault("NEOIRC_OPER_NAME", "")
viper.SetDefault("NEOIRC_OPER_PASSWORD", "") viper.SetDefault("NEOIRC_OPER_PASSWORD", "")
viper.SetDefault("LOGIN_RATE_LIMIT", "1")
viper.SetDefault("LOGIN_RATE_BURST", "5")
viper.SetDefault("IRC_LISTEN_ADDR", ":6667")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
@@ -116,9 +110,6 @@ func New(
HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"), HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"),
OperName: viper.GetString("NEOIRC_OPER_NAME"), OperName: viper.GetString("NEOIRC_OPER_NAME"),
OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"), OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"),
LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"),
LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"),
IRCListenAddr: viper.GetString("IRC_LISTEN_ADDR"),
log: log, log: log,
params: &params, params: &params,
} }

View File

@@ -10,39 +10,92 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
//nolint:gochecknoglobals // var so tests can override via SetBcryptCost const bcryptCost = bcrypt.DefaultCost
var bcryptCost = bcrypt.DefaultCost
// SetBcryptCost overrides the bcrypt cost.
// Use bcrypt.MinCost in tests to avoid slow hashing.
func SetBcryptCost(cost int) { bcryptCost = cost }
var errNoPassword = errors.New( var errNoPassword = errors.New(
"account has no password set", "account has no password set",
) )
// SetPassword sets a bcrypt-hashed password on a session, // RegisterUser creates a session with a hashed password
// enabling multi-client login via POST /api/v1/login. // and returns session ID, client ID, and token.
func (database *Database) SetPassword( func (database *Database) RegisterUser(
ctx context.Context, ctx context.Context,
sessionID int64, nick, password, username, hostname, remoteIP string,
password string, ) (int64, int64, string, error) {
) error { if username == "" {
username = nick
}
hash, err := bcrypt.GenerateFromPassword( hash, err := bcrypt.GenerateFromPassword(
[]byte(password), bcryptCost, []byte(password), bcryptCost,
) )
if err != nil { if err != nil {
return fmt.Errorf("hash password: %w", err) return 0, 0, "", fmt.Errorf(
"hash password: %w", err,
)
} }
_, err = database.conn.ExecContext(ctx, sessionUUID := uuid.New().String()
"UPDATE sessions SET password_hash = ? WHERE id = ?", clientUUID := uuid.New().String()
string(hash), sessionID)
token, err := generateToken()
if err != nil { if err != nil {
return fmt.Errorf("set password: %w", err) return 0, 0, "", err
} }
return nil now := time.Now()
transaction, err := database.conn.BeginTx(ctx, nil)
if err != nil {
return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
}
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, username, hostname, ip,
password_hash, created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
sessionUUID, nick, username, hostname,
remoteIP, string(hash), now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
tokenHash := hashToken(token)
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token, ip, hostname,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit registration: %w", err,
)
}
return sessionID, clientID, token, nil
} }
// LoginUser verifies a nick/password and creates a new // LoginUser verifies a nick/password and creates a new

View File

@@ -6,65 +6,126 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func TestSetPassword(t *testing.T) { func TestRegisterUser(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := sessionID, clientID, token, err :=
database.CreateSession(ctx, "passuser", "", "", "") database.RegisterUser(ctx, "reguser", "password123", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.SetPassword( if sessionID == 0 || clientID == 0 || token == "" {
ctx, sessionID, "password123",
)
if err != nil {
t.Fatal(err)
}
// Verify we can now log in with the password.
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "passuser", "password123", "", "")
if err != nil {
t.Fatal(err)
}
if loginSID == 0 || loginCID == 0 || loginToken == "" {
t.Fatal("expected valid ids and token") t.Fatal("expected valid ids and token")
} }
// Verify session works via token lookup.
sid, cid, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if sid != sessionID || cid != clientID {
t.Fatal("session/client id mismatch")
}
if nick != "reguser" {
t.Fatalf("expected reguser, got %s", nick)
}
} }
func TestSetPasswordThenWrongLogin(t *testing.T) { func TestRegisterUserWithUserHost(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := sessionID, _, _, err := database.RegisterUser(
database.CreateSession(ctx, "wrongpw", "", "", "") ctx, "reguhost", "password123",
if err != nil { "myident", "example.org", "",
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "correctpass",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
loginSID, loginCID, loginToken, loginErr := info, err := database.GetSessionHostInfo(
database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "") ctx, sessionID,
if loginErr == nil { )
t.Fatal("expected error for wrong password") if err != nil {
t.Fatal(err)
} }
_ = loginSID if info.Username != "myident" {
_ = loginCID t.Fatalf(
_ = loginToken "expected myident, got %s", info.Username,
)
}
if info.Hostname != "example.org" {
t.Fatalf(
"expected example.org, got %s",
info.Hostname,
)
}
}
func TestRegisterUserDefaultUsername(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.RegisterUser(
ctx, "regdefault", "password123", "", "", "",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.Username != "regdefault" {
t.Fatalf(
"expected regdefault, got %s",
info.Username,
)
}
}
func TestRegisterUserDuplicateNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "dupnick", "password123", "", "", "")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345", "", "", "")
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
}
_ = dupSID
_ = dupCID
_ = dupToken
} }
func TestLoginUser(t *testing.T) { func TestLoginUser(t *testing.T) {
@@ -73,26 +134,23 @@ func TestLoginUser(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := regSID, regCID, regToken, err :=
database.CreateSession(ctx, "loginuser", "", "", "") database.RegisterUser(ctx, "loginuser", "mypassword", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.SetPassword( _ = regSID
ctx, sessionID, "mypassword", _ = regCID
) _ = regToken
if err != nil {
t.Fatal(err)
}
loginSID, loginCID, token, err := sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword", "", "") database.LoginUser(ctx, "loginuser", "mypassword", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if loginSID == 0 || loginCID == 0 || token == "" { if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token") t.Fatal("expected valid ids and token")
} }
@@ -108,6 +166,110 @@ func TestLoginUser(t *testing.T) {
} }
} }
func TestLoginUserStoresClientIPHostname(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err := database.RegisterUser(
ctx, "loginipuser", "password123",
"", "", "10.0.0.1",
)
_ = regSID
_ = regCID
_ = regToken
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.LoginUser(
ctx, "loginipuser", "password123",
"10.0.0.99", "newhost.example.com",
)
if err != nil {
t.Fatal(err)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "10.0.0.99" {
t.Fatalf(
"expected client IP 10.0.0.99, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "newhost.example.com" {
t.Fatalf(
"expected hostname newhost.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestRegisterUserStoresSessionIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.RegisterUser(
ctx, "regipuser", "password123",
"ident", "host.local", "172.16.0.5",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "172.16.0.5" {
t.Fatalf(
"expected session IP 172.16.0.5, got %s",
info.IP,
)
}
}
func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass", "", "", "")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
}
func TestLoginUserNoPassword(t *testing.T) { func TestLoginUserNoPassword(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -135,21 +135,13 @@ type migration struct {
func (database *Database) runMigrations( func (database *Database) runMigrations(
ctx context.Context, ctx context.Context,
) error { ) error {
bootstrap, err := SchemaFiles.ReadFile( _, err := database.conn.ExecContext(ctx,
"schema/000.sql", `CREATE TABLE IF NOT EXISTS schema_migrations (
) version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)`)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
"read bootstrap migration: %w", err, "create schema_migrations: %w", err,
)
}
_, err = database.conn.ExecContext(
ctx, string(bootstrap),
)
if err != nil {
return fmt.Errorf(
"execute bootstrap migration: %w", err,
) )
} }
@@ -278,11 +270,6 @@ func (database *Database) loadMigrations() (
continue continue
} }
// Skip bootstrap migration; it is executed separately.
if version == 0 {
continue
}
content, readErr := SchemaFiles.ReadFile( content, readErr := SchemaFiles.ReadFile(
"schema/" + entry.Name(), "schema/" + entry.Name(),
) )

View File

@@ -1,14 +0,0 @@
package db_test
import (
"os"
"testing"
"git.eeqj.de/sneak/neoirc/internal/db"
"golang.org/x/crypto/bcrypt"
)
func TestMain(m *testing.M) {
db.SetBcryptCost(bcrypt.MinCost)
os.Exit(m.Run())
}

View File

@@ -9,7 +9,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"time" "time"
"git.eeqj.de/sneak/neoirc/pkg/irc" "git.eeqj.de/sneak/neoirc/pkg/irc"
@@ -77,8 +76,6 @@ type MemberInfo struct {
Nick string `json:"nick"` Nick string `json:"nick"`
Username string `json:"username"` Username string `json:"username"`
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
IsOperator bool `json:"isOperator"`
IsVoiced bool `json:"isVoiced"`
LastSeen time.Time `json:"lastSeen"` LastSeen time.Time `json:"lastSeen"`
} }
@@ -439,237 +436,6 @@ func (database *Database) JoinChannel(
return nil return nil
} }
// JoinChannelAsOperator adds a session to a channel with
// operator status. Used when a user creates a new channel.
func (database *Database) JoinChannelAsOperator(
ctx context.Context,
channelID, sessionID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_members
(channel_id, session_id, is_operator, joined_at)
VALUES (?, ?, 1, ?)`,
channelID, sessionID, time.Now())
if err != nil {
return fmt.Errorf(
"join channel as operator: %w", err,
)
}
return nil
}
// CountChannelMembers returns the number of members in
// a channel.
func (database *Database) CountChannelMembers(
ctx context.Context,
channelID int64,
) (int64, error) {
var count int64
err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM channel_members
WHERE channel_id = ?`,
channelID,
).Scan(&count)
if err != nil {
return 0, fmt.Errorf(
"count channel members: %w", err,
)
}
return count, nil
}
// IsChannelOperator checks if a session has operator
// status in a channel.
func (database *Database) IsChannelOperator(
ctx context.Context,
channelID, sessionID int64,
) (bool, error) {
var isOp int
err := database.conn.QueryRowContext(ctx,
`SELECT is_operator FROM channel_members
WHERE channel_id = ? AND session_id = ?`,
channelID, sessionID,
).Scan(&isOp)
if err != nil {
return false, fmt.Errorf(
"check channel operator: %w", err,
)
}
return isOp != 0, nil
}
// IsChannelVoiced checks if a session has voice status
// in a channel.
func (database *Database) IsChannelVoiced(
ctx context.Context,
channelID, sessionID int64,
) (bool, error) {
var isVoiced int
err := database.conn.QueryRowContext(ctx,
`SELECT is_voiced FROM channel_members
WHERE channel_id = ? AND session_id = ?`,
channelID, sessionID,
).Scan(&isVoiced)
if err != nil {
return false, fmt.Errorf(
"check channel voiced: %w", err,
)
}
return isVoiced != 0, nil
}
// SetChannelMemberOperator sets or clears operator status
// for a session in a channel.
func (database *Database) SetChannelMemberOperator(
ctx context.Context,
channelID, sessionID int64,
isOp bool,
) error {
val := 0
if isOp {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channel_members
SET is_operator = ?
WHERE channel_id = ? AND session_id = ?`,
val, channelID, sessionID)
if err != nil {
return fmt.Errorf(
"set channel member operator: %w", err,
)
}
return nil
}
// SetChannelMemberVoiced sets or clears voice status
// for a session in a channel.
func (database *Database) SetChannelMemberVoiced(
ctx context.Context,
channelID, sessionID int64,
isVoiced bool,
) error {
val := 0
if isVoiced {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channel_members
SET is_voiced = ?
WHERE channel_id = ? AND session_id = ?`,
val, channelID, sessionID)
if err != nil {
return fmt.Errorf(
"set channel member voiced: %w", err,
)
}
return nil
}
// IsChannelModerated returns whether a channel has +m set.
func (database *Database) IsChannelModerated(
ctx context.Context,
channelID int64,
) (bool, error) {
var isMod int
err := database.conn.QueryRowContext(ctx,
`SELECT is_moderated FROM channels
WHERE id = ?`,
channelID,
).Scan(&isMod)
if err != nil {
return false, fmt.Errorf(
"check channel moderated: %w", err,
)
}
return isMod != 0, nil
}
// SetChannelModerated sets or clears +m on a channel.
func (database *Database) SetChannelModerated(
ctx context.Context,
channelID int64,
moderated bool,
) error {
val := 0
if moderated {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET is_moderated = ?, updated_at = ?
WHERE id = ?`,
val, time.Now(), channelID)
if err != nil {
return fmt.Errorf(
"set channel moderated: %w", err,
)
}
return nil
}
// IsChannelTopicLocked returns whether a channel has
// +t set.
func (database *Database) IsChannelTopicLocked(
ctx context.Context,
channelID int64,
) (bool, error) {
var isLocked int
err := database.conn.QueryRowContext(ctx,
`SELECT is_topic_locked FROM channels
WHERE id = ?`,
channelID,
).Scan(&isLocked)
if err != nil {
return false, fmt.Errorf(
"check channel topic locked: %w", err,
)
}
return isLocked != 0, nil
}
// SetChannelTopicLocked sets or clears +t on a channel.
func (database *Database) SetChannelTopicLocked(
ctx context.Context,
channelID int64,
locked bool,
) error {
val := 0
if locked {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET is_topic_locked = ?, updated_at = ?
WHERE id = ?`,
val, time.Now(), channelID)
if err != nil {
return fmt.Errorf(
"set channel topic locked: %w", err,
)
}
return nil
}
// PartChannel removes a session from a channel. // PartChannel removes a session from a channel.
func (database *Database) PartChannel( func (database *Database) PartChannel(
ctx context.Context, ctx context.Context,
@@ -781,8 +547,7 @@ func (database *Database) ChannelMembers(
) ([]MemberInfo, error) { ) ([]MemberInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT s.id, s.nick, s.username, `SELECT s.id, s.nick, s.username,
s.hostname, cm.is_operator, cm.is_voiced, s.hostname, s.last_seen
s.last_seen
FROM sessions s FROM sessions s
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.session_id = s.id ON cm.session_id = s.id
@@ -799,16 +564,11 @@ func (database *Database) ChannelMembers(
var members []MemberInfo var members []MemberInfo
for rows.Next() { for rows.Next() {
var ( var member MemberInfo
member MemberInfo
isOp int
isV int
)
err = rows.Scan( err = rows.Scan(
&member.ID, &member.Nick, &member.ID, &member.Nick,
&member.Username, &member.Hostname, &member.Username, &member.Hostname,
&isOp, &isV,
&member.LastSeen, &member.LastSeen,
) )
if err != nil { if err != nil {
@@ -817,9 +577,6 @@ func (database *Database) ChannelMembers(
) )
} }
member.IsOperator = isOp != 0
member.IsVoiced = isV != 0
members = append(members, member) members = append(members, member)
} }
@@ -1836,580 +1593,3 @@ func (database *Database) PruneSpentHashcash(
return deleted, nil return deleted, nil
} }
// --- Tier 2: Ban system (+b) ---
// BanInfo represents a channel ban entry.
type BanInfo struct {
Mask string
SetBy string
CreatedAt time.Time
}
// AddChannelBan inserts a ban mask for a channel.
func (database *Database) AddChannelBan(
ctx context.Context,
channelID int64,
mask, setBy string,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_bans
(channel_id, mask, set_by, created_at)
VALUES (?, ?, ?, ?)`,
channelID, mask, setBy, time.Now())
if err != nil {
return fmt.Errorf("add channel ban: %w", err)
}
return nil
}
// RemoveChannelBan removes a ban mask from a channel.
func (database *Database) RemoveChannelBan(
ctx context.Context,
channelID int64,
mask string,
) error {
_, err := database.conn.ExecContext(ctx,
`DELETE FROM channel_bans
WHERE channel_id = ? AND mask = ?`,
channelID, mask)
if err != nil {
return fmt.Errorf("remove channel ban: %w", err)
}
return nil
}
// ListChannelBans returns all bans for a channel.
//
//nolint:dupl // different query+type vs filtered variant
func (database *Database) ListChannelBans(
ctx context.Context,
channelID int64,
) ([]BanInfo, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT mask, set_by, created_at
FROM channel_bans
WHERE channel_id = ?
ORDER BY created_at ASC`,
channelID)
if err != nil {
return nil, fmt.Errorf("list channel bans: %w", err)
}
defer func() { _ = rows.Close() }()
var bans []BanInfo
for rows.Next() {
var ban BanInfo
if scanErr := rows.Scan(
&ban.Mask, &ban.SetBy, &ban.CreatedAt,
); scanErr != nil {
return nil, fmt.Errorf(
"scan channel ban: %w", scanErr,
)
}
bans = append(bans, ban)
}
if rowErr := rows.Err(); rowErr != nil {
return nil, fmt.Errorf(
"iterate channel bans: %w", rowErr,
)
}
return bans, nil
}
// IsSessionBanned checks if a session's hostmask matches
// any ban in the channel. Returns true if banned.
func (database *Database) IsSessionBanned(
ctx context.Context,
channelID, sessionID int64,
) (bool, error) {
// Get the session's hostmask parts.
var nick, username, hostname string
err := database.conn.QueryRowContext(ctx,
`SELECT nick, username, hostname
FROM sessions WHERE id = ?`,
sessionID,
).Scan(&nick, &username, &hostname)
if err != nil {
return false, fmt.Errorf(
"get session hostmask: %w", err,
)
}
hostmask := FormatHostmask(nick, username, hostname)
// Get all ban masks for the channel.
bans, banErr := database.ListChannelBans(ctx, channelID)
if banErr != nil {
return false, banErr
}
for _, ban := range bans {
if MatchBanMask(ban.Mask, hostmask) {
return true, nil
}
}
return false, nil
}
// MatchBanMask checks if hostmask matches a ban pattern
// using IRC-style wildcard matching (* and ?).
func MatchBanMask(pattern, hostmask string) bool {
return wildcardMatch(
strings.ToLower(pattern),
strings.ToLower(hostmask),
)
}
// wildcardMatch implements simple glob-style matching
// with * (any sequence) and ? (any single character).
func wildcardMatch(pattern, str string) bool {
for len(pattern) > 0 {
switch pattern[0] {
case '*':
// Skip consecutive asterisks.
for len(pattern) > 0 && pattern[0] == '*' {
pattern = pattern[1:]
}
if len(pattern) == 0 {
return true
}
for i := 0; i <= len(str); i++ {
if wildcardMatch(pattern, str[i:]) {
return true
}
}
return false
case '?':
if len(str) == 0 {
return false
}
pattern = pattern[1:]
str = str[1:]
default:
if len(str) == 0 || pattern[0] != str[0] {
return false
}
pattern = pattern[1:]
str = str[1:]
}
}
return len(str) == 0
}
// --- Tier 2: Invite-only (+i) ---
// IsChannelInviteOnly checks if a channel has +i mode.
func (database *Database) IsChannelInviteOnly(
ctx context.Context,
channelID int64,
) (bool, error) {
var isInviteOnly int
err := database.conn.QueryRowContext(ctx,
`SELECT is_invite_only FROM channels
WHERE id = ?`,
channelID,
).Scan(&isInviteOnly)
if err != nil {
return false, fmt.Errorf(
"check invite only: %w", err,
)
}
return isInviteOnly != 0, nil
}
// SetChannelInviteOnly sets or unsets +i mode.
func (database *Database) SetChannelInviteOnly(
ctx context.Context,
channelID int64,
inviteOnly bool,
) error {
val := 0
if inviteOnly {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET is_invite_only = ?, updated_at = ?
WHERE id = ?`,
val, time.Now(), channelID)
if err != nil {
return fmt.Errorf(
"set invite only: %w", err,
)
}
return nil
}
// AddChannelInvite records that a session has been
// invited to a channel.
func (database *Database) AddChannelInvite(
ctx context.Context,
channelID, sessionID int64,
invitedBy string,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_invites
(channel_id, session_id, invited_by, created_at)
VALUES (?, ?, ?, ?)`,
channelID, sessionID, invitedBy, time.Now())
if err != nil {
return fmt.Errorf("add channel invite: %w", err)
}
return nil
}
// HasChannelInvite checks if a session has been invited
// to a channel.
func (database *Database) HasChannelInvite(
ctx context.Context,
channelID, sessionID int64,
) (bool, error) {
var count int
err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM channel_invites
WHERE channel_id = ? AND session_id = ?`,
channelID, sessionID,
).Scan(&count)
if err != nil {
return false, fmt.Errorf(
"check invite: %w", err,
)
}
return count > 0, nil
}
// ClearChannelInvite removes a session's invite to a
// channel (called after successful JOIN).
func (database *Database) ClearChannelInvite(
ctx context.Context,
channelID, sessionID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`DELETE FROM channel_invites
WHERE channel_id = ? AND session_id = ?`,
channelID, sessionID)
if err != nil {
return fmt.Errorf("clear invite: %w", err)
}
return nil
}
// --- Tier 2: Secret (+s) ---
// IsChannelSecret checks if a channel has +s mode.
func (database *Database) IsChannelSecret(
ctx context.Context,
channelID int64,
) (bool, error) {
var isSecret int
err := database.conn.QueryRowContext(ctx,
`SELECT is_secret FROM channels
WHERE id = ?`,
channelID,
).Scan(&isSecret)
if err != nil {
return false, fmt.Errorf(
"check secret: %w", err,
)
}
return isSecret != 0, nil
}
// SetChannelSecret sets or unsets +s mode.
func (database *Database) SetChannelSecret(
ctx context.Context,
channelID int64,
secret bool,
) error {
val := 0
if secret {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET is_secret = ?, updated_at = ?
WHERE id = ?`,
val, time.Now(), channelID)
if err != nil {
return fmt.Errorf("set secret: %w", err)
}
return nil
}
// --- No External Messages (+n) ---
// IsChannelNoExternal checks if a channel has +n mode.
func (database *Database) IsChannelNoExternal(
ctx context.Context,
channelID int64,
) (bool, error) {
var isNoExternal int
err := database.conn.QueryRowContext(ctx,
`SELECT is_no_external FROM channels
WHERE id = ?`,
channelID,
).Scan(&isNoExternal)
if err != nil {
return false, fmt.Errorf(
"check no external: %w", err,
)
}
return isNoExternal != 0, nil
}
// SetChannelNoExternal sets or unsets +n mode.
func (database *Database) SetChannelNoExternal(
ctx context.Context,
channelID int64,
noExternal bool,
) error {
val := 0
if noExternal {
val = 1
}
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET is_no_external = ?, updated_at = ?
WHERE id = ?`,
val, time.Now(), channelID)
if err != nil {
return fmt.Errorf("set no external: %w", err)
}
return nil
}
// ListAllChannelsWithCountsFiltered returns all channels
// with member counts, excluding secret channels that
// the given session is not a member of.
//
//nolint:dupl // different query+type vs ListChannelBans
func (database *Database) ListAllChannelsWithCountsFiltered(
ctx context.Context,
sessionID int64,
) ([]ChannelInfoFull, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT c.name, COUNT(cm.id) AS member_count,
c.topic
FROM channels c
LEFT JOIN channel_members cm
ON cm.channel_id = c.id
WHERE c.is_secret = 0
OR c.id IN (
SELECT channel_id FROM channel_members
WHERE session_id = ?
)
GROUP BY c.id
ORDER BY c.name ASC`,
sessionID)
if err != nil {
return nil, fmt.Errorf(
"list channels filtered: %w", err,
)
}
defer func() { _ = rows.Close() }()
var channels []ChannelInfoFull
for rows.Next() {
var chanInfo ChannelInfoFull
if scanErr := rows.Scan(
&chanInfo.Name,
&chanInfo.MemberCount,
&chanInfo.Topic,
); scanErr != nil {
return nil, fmt.Errorf(
"scan channel: %w", scanErr,
)
}
channels = append(channels, chanInfo)
}
if rowErr := rows.Err(); rowErr != nil {
return nil, fmt.Errorf(
"iterate channels: %w", rowErr,
)
}
return channels, nil
}
// GetSessionChannelsFiltered returns channels a session
// belongs to, optionally excluding secret channels for
// WHOIS (when the querier is not in the same channel).
// If querierID == targetID, returns all channels.
func (database *Database) GetSessionChannelsFiltered(
ctx context.Context,
targetSID, querierSID int64,
) ([]ChannelInfo, error) {
// If querying yourself, return all channels.
if targetSID == querierSID {
return database.GetSessionChannels(ctx, targetSID)
}
rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic
FROM channels c
JOIN channel_members cm
ON cm.channel_id = c.id
WHERE cm.session_id = ?
AND (c.is_secret = 0
OR c.id IN (
SELECT channel_id FROM channel_members
WHERE session_id = ?
))
ORDER BY c.name ASC`,
targetSID, querierSID)
if err != nil {
return nil, fmt.Errorf(
"get session channels filtered: %w", err,
)
}
defer func() { _ = rows.Close() }()
var channels []ChannelInfo
for rows.Next() {
var chanInfo ChannelInfo
if scanErr := rows.Scan(
&chanInfo.ID,
&chanInfo.Name,
&chanInfo.Topic,
); scanErr != nil {
return nil, fmt.Errorf(
"scan channel: %w", scanErr,
)
}
channels = append(channels, chanInfo)
}
if rowErr := rows.Err(); rowErr != nil {
return nil, fmt.Errorf(
"iterate channels: %w", rowErr,
)
}
return channels, nil
}
// --- Tier 2: Channel Key (+k) ---
// GetChannelKey returns the key for a channel (empty
// string means no key set).
func (database *Database) GetChannelKey(
ctx context.Context,
channelID int64,
) (string, error) {
var key string
err := database.conn.QueryRowContext(ctx,
`SELECT channel_key FROM channels
WHERE id = ?`,
channelID,
).Scan(&key)
if err != nil {
return "", fmt.Errorf("get channel key: %w", err)
}
return key, nil
}
// SetChannelKey sets or clears the key for a channel.
func (database *Database) SetChannelKey(
ctx context.Context,
channelID int64,
key string,
) error {
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET channel_key = ?, updated_at = ?
WHERE id = ?`,
key, time.Now(), channelID)
if err != nil {
return fmt.Errorf("set channel key: %w", err)
}
return nil
}
// --- Tier 2: User Limit (+l) ---
// GetChannelUserLimit returns the user limit for a
// channel (0 means no limit).
func (database *Database) GetChannelUserLimit(
ctx context.Context,
channelID int64,
) (int, error) {
var limit int
err := database.conn.QueryRowContext(ctx,
`SELECT user_limit FROM channels
WHERE id = ?`,
channelID,
).Scan(&limit)
if err != nil {
return 0, fmt.Errorf(
"get channel user limit: %w", err,
)
}
return limit, nil
}
// SetChannelUserLimit sets the user limit for a channel.
func (database *Database) SetChannelUserLimit(
ctx context.Context,
channelID int64,
limit int,
) error {
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET user_limit = ?, updated_at = ?
WHERE id = ?`,
limit, time.Now(), channelID)
if err != nil {
return fmt.Errorf(
"set channel user limit: %w", err,
)
}
return nil
}

View File

@@ -1017,474 +1017,3 @@ func TestGetOperCount(t *testing.T) {
t.Fatalf("expected 1 oper, got %d", count) t.Fatalf("expected 1 oper, got %d", count)
} }
} }
// --- Tier 2 Tests ---
func TestWildcardMatch(t *testing.T) {
t.Parallel()
tests := []struct {
pattern string
input string
match bool
}{
{"*!*@*", "nick!user@host", true},
{"*!*@*.example.com", "nick!user@foo.example.com", true},
{"*!*@*.example.com", "nick!user@other.net", false},
{"badnick!*@*", "badnick!user@host", true},
{"badnick!*@*", "goodnick!user@host", false},
{"nick!user@host", "nick!user@host", true},
{"nick!user@host", "nick!user@other", false},
{"*", "anything", true},
{"?ick!*@*", "nick!user@host", true},
{"?ick!*@*", "nn!user@host", false},
// Case-insensitive.
{"Nick!*@*", "nick!user@host", true},
}
for _, tc := range tests {
result := db.MatchBanMask(tc.pattern, tc.input)
if result != tc.match {
t.Errorf(
"MatchBanMask(%q, %q) = %v, want %v",
tc.pattern, tc.input, result, tc.match,
)
}
}
}
func TestChannelBanCRUD(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil {
t.Fatal(err)
}
// No bans initially.
bans, err := database.ListChannelBans(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(bans) != 0 {
t.Fatalf("expected 0 bans, got %d", len(bans))
}
// Add a ban.
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op",
)
if err != nil {
t.Fatal(err)
}
bans, err = database.ListChannelBans(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(bans) != 1 {
t.Fatalf("expected 1 ban, got %d", len(bans))
}
if bans[0].Mask != "*!*@evil.com" {
t.Fatalf("wrong mask: %s", bans[0].Mask)
}
// Duplicate add is ignored (OR IGNORE).
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op2",
)
if err != nil {
t.Fatal(err)
}
bans, _ = database.ListChannelBans(ctx, chID)
if len(bans) != 1 {
t.Fatalf("expected 1 ban after dup, got %d", len(bans))
}
// Remove ban.
err = database.RemoveChannelBan(
ctx, chID, "*!*@evil.com",
)
if err != nil {
t.Fatal(err)
}
bans, _ = database.ListChannelBans(ctx, chID)
if len(bans) != 0 {
t.Fatalf("expected 0 bans after remove, got %d", len(bans))
}
}
func TestIsSessionBanned(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "victim", "victim", "evil.com", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#bantest")
if err != nil {
t.Fatal(err)
}
// Not banned initially.
banned, err := database.IsSessionBanned(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if banned {
t.Fatal("should not be banned initially")
}
// Add ban matching the hostmask.
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op",
)
if err != nil {
t.Fatal(err)
}
banned, err = database.IsSessionBanned(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if !banned {
t.Fatal("should be banned")
}
}
func TestChannelInviteOnly(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#invite")
if err != nil {
t.Fatal(err)
}
// Default: not invite-only.
isIO, err := database.IsChannelInviteOnly(ctx, chID)
if err != nil {
t.Fatal(err)
}
if isIO {
t.Fatal("should not be invite-only by default")
}
// Set invite-only.
err = database.SetChannelInviteOnly(ctx, chID, true)
if err != nil {
t.Fatal(err)
}
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
if !isIO {
t.Fatal("should be invite-only")
}
// Unset.
err = database.SetChannelInviteOnly(ctx, chID, false)
if err != nil {
t.Fatal(err)
}
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
if isIO {
t.Fatal("should not be invite-only")
}
}
func TestChannelInviteCRUD(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "invited", "", "", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#inv")
if err != nil {
t.Fatal(err)
}
// No invite initially.
has, err := database.HasChannelInvite(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if has {
t.Fatal("should not have invite")
}
// Add invite.
err = database.AddChannelInvite(ctx, chID, sid, "op")
if err != nil {
t.Fatal(err)
}
has, _ = database.HasChannelInvite(ctx, chID, sid)
if !has {
t.Fatal("should have invite")
}
// Clear invite.
err = database.ClearChannelInvite(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
has, _ = database.HasChannelInvite(ctx, chID, sid)
if has {
t.Fatal("invite should be cleared")
}
}
func TestChannelSecret(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#secret")
if err != nil {
t.Fatal(err)
}
// Default: not secret.
isSec, err := database.IsChannelSecret(ctx, chID)
if err != nil {
t.Fatal(err)
}
if isSec {
t.Fatal("should not be secret by default")
}
err = database.SetChannelSecret(ctx, chID, true)
if err != nil {
t.Fatal(err)
}
isSec, _ = database.IsChannelSecret(ctx, chID)
if !isSec {
t.Fatal("should be secret")
}
}
// createTestSession is a helper to create a session and
// return only the session ID.
func createTestSession(
t *testing.T,
database *db.Database,
nick string,
) int64 {
t.Helper()
sid, _, _, err := database.CreateSession(
t.Context(), nick, "", "", "",
)
if err != nil {
t.Fatalf("create session %s: %v", nick, err)
}
return sid
}
func TestSecretChannelFiltering(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
// Create two sessions.
sid1 := createTestSession(t, database, "member")
sid2 := createTestSession(t, database, "outsider")
// Create a secret channel.
chID, _ := database.GetOrCreateChannel(ctx, "#secret")
_ = database.SetChannelSecret(ctx, chID, true)
_ = database.JoinChannel(ctx, chID, sid1)
// Create a non-secret channel.
chID2, _ := database.GetOrCreateChannel(ctx, "#public")
_ = database.JoinChannel(ctx, chID2, sid1)
// Member should see both.
list, err := database.ListAllChannelsWithCountsFiltered(
ctx, sid1,
)
if err != nil {
t.Fatal(err)
}
if len(list) != 2 {
t.Fatalf("member should see 2 channels, got %d", len(list))
}
// Outsider should only see public.
list, _ = database.ListAllChannelsWithCountsFiltered(
ctx, sid2,
)
if len(list) != 1 {
t.Fatalf("outsider should see 1 channel, got %d", len(list))
}
if list[0].Name != "#public" {
t.Fatalf("outsider should see #public, got %s", list[0].Name)
}
}
func TestWhoisChannelFiltering(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid1 := createTestSession(t, database, "target")
sid2 := createTestSession(t, database, "querier")
// Create secret channel, target joins it.
chID, _ := database.GetOrCreateChannel(ctx, "#hidden")
_ = database.SetChannelSecret(ctx, chID, true)
_ = database.JoinChannel(ctx, chID, sid1)
// Querier (non-member) should not see the channel.
channels, err := database.GetSessionChannelsFiltered(
ctx, sid1, sid2,
)
if err != nil {
t.Fatal(err)
}
if len(channels) != 0 {
t.Fatalf(
"querier should see 0 channels, got %d",
len(channels),
)
}
// Target querying self should see it.
channels, _ = database.GetSessionChannelsFiltered(
ctx, sid1, sid1,
)
if len(channels) != 1 {
t.Fatalf(
"self-query should see 1 channel, got %d",
len(channels),
)
}
}
//nolint:dupl // structurally similar to TestChannelUserLimit
func TestChannelKey(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#keyed")
if err != nil {
t.Fatal(err)
}
// Default: no key.
key, err := database.GetChannelKey(ctx, chID)
if err != nil {
t.Fatal(err)
}
if key != "" {
t.Fatalf("expected empty key, got %q", key)
}
// Set key.
err = database.SetChannelKey(ctx, chID, "secret123")
if err != nil {
t.Fatal(err)
}
key, _ = database.GetChannelKey(ctx, chID)
if key != "secret123" {
t.Fatalf("expected secret123, got %q", key)
}
// Clear key.
err = database.SetChannelKey(ctx, chID, "")
if err != nil {
t.Fatal(err)
}
key, _ = database.GetChannelKey(ctx, chID)
if key != "" {
t.Fatalf("expected empty key, got %q", key)
}
}
//nolint:dupl // structurally similar to TestChannelKey
func TestChannelUserLimit(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#limited")
if err != nil {
t.Fatal(err)
}
// Default: no limit.
limit, err := database.GetChannelUserLimit(ctx, chID)
if err != nil {
t.Fatal(err)
}
if limit != 0 {
t.Fatalf("expected 0 limit, got %d", limit)
}
// Set limit.
err = database.SetChannelUserLimit(ctx, chID, 50)
if err != nil {
t.Fatal(err)
}
limit, _ = database.GetChannelUserLimit(ctx, chID)
if limit != 50 {
t.Fatalf("expected 50, got %d", limit)
}
// Clear limit.
err = database.SetChannelUserLimit(ctx, chID, 0)
if err != nil {
t.Fatal(err)
}
limit, _ = database.GetChannelUserLimit(ctx, chID)
if limit != 0 {
t.Fatalf("expected 0, got %d", limit)
}
}

View File

@@ -1,6 +0,0 @@
-- Bootstrap: create the schema_migrations table itself.
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
INSERT OR IGNORE INTO schema_migrations (version) VALUES (0);

View File

@@ -40,46 +40,15 @@ CREATE TABLE IF NOT EXISTS channels (
topic_set_by TEXT NOT NULL DEFAULT '', topic_set_by TEXT NOT NULL DEFAULT '',
topic_set_at DATETIME, topic_set_at DATETIME,
hashcash_bits INTEGER NOT NULL DEFAULT 0, hashcash_bits INTEGER NOT NULL DEFAULT 0,
is_moderated INTEGER NOT NULL DEFAULT 0,
is_topic_locked INTEGER NOT NULL DEFAULT 1,
is_invite_only INTEGER NOT NULL DEFAULT 0,
is_secret INTEGER NOT NULL DEFAULT 0,
is_no_external INTEGER NOT NULL DEFAULT 1,
channel_key TEXT NOT NULL DEFAULT '',
user_limit INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
); );
-- Channel bans
CREATE TABLE IF NOT EXISTS channel_bans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
mask TEXT NOT NULL,
set_by TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, mask)
);
CREATE INDEX IF NOT EXISTS idx_channel_bans_channel ON channel_bans(channel_id);
-- Channel invites (in-memory would be simpler but DB survives restarts)
CREATE TABLE IF NOT EXISTS channel_invites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
invited_by TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, session_id)
);
CREATE INDEX IF NOT EXISTS idx_channel_invites_channel ON channel_invites(channel_id);
-- Channel members -- Channel members
CREATE TABLE IF NOT EXISTS channel_members ( CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
is_operator INTEGER NOT NULL DEFAULT 0,
is_voiced INTEGER NOT NULL DEFAULT 0,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP, joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, session_id) UNIQUE(channel_id, session_id)
); );

View File

@@ -1,25 +0,0 @@
package db
import (
"context"
"database/sql"
"log/slog"
)
// NewTestDatabaseFromConn creates a Database wrapping an
// existing *sql.DB connection. Intended for integration
// tests in other packages.
func NewTestDatabaseFromConn(conn *sql.DB) *Database {
return &Database{ //nolint:exhaustruct
conn: conn,
log: slog.Default(),
}
}
// RunMigrations applies all schema migrations. Exposed
// for integration tests in other packages.
func (database *Database) RunMigrations(
ctx context.Context,
) error {
return database.runMigrations(ctx)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,151 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.eeqj.de/sneak/neoirc/pkg/irc" "git.eeqj.de/sneak/neoirc/internal/db"
) )
const minPasswordLength = 8 const minPasswordLength = 8
// HandleRegister creates a new user with a password.
func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
request.Body = http.MaxBytesReader(
writer, request.Body, hdlr.maxBodySize(),
)
hdlr.handleRegister(writer, request)
}
}
func (hdlr *Handlers) handleRegister(
writer http.ResponseWriter,
request *http.Request,
) {
type registerRequest struct {
Nick string `json:"nick"`
Username string `json:"username,omitempty"`
Password string `json:"password"`
}
var payload registerRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
username := resolveUsername(
payload.Username, payload.Nick,
)
if !validUsernameRe.MatchString(username) {
hdlr.respondError(
writer, request,
"invalid username format",
http.StatusBadRequest,
)
return
}
if len(payload.Password) < minPasswordLength {
hdlr.respondError(
writer, request,
"password must be at least 8 characters",
http.StatusBadRequest,
)
return
}
hdlr.executeRegister(
writer, request,
payload.Nick, payload.Password, username,
)
}
func (hdlr *Handlers) executeRegister(
writer http.ResponseWriter,
request *http.Request,
nick, password, username string,
) {
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), remoteIP,
)
sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser(
request.Context(),
nick, password, username, hostname, remoteIP,
)
if err != nil {
hdlr.handleRegisterError(
writer, request, err,
)
return
}
hdlr.stats.IncrSessions()
hdlr.stats.IncrConnections()
hdlr.deliverMOTD(request, clientID, sessionID, nick)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleRegisterError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if db.IsUniqueConstraintError(err) {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"register user failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
}
// HandleLogin authenticates a user with nick and password. // HandleLogin authenticates a user with nick and password.
func (hdlr *Handlers) HandleLogin() http.HandlerFunc { func (hdlr *Handlers) HandleLogin() http.HandlerFunc {
return func( return func(
@@ -28,21 +168,6 @@ func (hdlr *Handlers) handleLogin(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
ip := clientIP(request)
if !hdlr.loginLimiter.Allow(ip) {
writer.Header().Set(
"Retry-After", "1",
)
hdlr.respondError(
writer, request,
"too many login attempts, try again later",
http.StatusTooManyRequests,
)
return
}
type loginRequest struct { type loginRequest struct {
Nick string `json:"nick"` Nick string `json:"nick"`
Password string `json:"password"` Password string `json:"password"`
@@ -73,16 +198,6 @@ func (hdlr *Handlers) handleLogin(
return return
} }
hdlr.executeLogin(
writer, request, payload.Nick, payload.Password,
)
}
func (hdlr *Handlers) executeLogin(
writer http.ResponseWriter,
request *http.Request,
nick, password string,
) {
remoteIP := clientIP(request) remoteIP := clientIP(request)
hostname := resolveHostname( hostname := resolveHostname(
@@ -92,7 +207,8 @@ func (hdlr *Handlers) executeLogin(
sessionID, clientID, token, err := sessionID, clientID, token, err :=
hdlr.params.Database.LoginUser( hdlr.params.Database.LoginUser(
request.Context(), request.Context(),
nick, password, payload.Nick,
payload.Password,
remoteIP, hostname, remoteIP, hostname,
) )
if err != nil { if err != nil {
@@ -108,75 +224,18 @@ func (hdlr *Handlers) executeLogin(
hdlr.stats.IncrConnections() hdlr.stats.IncrConnections()
hdlr.deliverMOTD( hdlr.deliverMOTD(
request, clientID, sessionID, nick, request, clientID, sessionID, payload.Nick,
) )
// Initialize channel state so the new client knows // Initialize channel state so the new client knows
// which channels the session already belongs to. // which channels the session already belongs to.
hdlr.initChannelState( hdlr.initChannelState(
request, clientID, sessionID, nick, request, clientID, sessionID, payload.Nick,
) )
hdlr.setAuthCookie(writer, request, token)
hdlr.respondJSON(writer, request, map[string]any{ hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID, "id": sessionID,
"nick": nick, "nick": payload.Nick,
"token": token,
}, http.StatusOK) }, http.StatusOK)
} }
// handlePass handles the IRC PASS command to set a
// password on the authenticated session, enabling
// multi-client login via POST /api/v1/login.
func (hdlr *Handlers) handlePass(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
lines := bodyLines()
if len(lines) == 0 || lines[0] == "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Not enough parameters",
)
return
}
password := lines[0]
if len(password) < minPasswordLength {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Password must be at least 8 characters",
)
return
}
err := hdlr.params.Database.SetPassword(
request.Context(), sessionID, password,
)
if err != nil {
hdlr.log.Error(
"set password failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}

View File

@@ -16,8 +16,6 @@ import (
"git.eeqj.de/sneak/neoirc/internal/hashcash" "git.eeqj.de/sneak/neoirc/internal/hashcash"
"git.eeqj.de/sneak/neoirc/internal/healthcheck" "git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/internal/stats" "git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx" "go.uber.org/fx"
) )
@@ -34,8 +32,6 @@ type Params struct {
Database *db.Database Database *db.Database
Healthcheck *healthcheck.Healthcheck Healthcheck *healthcheck.Healthcheck
Stats *stats.Tracker Stats *stats.Tracker
Broker *broker.Broker
Service *service.Service
} }
const defaultIdleTimeout = 30 * 24 * time.Hour const defaultIdleTimeout = 30 * 24 * time.Hour
@@ -51,10 +47,8 @@ type Handlers struct {
log *slog.Logger log *slog.Logger
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
broker *broker.Broker broker *broker.Broker
svc *service.Service
hashcashVal *hashcash.Validator hashcashVal *hashcash.Validator
channelHashcash *hashcash.ChannelValidator channelHashcash *hashcash.ChannelValidator
loginLimiter *ratelimit.Limiter
stats *stats.Tracker stats *stats.Tracker
cancelCleanup context.CancelFunc cancelCleanup context.CancelFunc
} }
@@ -69,25 +63,13 @@ func New(
resource = "neoirc" resource = "neoirc"
} }
loginRate := params.Config.LoginRateLimit
if loginRate <= 0 {
loginRate = ratelimit.DefaultRate
}
loginBurst := params.Config.LoginRateBurst
if loginBurst <= 0 {
loginBurst = ratelimit.DefaultBurst
}
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params, params: &params,
log: params.Logger.Get(), log: params.Logger.Get(),
hc: params.Healthcheck, hc: params.Healthcheck,
broker: params.Broker, broker: broker.New(),
svc: params.Service,
hashcashVal: hashcash.NewValidator(resource), hashcashVal: hashcash.NewValidator(resource),
channelHashcash: hashcash.NewChannelValidator(), channelHashcash: hashcash.NewChannelValidator(),
loginLimiter: ratelimit.New(loginRate, loginBurst),
stats: params.Stats, stats: params.Stats,
} }
@@ -180,10 +162,6 @@ func (hdlr *Handlers) stopCleanup() {
if hdlr.cancelCleanup != nil { if hdlr.cancelCleanup != nil {
hdlr.cancelCleanup() hdlr.cancelCleanup()
} }
if hdlr.loginLimiter != nil {
hdlr.loginLimiter.Stop()
}
} }
func (hdlr *Handlers) cleanupLoop(ctx context.Context) { func (hdlr *Handlers) cleanupLoop(ctx context.Context) {

File diff suppressed because it is too large Load Diff

View File

@@ -1,502 +0,0 @@
package ircserver
import (
"bufio"
"context"
"fmt"
"log/slog"
"net"
"strconv"
"strings"
"sync"
"time"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
const (
maxLineLen = 512
readTimeout = 5 * time.Minute
writeTimeout = 30 * time.Second
dnsTimeout = 3 * time.Second
pollInterval = 100 * time.Millisecond
pingInterval = 90 * time.Second
pongDeadline = 30 * time.Second
maxNickLen = 32
minPasswordLen = 8
maxHashcashBits = 40
)
// cmdHandler is the signature for registered IRC command
// handlers.
type cmdHandler func(ctx context.Context, msg *Message)
// Conn represents a single IRC client TCP connection.
type Conn struct {
conn net.Conn
log *slog.Logger
database *db.Database
brk *broker.Broker
cfg *config.Config
svc *service.Service
serverSfx string
commands map[string]cmdHandler
mu sync.Mutex
nick string
username string
realname string
hostname string
remoteIP string
sessionID int64
clientID int64
registered bool
gotNick bool
gotUser bool
passWord string
lastQueueID int64
closed bool
cancel context.CancelFunc
}
func newConn(
ctx context.Context,
tcpConn net.Conn,
log *slog.Logger,
database *db.Database,
brk *broker.Broker,
cfg *config.Config,
svc *service.Service,
) *Conn {
host, _, _ := net.SplitHostPort(tcpConn.RemoteAddr().String())
srvName := cfg.ServerName
if srvName == "" {
srvName = "neoirc"
}
conn := &Conn{ //nolint:exhaustruct // zero-value defaults
conn: tcpConn,
log: log,
database: database,
brk: brk,
cfg: cfg,
svc: svc,
serverSfx: srvName,
remoteIP: host,
hostname: resolveHost(ctx, host),
}
conn.commands = conn.buildCommandMap()
return conn
}
// buildCommandMap returns a map from IRC command strings
// to handler functions.
func (c *Conn) buildCommandMap() map[string]cmdHandler {
return map[string]cmdHandler{
irc.CmdPing: func(_ context.Context, msg *Message) {
c.handlePing(msg)
},
"PONG": func(context.Context, *Message) {},
irc.CmdNick: c.handleNick,
irc.CmdPrivmsg: c.handlePrivmsg,
irc.CmdNotice: c.handlePrivmsg,
irc.CmdJoin: c.handleJoin,
irc.CmdPart: c.handlePart,
irc.CmdQuit: func(_ context.Context, msg *Message) {
c.handleQuit(msg)
},
irc.CmdTopic: c.handleTopic,
irc.CmdMode: c.handleMode,
irc.CmdNames: c.handleNames,
irc.CmdList: func(ctx context.Context, _ *Message) { c.handleList(ctx) },
irc.CmdWhois: c.handleWhois,
irc.CmdWho: c.handleWho,
irc.CmdLusers: func(ctx context.Context, _ *Message) { c.handleLusers(ctx) },
irc.CmdMotd: func(context.Context, *Message) { c.deliverMOTD() },
irc.CmdOper: c.handleOper,
irc.CmdAway: c.handleAway,
irc.CmdKick: c.handleKick,
irc.CmdPass: c.handlePassPostReg,
"INVITE": c.handleInvite,
"CAP": func(_ context.Context, msg *Message) {
c.handleCAP(msg)
},
"USERHOST": c.handleUserhost,
}
}
// resolveHost does a reverse DNS lookup, returning the IP
// on failure.
func resolveHost(ctx context.Context, addr string) string {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
resolver := &net.Resolver{} //nolint:exhaustruct
names, err := resolver.LookupAddr(ctx, addr)
if err != nil || len(names) == 0 {
return addr
}
return strings.TrimSuffix(names[0], ".")
}
// serve is the main loop for a single IRC client connection.
func (c *Conn) serve(ctx context.Context) {
ctx, c.cancel = context.WithCancel(ctx)
defer c.cleanup(ctx)
scanner := bufio.NewScanner(c.conn)
scanner.Buffer(make([]byte, maxLineLen), maxLineLen)
for {
_ = c.conn.SetReadDeadline(
time.Now().Add(readTimeout),
)
if !scanner.Scan() {
return
}
line := scanner.Text()
if line == "" {
continue
}
msg := ParseMessage(line)
if msg == nil {
continue
}
c.handleMessage(ctx, msg)
if c.closed {
return
}
}
}
func (c *Conn) cleanup(ctx context.Context) {
c.mu.Lock()
wasRegistered := c.registered
sessID := c.sessionID
nick := c.nick
c.closed = true
c.mu.Unlock()
if wasRegistered && sessID > 0 {
c.svc.BroadcastQuit(
ctx, sessID, nick, "Connection closed",
)
}
c.conn.Close() //nolint:errcheck,gosec
}
// send writes a formatted IRC line to the connection.
func (c *Conn) send(line string) {
_ = c.conn.SetWriteDeadline(
time.Now().Add(writeTimeout),
)
_, _ = fmt.Fprintf(c.conn, "%s\r\n", line)
}
// sendNumeric sends a numeric reply from the server.
func (c *Conn) sendNumeric(
code irc.IRCMessageType,
params ...string,
) {
nick := c.nick
if nick == "" {
nick = "*"
}
allParams := make([]string, 0, 1+len(params))
allParams = append(allParams, nick)
allParams = append(allParams, params...)
c.send(FormatMessage(
c.serverSfx, code.Code(), allParams...,
))
}
// sendFromServer sends a message from the server.
func (c *Conn) sendFromServer(
command string, params ...string,
) {
c.send(FormatMessage(c.serverSfx, command, params...))
}
// hostmask returns the client's full hostmask
// (nick!user@host).
func (c *Conn) hostmask() string {
user := c.username
if user == "" {
user = c.nick
}
host := c.hostname
if host == "" {
host = c.remoteIP
}
return c.nick + "!" + user + "@" + host
}
// handleMessage dispatches a parsed IRC message using
// the command handler map.
func (c *Conn) handleMessage(
ctx context.Context,
msg *Message,
) {
// Before registration, only NICK, USER, PASS, PING,
// QUIT, and CAP are accepted.
if !c.registered {
c.handlePreRegistration(ctx, msg)
return
}
handler, ok := c.commands[msg.Command]
if !ok {
c.sendNumeric(
irc.ErrUnknownCommand,
msg.Command, "Unknown command",
)
return
}
handler(ctx, msg)
}
// handlePreRegistration handles messages before the
// connection is registered (NICK+USER received).
func (c *Conn) handlePreRegistration(
ctx context.Context,
msg *Message,
) {
switch msg.Command {
case irc.CmdPass:
if len(msg.Params) < 1 {
c.sendNumeric(
irc.ErrNeedMoreParams,
"PASS", "Not enough parameters",
)
return
}
c.passWord = msg.Params[0]
case irc.CmdNick:
if len(msg.Params) < 1 {
c.sendNumeric(
irc.ErrNoNicknameGiven,
"No nickname given",
)
return
}
c.nick = msg.Params[0]
if len(c.nick) > maxNickLen {
c.nick = c.nick[:maxNickLen]
}
c.gotNick = true
case irc.CmdUser:
if len(msg.Params) < 4 { //nolint:mnd
c.sendNumeric(
irc.ErrNeedMoreParams,
"USER", "Not enough parameters",
)
return
}
c.username = msg.Params[0]
c.realname = msg.Params[3]
c.gotUser = true
case irc.CmdPing:
c.handlePing(msg)
return
case irc.CmdQuit:
c.handleQuit(msg)
return
case "CAP":
c.handleCAP(msg)
return
default:
c.sendNumeric(
irc.ErrNotRegistered,
"You have not registered",
)
return
}
// Try to complete registration once we have both
// NICK and USER.
if c.gotNick && c.gotUser {
c.completeRegistration(ctx)
}
}
// completeRegistration creates a session and sends the
// welcome burst.
func (c *Conn) completeRegistration(ctx context.Context) {
// Check if nick is valid.
if c.nick == "" {
c.sendNumeric(
irc.ErrNoNicknameGiven, "No nickname given",
)
return
}
// Create session in DB.
sessionID, clientID, _, err := c.database.CreateSession(
ctx, c.nick, c.username, c.hostname, c.remoteIP,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") ||
strings.Contains(err.Error(), "nick") {
c.sendNumeric(
irc.ErrNicknameInUse,
c.nick, "Nickname is already in use",
)
return
}
c.log.Error(
"failed to create session", "error", err,
)
c.send("ERROR :Internal server error")
c.closed = true
return
}
c.mu.Lock()
c.sessionID = sessionID
c.clientID = clientID
c.registered = true
c.mu.Unlock()
// If PASS was provided before registration, set the
// session password.
if c.passWord != "" && len(c.passWord) >= minPasswordLen {
c.setPassword(ctx, c.passWord)
}
// Send welcome burst.
c.deliverWelcome()
c.deliverLusers(ctx)
c.deliverMOTD()
// Start the message relay goroutine.
go c.relayMessages(ctx)
}
// deliverWelcome sends 001-005 welcome numerics.
func (c *Conn) deliverWelcome() {
c.sendNumeric(irc.RplWelcome, fmt.Sprintf(
"Welcome to the %s Network, %s",
c.serverSfx, c.hostmask(),
))
c.sendNumeric(irc.RplYourHost, fmt.Sprintf(
"Your host is %s, running version neoirc",
c.serverSfx,
))
c.sendNumeric(
irc.RplCreated,
"This server was created recently",
)
c.sendNumeric(
irc.RplMyInfo,
c.serverSfx, "neoirc", "", "mnst",
)
c.sendNumeric(
irc.RplIsupport,
"CHANTYPES=#",
"NICKLEN=32",
"PREFIX=(ov)@+",
"CHANMODES=,,H,imnst",
"NETWORK="+c.serverSfx,
"are supported by this server",
)
}
// deliverLusers sends 251/252/254/255 server statistics.
func (c *Conn) deliverLusers(ctx context.Context) {
users, _ := c.database.GetUserCount(ctx)
opers, _ := c.database.GetOperCount(ctx)
channels, _ := c.database.GetChannelCount(ctx)
c.sendNumeric(irc.RplLuserClient, fmt.Sprintf(
"There are %d users and 0 invisible on 1 servers",
users,
))
c.sendNumeric(
irc.RplLuserOp,
strconv.FormatInt(opers, 10),
"operator(s) online",
)
c.sendNumeric(
irc.RplLuserChannels,
strconv.FormatInt(channels, 10),
"channels formed",
)
c.sendNumeric(irc.RplLuserMe, fmt.Sprintf(
"I have %d clients and 1 servers", users,
))
}
// deliverMOTD sends 375/372/376 MOTD lines.
func (c *Conn) deliverMOTD() {
motd := c.cfg.MOTD
if motd == "" {
c.sendNumeric(
irc.ErrNoMotd, "MOTD File is missing",
)
return
}
c.sendNumeric(irc.RplMotdStart, fmt.Sprintf(
"- %s Message of the Day -", c.serverSfx,
))
for _, line := range strings.Split(motd, "\n") {
c.sendNumeric(irc.RplMotd, "- "+line)
}
c.sendNumeric(
irc.RplEndOfMotd, "End of /MOTD command",
)
}
// setPassword sets a bcrypt password on the session.
func (c *Conn) setPassword(ctx context.Context, pw string) {
// Use the database's auth module to hash and store.
err := c.database.SetPassword(ctx, c.sessionID, pw)
if err != nil {
c.log.Error(
"failed to set password", "error", err,
)
}
}

View File

@@ -1,49 +0,0 @@
package ircserver
import (
"context"
"log/slog"
"net"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/service"
)
// NewTestServer creates a Server suitable for testing.
// The caller must call Stop() when finished.
func NewTestServer(
log *slog.Logger,
cfg *config.Config,
database *db.Database,
brk *broker.Broker,
) *Server {
svc := service.NewTestService(
database, brk, cfg, log,
)
return &Server{ //nolint:exhaustruct
log: log,
cfg: cfg,
database: database,
brk: brk,
svc: svc,
conns: make(map[*Conn]struct{}),
}
}
// Start exposes the unexported start method for tests.
func (s *Server) Start(addr string) error {
return s.start(context.Background(), addr)
}
// Stop exposes the unexported stop method for tests.
func (s *Server) Stop() {
s.stop()
}
// Listener returns the server's net.Listener for tests.
func (s *Server) Listener() net.Listener {
return s.listener
}

View File

@@ -1,913 +0,0 @@
package ircserver_test
import (
"strings"
"testing"
"time"
)
// TestIntegrationTwoClients is a comprehensive integration
// test that spawns the IRC server programmatically, connects
// two real TCP clients, and verifies all major IRC features
// including cross-client message delivery.
//
// The test runs sequentially through IRC features because
// both clients share the same channel state. Each section
// builds on the previous one (e.g. alice and bob must be
// JOINed before PRIVMSG can be tested).
//
//nolint:cyclop,funlen,maintidx // integration test
func TestIntegrationTwoClients(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
bob := env.dial(t)
// ── Registration ──────────────────────────────────
aliceWelcome := alice.register("alice")
assertContains(
t, aliceWelcome, " 001 ", "RPL_WELCOME alice",
)
assertContains(
t, aliceWelcome, " 002 ", "RPL_YOURHOST alice",
)
assertContains(
t, aliceWelcome, " 003 ", "RPL_CREATED alice",
)
assertContains(
t, aliceWelcome, " 004 ", "RPL_MYINFO alice",
)
assertContains(
t, aliceWelcome, "alice",
"nick in welcome burst",
)
bobWelcome := bob.register("bob")
assertContains(
t, bobWelcome, " 001 ", "RPL_WELCOME bob",
)
assertContains(
t, bobWelcome, "bob",
"nick in welcome burst",
)
// ── JOIN and cross-client visibility ──────────────
alice.send("JOIN #integration")
aliceJoinLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
assertContains(
t, aliceJoinLines, "JOIN",
"alice receives JOIN echo",
)
assertContains(
t, aliceJoinLines, " 366 ",
"RPL_ENDOFNAMES for alice",
)
bob.send("JOIN #integration")
bobJoinLines := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
assertContains(
t, bobJoinLines, "JOIN",
"bob receives JOIN echo",
)
// Alice should see bob's JOIN via relay.
aliceSeesBob := alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
assertContains(
t, aliceSeesBob, "bob",
"alice sees bob's JOIN",
)
// ── PRIVMSG (channel) — alice to bob ──────────────
alice.send("PRIVMSG #integration :hello from alice")
bobGetsMsg := bob.readUntil(func(l string) bool {
return strings.Contains(l, "hello from alice")
})
assertContains(
t, bobGetsMsg, "hello from alice",
"bob receives alice's channel message",
)
// ── PRIVMSG (channel) — bob to alice ──────────────
bob.send("PRIVMSG #integration :hello from bob")
aliceGetsMsg := alice.readUntil(func(l string) bool {
return strings.Contains(l, "hello from bob")
})
assertContains(
t, aliceGetsMsg, "hello from bob",
"alice receives bob's channel message",
)
// ── PRIVMSG (DM) — alice to bob ──────────────────
alice.send("PRIVMSG bob :secret message")
bobDM := bob.readUntil(func(l string) bool {
return strings.Contains(l, "secret message")
})
assertContains(
t, bobDM, "secret message",
"bob receives alice's DM",
)
assertContains(
t, bobDM, "alice",
"DM from field is alice",
)
// ── PRIVMSG (DM) — bob to alice ──────────────────
bob.send("PRIVMSG alice :reply to you")
aliceDM := alice.readUntil(func(l string) bool {
return strings.Contains(l, "reply to you")
})
assertContains(
t, aliceDM, "reply to you",
"alice receives bob's DM",
)
// ── NOTICE (channel) ──────────────────────────────
alice.send("NOTICE #integration :notice msg")
bobNotice := bob.readUntil(func(l string) bool {
return strings.Contains(l, "notice msg")
})
assertContains(
t, bobNotice, "NOTICE",
"bob receives NOTICE command",
)
assertContains(
t, bobNotice, "notice msg",
"bob receives NOTICE text",
)
// ── NOTICE (DM) ──────────────────────────────────
bob.send("NOTICE alice :dm notice")
aliceNotice := alice.readUntil(func(l string) bool {
return strings.Contains(l, "dm notice")
})
assertContains(
t, aliceNotice, "dm notice",
"alice receives DM NOTICE",
)
// ── TOPIC ─────────────────────────────────────────
// alice is the channel creator so she is +o.
alice.send("TOPIC #integration :Integration Test Topic")
aliceTopic := alice.readUntil(func(l string) bool {
return strings.Contains(
l, "Integration Test Topic",
)
})
assertContains(
t, aliceTopic, "Integration Test Topic",
"alice sees TOPIC echo",
)
bobTopic := bob.readUntil(func(l string) bool {
return strings.Contains(
l, "Integration Test Topic",
)
})
assertContains(
t, bobTopic, "Integration Test Topic",
"bob receives TOPIC change",
)
// ── MODE (query) ──────────────────────────────────
alice.send("MODE #integration")
aliceMode := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 324 ")
})
assertContains(
t, aliceMode, " 324 ",
"RPL_CHANNELMODEIS",
)
// ── MODE (+m moderated, then -m) ──────────────────
alice.send("MODE #integration +m")
aliceModeM := alice.readUntil(func(l string) bool {
return strings.Contains(l, "MODE") &&
strings.Contains(l, "+m")
})
assertContains(
t, aliceModeM, "+m",
"alice sees MODE +m echo",
)
bobModeM := bob.readUntil(func(l string) bool {
return strings.Contains(l, "+m")
})
assertContains(
t, bobModeM, "+m",
"bob sees MODE +m relay",
)
// Revert moderated mode.
alice.send("MODE #integration -m")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "-m")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "-m")
})
// ── MODE (+v voice, then -v) ──────────────────────
alice.send("MODE #integration +v bob")
aliceVoice := alice.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
assertContains(
t, aliceVoice, "+v",
"alice sees +v echo",
)
bobVoice := bob.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
assertContains(
t, bobVoice, "+v",
"bob receives +v relay",
)
// Remove voice.
alice.send("MODE #integration -v bob")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "-v")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "-v")
})
// ── NAMES ─────────────────────────────────────────
alice.send("NAMES #integration")
aliceNames := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
assertContains(
t, aliceNames, " 353 ",
"RPL_NAMREPLY",
)
assertContains(
t, aliceNames, " 366 ",
"RPL_ENDOFNAMES",
)
// Both nicks should appear in the name list.
foundBothNames := false
for _, line := range aliceNames {
if strings.Contains(line, " 353 ") &&
strings.Contains(line, "alice") &&
strings.Contains(line, "bob") {
foundBothNames = true
break
}
}
if !foundBothNames {
t.Error("NAMES reply should list both alice and bob")
}
// ── LIST ──────────────────────────────────────────
alice.send("LIST")
aliceList := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 323 ")
})
assertContains(
t, aliceList, " 322 ",
"RPL_LIST entry",
)
assertContains(
t, aliceList, "#integration",
"LIST includes #integration",
)
assertContains(
t, aliceList, " 323 ", //nolint:misspell // IRC RPL_LISTEND
"RPL_LISTEND", //nolint:misspell // IRC term
)
// ── WHO ───────────────────────────────────────────
bob.send("WHO #integration")
bobWho := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 315 ")
})
assertContains(
t, bobWho, " 352 ",
"RPL_WHOREPLY",
)
assertContains(
t, bobWho, " 315 ",
"RPL_ENDOFWHO",
)
// ── WHOIS ─────────────────────────────────────────
alice.send("WHOIS bob")
aliceWhois := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 318 ")
})
assertContains(
t, aliceWhois, " 311 ",
"RPL_WHOISUSER",
)
assertContains(
t, aliceWhois, " 312 ",
"RPL_WHOISSERVER",
)
assertContains(
t, aliceWhois, " 318 ",
"RPL_ENDOFWHOIS",
)
// ── WHOIS with channels ───────────────────────────
bob.send("WHOIS alice")
bobWhois := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 318 ")
})
assertContains(
t, bobWhois, " 319 ",
"RPL_WHOISCHANNELS",
)
assertContains(
t, bobWhois, "#integration",
"WHOIS shows #integration channel",
)
// ── LUSERS ────────────────────────────────────────
alice.send("LUSERS")
aliceLusers := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 255 ")
})
assertContains(
t, aliceLusers, " 251 ",
"RPL_LUSERCLIENT",
)
assertContains(
t, aliceLusers, " 255 ",
"RPL_LUSERME",
)
// ── NICK change ───────────────────────────────────
bob.send("NICK bobby")
bobNick := bob.readUntil(func(l string) bool {
return strings.Contains(l, "NICK") &&
strings.Contains(l, "bobby")
})
assertContains(
t, bobNick, "bobby",
"bob sees NICK change to bobby",
)
// alice should see the nick change relayed.
aliceNick := alice.readUntil(func(l string) bool {
return strings.Contains(l, "bobby")
})
assertContains(
t, aliceNick, "NICK",
"alice sees NICK command",
)
assertContains(
t, aliceNick, "bobby",
"alice sees new nick bobby",
)
// Change it back for remaining tests.
bob.send("NICK bob")
bob.readUntil(func(l string) bool {
return strings.Contains(l, "bob")
})
alice.readUntil(func(l string) bool {
return strings.Contains(l, "NICK") &&
strings.Contains(l, "bob")
})
// ── Duplicate NICK ────────────────────────────────
bob.send("NICK alice")
bobDupNick := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 433 ")
})
assertContains(
t, bobDupNick, " 433 ",
"ERR_NICKNAMEINUSE",
)
// ── KICK ──────────────────────────────────────────
// alice is op; she kicks bob.
alice.send("KICK #integration bob :testing kick")
aliceKick := alice.readUntil(func(l string) bool {
return strings.Contains(l, "KICK")
})
assertContains(
t, aliceKick, "KICK",
"alice sees KICK echo",
)
assertContains(
t, aliceKick, "bob",
"KICK mentions bob",
)
bobKick := bob.readUntil(func(l string) bool {
return strings.Contains(l, "KICK")
})
assertContains(
t, bobKick, "KICK",
"bob receives KICK",
)
assertContains(
t, bobKick, "testing kick",
"KICK reason is relayed",
)
// bob rejoins.
bob.joinAndDrain("#integration")
// Drain alice's view of the rejoin.
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
// ── KICK non-op should fail ───────────────────────
bob.send("KICK #integration alice :nope")
bobKickFail := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 482 ")
})
assertContains(
t, bobKickFail, " 482 ",
"ERR_CHANOPRIVSNEEDED",
)
// ── TOPIC lock (+t default) ───────────────────────
// +t is default, so bob should not be able to set
// topic.
bob.send("TOPIC #integration :bob tries topic")
bobTopicFail := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 482 ")
})
assertContains(
t, bobTopicFail, " 482 ",
"ERR_CHANOPRIVSNEEDED for topic",
)
// ── PING / PONG ───────────────────────────────────
alice.send("PING :testtoken")
alicePong := alice.readUntil(func(l string) bool {
return strings.Contains(l, "PONG")
})
assertContains(
t, alicePong, "PONG",
"PONG response received",
)
// ── Unknown command ───────────────────────────────
bob.send("FOOBAR")
bobUnknown := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 421 ")
})
assertContains(
t, bobUnknown, " 421 ",
"ERR_UNKNOWNCOMMAND",
)
// ── MOTD ──────────────────────────────────────────
alice.send("MOTD")
aliceMOTD := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 376 ")
})
assertContains(
t, aliceMOTD, " 376 ",
"RPL_ENDOFMOTD",
)
// ── AWAY (set, check via DM, clear) ───────────────
alice.send("AWAY :gone fishing")
aliceAway := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 306 ")
})
assertContains(
t, aliceAway, " 306 ",
"RPL_NOWAWAY",
)
// bob DMs alice — should get RPL_AWAY.
bob.send("PRIVMSG alice :are you there?")
bobAwayReply := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 301 ")
})
assertContains(
t, bobAwayReply, " 301 ",
"RPL_AWAY for bob when messaging alice",
)
assertContains(
t, bobAwayReply, "gone fishing",
"away message relayed",
)
// Clear away.
alice.send("AWAY")
alice.readUntil(func(l string) bool {
return strings.Contains(l, " 305 ")
})
// ── PASS (set password post-registration) ─────────
alice.send("PASS :mypassword123")
alicePass := alice.readUntil(func(l string) bool {
return strings.Contains(l, "Password set")
})
assertContains(
t, alicePass, "Password set",
"password set confirmation",
)
// ── MODE -t/+t topic lock toggle ──────────────────
alice.send("MODE #integration -t")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "-t")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "-t")
})
// Now bob should be able to set topic.
bob.send("TOPIC #integration :bob sets topic now")
bobTopicOK := bob.readUntil(func(l string) bool {
return strings.Contains(l, "bob sets topic now")
})
assertContains(
t, bobTopicOK, "bob sets topic now",
"bob can set topic after -t",
)
// alice sees the topic change.
aliceTopicRelay := alice.readUntil(func(l string) bool {
return strings.Contains(l, "bob sets topic now")
})
assertContains(
t, aliceTopicRelay, "bob sets topic now",
"alice sees bob's topic after -t",
)
// Restore +t.
alice.send("MODE #integration +t")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "+t")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "+t")
})
// ── DM to nonexistent nick ────────────────────────
alice.send("PRIVMSG nobody123 :hello")
aliceNoSuch := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 401 ")
})
assertContains(
t, aliceNoSuch, " 401 ",
"ERR_NOSUCHNICK",
)
// ── PART with reason ──────────────────────────────
bob.send("PART #integration :bye for now")
bobPart := bob.readUntil(func(l string) bool {
return strings.Contains(l, "PART")
})
assertContains(
t, bobPart, "PART",
"bob sees PART echo",
)
// alice sees bob PART via relay.
alicePart := alice.readUntil(func(l string) bool {
return strings.Contains(l, "PART") &&
strings.Contains(l, "bob")
})
assertContains(
t, alicePart, "bob",
"alice sees bob's PART",
)
assertContains(
t, alicePart, "bye for now",
"PART reason is relayed",
)
// bob rejoins for remaining tests.
bob.joinAndDrain("#integration")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
// ── PART non-existent channel ─────────────────────
bob.send("PART #nonexistent")
bobPartFail := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 403 ") ||
strings.Contains(l, " 442 ")
})
foundPartErr := false
for _, line := range bobPartFail {
if strings.Contains(line, " 403 ") ||
strings.Contains(line, " 442 ") {
foundPartErr = true
break
}
}
if !foundPartErr {
t.Error(
"expected ERR_NOSUCHCHANNEL or " +
"ERR_NOTONCHANNEL",
)
}
// ── User MODE query ───────────────────────────────
alice.send("MODE alice")
aliceUMode := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 221 ")
})
assertContains(
t, aliceUMode, " 221 ",
"RPL_UMODEIS",
)
// ── Multiple channel operation ────────────────────
alice.send("JOIN #second")
alice.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
bob.send("JOIN #second")
bob.readUntil(func(l string) bool {
return strings.Contains(l, " 366 ")
})
// Drain alice seeing bob join.
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
alice.send("PRIVMSG #second :cross-channel test")
bobCross := bob.readUntil(func(l string) bool {
return strings.Contains(l, "cross-channel test")
})
assertContains(
t, bobCross, "cross-channel test",
"bob receives message in #second",
)
// Clean up #second.
alice.send("PART #second")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "PART")
})
bob.send("PART #second")
bob.readUntil(func(l string) bool {
return strings.Contains(l, "PART")
})
// ── QUIT ──────────────────────────────────────────
bob.send("QUIT :integration test done")
bobQuit := bob.readUntil(func(l string) bool {
return strings.Contains(l, "ERROR")
})
assertContains(
t, bobQuit, "integration test done",
"QUIT reason echoed",
)
// alice should see bob's QUIT via relay.
aliceQuit := alice.readUntil(func(l string) bool {
return strings.Contains(l, "QUIT") &&
strings.Contains(l, "bob")
})
assertContains(
t, aliceQuit, "bob",
"alice sees bob's QUIT",
)
}
// TestIntegrationModeSecret tests +s (secret) channel
// mode — verifies that +s can be set and the mode is
// reflected in MODE queries.
func TestIntegrationModeSecret(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice")
alice.joinAndDrain("#secretroom")
// Set +s.
alice.send("MODE #secretroom +s")
aliceLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, "+s")
})
assertContains(
t, aliceLines, "+s",
"alice sees MODE +s confirmation",
)
// Verify mode is reflected in query.
alice.send("MODE #secretroom")
modeLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, " 324 ")
})
assertContains(
t, modeLines, "s",
"channel mode includes s",
)
}
// TestIntegrationModeModerated tests +m (moderated) mode
// — non-voiced users cannot send.
func TestIntegrationModeModerated(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice")
bob := env.dial(t)
bob.register("bob")
alice.joinAndDrain("#modtest")
bob.joinAndDrain("#modtest")
// Drain alice's view of bob's join.
alice.readUntil(func(l string) bool {
return strings.Contains(l, "JOIN") &&
strings.Contains(l, "bob")
})
// Set +m.
alice.send("MODE #modtest +m")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "+m")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "+m")
})
// bob should get an error trying to send.
bob.send("PRIVMSG #modtest :should fail")
bobLines := bob.readUntil(func(l string) bool {
return strings.Contains(l, " 404 ") ||
strings.Contains(l, " 482 ")
})
foundModErr := false
for _, line := range bobLines {
if strings.Contains(line, " 404 ") ||
strings.Contains(line, " 482 ") {
foundModErr = true
break
}
}
if !foundModErr {
t.Error(
"non-voiced user should not be able to send " +
"in +m channel",
)
}
// Grant +v to bob, then he should be able to send.
alice.send("MODE #modtest +v bob")
alice.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
bob.readUntil(func(l string) bool {
return strings.Contains(l, "+v")
})
bob.send("PRIVMSG #modtest :voiced message")
aliceLines := alice.readUntil(func(l string) bool {
return strings.Contains(l, "voiced message")
})
assertContains(
t, aliceLines, "voiced message",
"alice receives voiced bob's message",
)
}
// TestIntegrationThirdClientObserver verifies that a third
// client observing the same channel receives messages from
// the other two.
func TestIntegrationThirdClientObserver(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice")
bob := env.dial(t)
bob.register("bob")
carol := env.dial(t)
carol.register("carol")
alice.joinAndDrain("#trio")
bob.joinAndDrain("#trio")
carol.joinAndDrain("#trio")
// Drain join notifications.
time.Sleep(100 * time.Millisecond)
// alice sends; both bob and carol should receive.
alice.send("PRIVMSG #trio :hello trio")
bobLines := bob.readUntil(func(l string) bool {
return strings.Contains(l, "hello trio")
})
assertContains(
t, bobLines, "hello trio",
"bob receives trio message",
)
carolLines := carol.readUntil(func(l string) bool {
return strings.Contains(l, "hello trio")
})
assertContains(
t, carolLines, "hello trio",
"carol receives trio message",
)
}

View File

@@ -1,123 +0,0 @@
// Package ircserver implements a traditional IRC wire protocol
// listener (RFC 1459/2812) that bridges to the neoirc HTTP/JSON
// server internals.
package ircserver
import "strings"
// Message represents a parsed IRC wire protocol message.
type Message struct {
// Prefix is the optional :prefix at the start (may be
// empty for client-to-server messages).
Prefix string
// Command is the IRC command (e.g., "PRIVMSG", "NICK").
Command string
// Params holds the positional parameters, including the
// trailing parameter (which was preceded by ':' on the
// wire).
Params []string
}
// ParseMessage parses a single IRC wire protocol line
// (without the trailing CR-LF) into a Message.
// Returns nil if the line is empty.
//
// IRC message format (RFC 1459 §2.3.1):
//
// [":" prefix SPACE] command { SPACE param } [SPACE ":" trailing]
func ParseMessage(line string) *Message {
if line == "" {
return nil
}
msg := &Message{} //nolint:exhaustruct // fields set below
// Extract prefix if present.
if line[0] == ':' {
idx := strings.IndexByte(line, ' ')
if idx < 0 {
// Only a prefix, no command — invalid.
return nil
}
msg.Prefix = line[1:idx]
line = line[idx+1:]
}
// Skip leading spaces.
line = strings.TrimLeft(line, " ")
if line == "" {
return nil
}
// Extract command.
idx := strings.IndexByte(line, ' ')
if idx < 0 {
msg.Command = strings.ToUpper(line)
return msg
}
msg.Command = strings.ToUpper(line[:idx])
line = line[idx+1:]
// Extract parameters.
for line != "" {
line = strings.TrimLeft(line, " ")
if line == "" {
break
}
// Trailing parameter (everything after ':').
if line[0] == ':' {
msg.Params = append(msg.Params, line[1:])
break
}
idx = strings.IndexByte(line, ' ')
if idx < 0 {
msg.Params = append(msg.Params, line)
break
}
msg.Params = append(msg.Params, line[:idx])
line = line[idx+1:]
}
return msg
}
// FormatMessage formats an IRC message into wire protocol
// format (without the trailing CR-LF).
func FormatMessage(
prefix, command string,
params ...string,
) string {
var buf strings.Builder
if prefix != "" {
buf.WriteByte(':')
buf.WriteString(prefix)
buf.WriteByte(' ')
}
buf.WriteString(command)
for i, param := range params {
buf.WriteByte(' ')
isLast := i == len(params)-1
needsColon := strings.Contains(param, " ") ||
param == "" || param[0] == ':'
if isLast && needsColon {
buf.WriteByte(':')
}
buf.WriteString(param)
}
return buf.String()
}

View File

@@ -1,328 +0,0 @@
package ircserver_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
)
//nolint:funlen // table-driven test
func TestParseMessage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want *ircserver.Message
wantNil bool
}{
{
name: "empty",
input: "",
want: nil,
wantNil: true,
},
{
name: "simple command",
input: "PING",
want: &ircserver.Message{
Prefix: "",
Command: "PING",
Params: nil,
},
wantNil: false,
},
{
name: "command with one param",
input: "NICK alice",
want: &ircserver.Message{
Prefix: "",
Command: "NICK",
Params: []string{"alice"},
},
wantNil: false,
},
{
name: "command case insensitive",
input: "nick Alice",
want: &ircserver.Message{
Prefix: "",
Command: "NICK",
Params: []string{"Alice"},
},
wantNil: false,
},
{
name: "privmsg with trailing",
input: "PRIVMSG #general :hello world",
want: &ircserver.Message{
Prefix: "",
Command: "PRIVMSG",
Params: []string{"#general", "hello world"},
},
wantNil: false,
},
{
name: "with prefix",
input: ":server.example.com 001 alice :Welcome to IRC",
want: &ircserver.Message{
Prefix: "server.example.com",
Command: "001",
Params: []string{"alice", "Welcome to IRC"},
},
wantNil: false,
},
{
name: "user command",
input: "USER alice 0 * :Alice Smith",
want: &ircserver.Message{
Prefix: "",
Command: "USER",
Params: []string{
"alice", "0", "*", "Alice Smith",
},
},
wantNil: false,
},
{
name: "join channel",
input: "JOIN #general",
want: &ircserver.Message{
Prefix: "",
Command: "JOIN",
Params: []string{"#general"},
},
wantNil: false,
},
{
name: "quit with trailing",
input: "QUIT :leaving now",
want: &ircserver.Message{
Prefix: "",
Command: "QUIT",
Params: []string{"leaving now"},
},
wantNil: false,
},
{
name: "quit without reason",
input: "QUIT",
want: &ircserver.Message{
Prefix: "",
Command: "QUIT",
Params: nil,
},
wantNil: false,
},
{
name: "mode query",
input: "MODE #general",
want: &ircserver.Message{
Prefix: "",
Command: "MODE",
Params: []string{"#general"},
},
wantNil: false,
},
{
name: "kick with reason",
input: "KICK #general bob :misbehaving",
want: &ircserver.Message{
Prefix: "",
Command: "KICK",
Params: []string{
"#general", "bob", "misbehaving",
},
},
wantNil: false,
},
{
name: "empty trailing",
input: "PRIVMSG #general :",
want: &ircserver.Message{
Prefix: "",
Command: "PRIVMSG",
Params: []string{"#general", ""},
},
wantNil: false,
},
{
name: "pass command",
input: "PASS mysecret",
want: &ircserver.Message{
Prefix: "",
Command: "PASS",
Params: []string{"mysecret"},
},
wantNil: false,
},
{
name: "ping with server",
input: "PING :irc.example.com",
want: &ircserver.Message{
Prefix: "",
Command: "PING",
Params: []string{"irc.example.com"},
},
wantNil: false,
},
{
name: "topic with trailing spaces",
input: "TOPIC #general :Welcome to the channel!",
want: &ircserver.Message{
Prefix: "",
Command: "TOPIC",
Params: []string{
"#general",
"Welcome to the channel!",
},
},
wantNil: false,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
got := ircserver.ParseMessage(testCase.input)
if testCase.wantNil {
if got != nil {
t.Fatalf("expected nil, got %+v", got)
}
return
}
if got == nil {
t.Fatal("expected non-nil message")
}
if got.Prefix != testCase.want.Prefix {
t.Errorf(
"prefix: got %q, want %q",
got.Prefix, testCase.want.Prefix,
)
}
if got.Command != testCase.want.Command {
t.Errorf(
"command: got %q, want %q",
got.Command, testCase.want.Command,
)
}
if len(got.Params) != len(testCase.want.Params) {
t.Fatalf(
"params length: got %d, want %d (%v vs %v)",
len(got.Params),
len(testCase.want.Params),
got.Params,
testCase.want.Params,
)
}
for i, p := range got.Params {
if p != testCase.want.Params[i] {
t.Errorf(
"param[%d]: got %q, want %q",
i, p, testCase.want.Params[i],
)
}
}
})
}
}
func TestFormatMessage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prefix string
command string
params []string
want string
}{
{
name: "simple command",
prefix: "",
command: "PING",
params: nil,
want: "PING",
},
{
name: "with prefix",
prefix: "server",
command: "PONG",
params: []string{"server"},
want: ":server PONG server",
},
{
name: "privmsg with trailing",
prefix: "alice!alice@host",
command: "PRIVMSG",
params: []string{"#general", "hello world"},
want: ":alice!alice@host PRIVMSG #general :hello world",
},
{
name: "numeric reply",
prefix: "server",
command: "001",
params: []string{"alice", "Welcome to IRC"},
want: ":server 001 alice :Welcome to IRC",
},
{
name: "empty trailing",
prefix: "server",
command: "PRIVMSG",
params: []string{"#chan", ""},
want: ":server PRIVMSG #chan :",
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
got := ircserver.FormatMessage(
testCase.prefix, testCase.command, testCase.params...,
)
if got != testCase.want {
t.Errorf("got %q, want %q", got, testCase.want)
}
})
}
}
func TestParseFormatRoundTrip(t *testing.T) {
t.Parallel()
// Round-trip only works for lines where the last
// parameter either contains a space (gets ':' prefix
// on format) or is a non-trailing single token.
lines := []string{
"PING",
"NICK alice",
"PRIVMSG #general :hello world",
"JOIN #general",
"MODE #general",
}
for _, line := range lines {
msg := ircserver.ParseMessage(line)
if msg == nil {
t.Fatalf("failed to parse: %q", line)
}
formatted := ircserver.FormatMessage(
msg.Prefix, msg.Command, msg.Params...,
)
if formatted != line {
t.Errorf(
"round-trip failed: input %q, got %q",
line, formatted,
)
}
}
}

View File

@@ -1,319 +0,0 @@
package ircserver
import (
"context"
"encoding/json"
"strings"
"time"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
// relayMessages polls the client output queue and delivers
// IRC-formatted messages to the TCP connection. It runs
// in a goroutine for the lifetime of the connection.
func (c *Conn) relayMessages(ctx context.Context) {
// Use a ticker as a fallback; primary wakeup is via
// broker notification.
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
default:
}
// Drain any available messages.
delivered := c.drainQueue(ctx)
if delivered {
// Tight loop while there are messages.
continue
}
// Wait for notification or timeout.
waitCh := c.brk.Wait(c.sessionID)
select {
case <-waitCh:
// New message notification — loop back.
case <-ticker.C:
// Periodic check.
case <-ctx.Done():
c.brk.Remove(c.sessionID, waitCh)
return
}
}
}
const relayPollLimit = 100
// drainQueue polls the output queue and delivers all
// pending messages. Returns true if at least one message
// was delivered.
func (c *Conn) drainQueue(ctx context.Context) bool {
msgs, lastID, err := c.database.PollMessages(
ctx, c.clientID, c.lastQueueID, relayPollLimit,
)
if err != nil {
return false
}
if len(msgs) == 0 {
return false
}
for i := range msgs {
c.deliverIRCMessage(ctx, &msgs[i])
}
if lastID > c.lastQueueID {
c.lastQueueID = lastID
}
return true
}
// deliverIRCMessage converts a db.IRCMessage to wire
// protocol and sends it.
//
//nolint:cyclop // dispatch table
func (c *Conn) deliverIRCMessage(
_ context.Context,
msg *db.IRCMessage,
) {
command := msg.Command
// Decode body as []string for the trailing text.
var bodyLines []string
if msg.Body != nil {
_ = json.Unmarshal(msg.Body, &bodyLines)
}
text := ""
if len(bodyLines) > 0 {
text = bodyLines[0]
}
// Route by command type.
switch {
case isNumeric(command):
c.deliverNumeric(msg, text)
case command == irc.CmdPrivmsg || command == irc.CmdNotice:
c.deliverTextMessage(msg, command, text)
case command == irc.CmdJoin:
c.deliverJoin(msg)
case command == irc.CmdPart:
c.deliverPart(msg, text)
case command == irc.CmdNick:
c.deliverNickChange(msg, text)
case command == irc.CmdQuit:
c.deliverQuitMsg(msg, text)
case command == irc.CmdTopic:
c.deliverTopicChange(msg, text)
case command == irc.CmdKick:
c.deliverKickMsg(msg, text)
case command == "INVITE":
c.deliverInviteMsg(msg, text)
case command == irc.CmdMode:
c.deliverMode(msg, text)
case command == irc.CmdPing:
// Server-originated PING — reply with PONG.
c.sendFromServer("PING", c.serverSfx)
default:
// Unknown command — deliver as server notice.
if text != "" {
c.sendFromServer("NOTICE", c.nick, text)
}
}
}
// isNumeric returns true if the command is a 3-digit
// numeric code.
func isNumeric(cmd string) bool {
return len(cmd) == 3 &&
cmd[0] >= '0' && cmd[0] <= '9' &&
cmd[1] >= '0' && cmd[1] <= '9' &&
cmd[2] >= '0' && cmd[2] <= '9'
}
// deliverNumeric sends a numeric reply.
func (c *Conn) deliverNumeric(
msg *db.IRCMessage,
text string,
) {
from := msg.From
if from == "" {
from = c.serverSfx
}
var params []string
if msg.Params != nil {
_ = json.Unmarshal(msg.Params, &params)
}
allParams := make([]string, 0, 1+len(params)+1)
allParams = append(allParams, c.nick)
allParams = append(allParams, params...)
if text != "" {
allParams = append(allParams, text)
}
c.send(FormatMessage(from, msg.Command, allParams...))
}
// deliverTextMessage sends PRIVMSG or NOTICE.
func (c *Conn) deliverTextMessage(
msg *db.IRCMessage,
command, text string,
) {
from := msg.From
target := msg.To
// Don't echo our own messages back.
if strings.EqualFold(from, c.nick) {
return
}
prefix := from
if !strings.Contains(prefix, "!") {
prefix = from + "!" + from + "@*"
}
c.send(FormatMessage(prefix, command, target, text))
}
// deliverJoin sends a JOIN notification.
func (c *Conn) deliverJoin(msg *db.IRCMessage) {
// Don't echo our own JOINs (we already sent them
// during joinChannel).
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
c.send(FormatMessage(prefix, "JOIN", channel))
}
// deliverPart sends a PART notification.
func (c *Conn) deliverPart(msg *db.IRCMessage, text string) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
if text != "" {
c.send(FormatMessage(
prefix, "PART", channel, text,
))
} else {
c.send(FormatMessage(prefix, "PART", channel))
}
}
// deliverNickChange sends a NICK change notification.
func (c *Conn) deliverNickChange(
msg *db.IRCMessage,
newNick string,
) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
c.send(FormatMessage(prefix, "NICK", newNick))
}
// deliverQuitMsg sends a QUIT notification.
func (c *Conn) deliverQuitMsg(
msg *db.IRCMessage,
text string,
) {
if strings.EqualFold(msg.From, c.nick) {
return
}
prefix := msg.From + "!" + msg.From + "@*"
if text != "" {
c.send(FormatMessage(
prefix, "QUIT", "Quit: "+text,
))
} else {
c.send(FormatMessage(prefix, "QUIT", "Quit"))
}
}
// deliverTopicChange sends a TOPIC change notification.
func (c *Conn) deliverTopicChange(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
c.send(FormatMessage(prefix, "TOPIC", channel, text))
}
// deliverKickMsg sends a KICK notification.
func (c *Conn) deliverKickMsg(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
channel := msg.To
var params []string
if msg.Params != nil {
_ = json.Unmarshal(msg.Params, &params)
}
kickTarget := ""
if len(params) > 0 {
kickTarget = params[0]
}
if kickTarget != "" {
c.send(FormatMessage(
prefix, "KICK", channel, kickTarget, text,
))
} else {
c.send(FormatMessage(
prefix, "KICK", channel, "?", text,
))
}
}
// deliverInviteMsg sends an INVITE notification.
func (c *Conn) deliverInviteMsg(
_ *db.IRCMessage,
text string,
) {
c.sendFromServer("NOTICE", c.nick, text)
}
// deliverMode sends a MODE change notification.
func (c *Conn) deliverMode(
msg *db.IRCMessage,
text string,
) {
prefix := msg.From + "!" + msg.From + "@*"
target := msg.To
if text != "" {
c.send(FormatMessage(prefix, "MODE", target, text))
}
}

View File

@@ -1,157 +0,0 @@
package ircserver
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/service"
"go.uber.org/fx"
)
// Params defines the dependencies for creating an IRC
// Server.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *db.Database
Broker *broker.Broker
Service *service.Service
}
// Server is the TCP IRC protocol server.
type Server struct {
log *slog.Logger
cfg *config.Config
database *db.Database
brk *broker.Broker
svc *service.Service
listener net.Listener
mu sync.Mutex
conns map[*Conn]struct{}
cancel context.CancelFunc
}
// New creates a new IRC Server and registers its lifecycle
// hooks. The listener is only started if IRC_LISTEN_ADDR
// is configured; otherwise the server is inert.
func New(
lifecycle fx.Lifecycle,
params Params,
) *Server {
srv := &Server{
log: params.Logger.Get(),
cfg: params.Config,
database: params.Database,
brk: params.Broker,
svc: params.Service,
conns: make(map[*Conn]struct{}),
listener: nil,
cancel: nil,
mu: sync.Mutex{},
}
listenAddr := params.Config.IRCListenAddr
if listenAddr == "" {
return srv
}
lifecycle.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return srv.start(ctx, listenAddr)
},
OnStop: func(_ context.Context) error {
srv.stop()
return nil
},
})
return srv
}
// start begins listening for TCP connections.
//
//nolint:contextcheck // long-lived server ctx, not the short Fx one
func (s *Server) start(_ context.Context, addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("irc listen: %w", err)
}
s.listener = ln
ctx, cancel := context.WithCancel(context.Background())
s.cancel = cancel
s.log.Info(
"irc server listening", "addr", addr,
)
go s.acceptLoop(ctx)
return nil
}
// stop shuts down the listener and all connections.
func (s *Server) stop() {
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
s.listener.Close() //nolint:errcheck,gosec
}
s.mu.Lock()
for c := range s.conns {
c.conn.Close() //nolint:errcheck,gosec
}
s.mu.Unlock()
}
// acceptLoop accepts new connections.
func (s *Server) acceptLoop(ctx context.Context) {
for {
tcpConn, err := s.listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
s.log.Error(
"irc accept error", "error", err,
)
continue
}
}
client := newConn(
ctx, tcpConn, s.log,
s.database, s.brk, s.cfg, s.svc,
)
s.mu.Lock()
s.conns[client] = struct{}{}
s.mu.Unlock()
go func() {
defer func() {
s.mu.Lock()
delete(s.conns, client)
s.mu.Unlock()
}()
client.serve(ctx)
}()
}
}

View File

@@ -1,625 +0,0 @@
package ircserver_test
import (
"bufio"
"database/sql"
"fmt"
"log/slog"
"net"
"os"
"strings"
"testing"
"time"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/ircserver"
_ "modernc.org/sqlite"
)
const testTimeout = 5 * time.Second
func TestMain(m *testing.M) {
db.SetBcryptCost(4)
os.Exit(m.Run())
}
// testEnv holds the shared test infrastructure.
type testEnv struct {
database *db.Database
brk *broker.Broker
cfg *config.Config
srv *ircserver.Server
}
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
dsn := fmt.Sprintf(
"file:%s?mode=memory&cache=shared&_journal_mode=WAL",
t.Name(),
)
conn, err := sql.Open("sqlite", dsn)
if err != nil {
t.Fatalf("open db: %v", err)
}
conn.SetMaxOpenConns(1)
_, err = conn.ExecContext(
t.Context(), "PRAGMA foreign_keys = ON",
)
if err != nil {
t.Fatalf("pragma: %v", err)
}
database := db.NewTestDatabaseFromConn(conn)
err = database.RunMigrations(t.Context())
if err != nil {
t.Fatalf("migrate: %v", err)
}
brk := broker.New()
cfg := &config.Config{ //nolint:exhaustruct
ServerName: "test.irc",
MOTD: "Welcome to test IRC",
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
addr := listener.Addr().String()
err = listener.Close()
if err != nil {
t.Fatalf("close listener: %v", err)
}
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct
))
srv := ircserver.NewTestServer(log, cfg, database, brk)
err = srv.Start(addr)
if err != nil {
t.Fatalf("start irc server: %v", err)
}
t.Cleanup(func() {
srv.Stop()
err := conn.Close()
if err != nil {
t.Logf("close db: %v", err)
}
})
return &testEnv{
database: database,
brk: brk,
cfg: cfg,
srv: srv,
}
}
// dial connects to the test server.
func (env *testEnv) dial(t *testing.T) *testClient {
t.Helper()
conn, err := net.DialTimeout(
"tcp",
env.srv.Listener().Addr().String(),
testTimeout,
)
if err != nil {
t.Fatalf("dial: %v", err)
}
t.Cleanup(func() {
err := conn.Close()
if err != nil {
t.Logf("close conn: %v", err)
}
})
return &testClient{
t: t,
conn: conn,
scanner: bufio.NewScanner(conn),
}
}
// testClient wraps a raw TCP connection with helpers.
type testClient struct {
t *testing.T
conn net.Conn
scanner *bufio.Scanner
}
func (tc *testClient) send(line string) {
tc.t.Helper()
_ = tc.conn.SetWriteDeadline(
time.Now().Add(testTimeout),
)
_, err := fmt.Fprintf(tc.conn, "%s\r\n", line)
if err != nil {
tc.t.Fatalf("send: %v", err)
}
}
func (tc *testClient) readLine() string {
tc.t.Helper()
_ = tc.conn.SetReadDeadline(
time.Now().Add(testTimeout),
)
if !tc.scanner.Scan() {
err := tc.scanner.Err()
if err != nil {
tc.t.Fatalf("read: %v", err)
}
tc.t.Fatal("connection closed unexpectedly")
}
return tc.scanner.Text()
}
// readUntil reads lines until one matches the predicate.
func (tc *testClient) readUntil(
pred func(string) bool,
) []string {
tc.t.Helper()
var lines []string
for {
line := tc.readLine()
lines = append(lines, line)
if pred(line) {
return lines
}
}
}
// register sends NICK + USER and reads through the welcome
// burst.
func (tc *testClient) register(nick string) []string {
tc.t.Helper()
tc.send("NICK " + nick)
tc.send("USER " + nick + " 0 * :Test User")
return tc.readUntil(func(line string) bool {
return strings.Contains(line, " 376 ") ||
strings.Contains(line, " 422 ")
})
}
// assertContains checks that at least one line matches the
// given substring.
func assertContains(
t *testing.T,
lines []string,
substr, description string,
) {
t.Helper()
for _, line := range lines {
if strings.Contains(line, substr) {
return
}
}
t.Errorf("did not find %q in output: %s", substr, description)
}
// joinAndDrain joins a channel and reads until
// RPL_ENDOFNAMES.
func (tc *testClient) joinAndDrain(channel string) {
tc.t.Helper()
tc.send("JOIN " + channel)
tc.readUntil(func(line string) bool {
return strings.Contains(line, " 366 ")
})
}
// sendAndExpect sends a command and reads until a line
// containing the expected substring is found.
func (tc *testClient) sendAndExpect(
cmd, expect string,
) []string {
tc.t.Helper()
tc.send(cmd)
return tc.readUntil(func(line string) bool {
return strings.Contains(line, expect)
})
}
func TestRegistration(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
lines := client.register("alice")
assertContains(t, lines, " 001 ", "RPL_WELCOME")
}
func TestWelcomeContainsNick(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
lines := client.register("bob")
for _, line := range lines {
if strings.Contains(line, " 001 ") &&
!strings.Contains(line, "bob") {
t.Errorf("001 should contain nick: %s", line)
}
}
}
func TestPingPong(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("pingtest")
lines := client.sendAndExpect("PING :hello", "PONG")
assertContains(t, lines, "PONG", "PONG response")
}
func TestJoinChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("joiner")
client.send("JOIN #test")
lines := client.readUntil(func(line string) bool {
return strings.Contains(line, " 366 ")
})
assertContains(t, lines, "JOIN", "JOIN echo")
assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES")
}
func TestPrivmsgBetweenClients(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
alice := env.dial(t)
alice.register("alice_pm")
bob := env.dial(t)
bob.register("bob_pm")
alice.joinAndDrain("#chat")
bob.joinAndDrain("#chat")
alice.send("PRIVMSG #chat :hello bob!")
lines := bob.sendAndExpect("PING :sync", "hello bob!")
assertContains(t, lines, "hello bob!", "channel PRIVMSG")
}
func TestNickChange(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("oldnick")
lines := client.sendAndExpect("NICK newnick", "newnick")
assertContains(t, lines, "NICK", "NICK change echo")
}
func TestDuplicateNick(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
first := env.dial(t)
first.register("taken")
second := env.dial(t)
second.send("NICK taken")
second.send("USER taken 0 * :Test")
lines := second.readUntil(func(line string) bool {
return strings.Contains(line, " 433 ")
})
assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE")
}
func TestListChannels(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("lister")
client.joinAndDrain("#listtest")
lines := client.sendAndExpect("LIST", " 323 ")
assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term
}
func TestWhois(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("whoistest")
lines := client.sendAndExpect(
"WHOIS whoistest", " 318 ",
)
assertContains(t, lines, " 311 ", "RPL_WHOISUSER")
assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS")
}
func TestQuit(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("quitter")
lines := client.sendAndExpect(
"QUIT :goodbye", "ERROR",
)
assertContains(t, lines, "goodbye", "QUIT reason")
}
func TestTopicSetAndGet(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("topicuser")
client.joinAndDrain("#topictest")
lines := client.sendAndExpect(
"TOPIC #topictest :New topic here",
"New topic here",
)
assertContains(
t, lines, "New topic here", "TOPIC echo",
)
}
func TestUnknownCommand(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("unknowncmd")
lines := client.sendAndExpect("FOOBAR", " 421 ")
assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND")
}
func TestDirectMessage(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
sender := env.dial(t)
sender.register("dmsender")
receiver := env.dial(t)
receiver.register("dmreceiver")
// Give relay goroutines time to start.
time.Sleep(100 * time.Millisecond)
sender.send("PRIVMSG dmreceiver :hello privately")
lines := receiver.readUntil(func(line string) bool {
return strings.Contains(line, "hello privately")
})
assertContains(
t, lines, "hello privately", "direct PRIVMSG",
)
}
func TestCAPNegotiation(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.send("CAP LS 302")
line := client.readLine()
if !strings.Contains(line, "CAP") {
t.Errorf("expected CAP response, got: %s", line)
}
client.send("CAP END")
lines := client.register("capuser")
assertContains(t, lines, " 001 ", "registration after CAP")
}
func TestPartChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("parter")
client.joinAndDrain("#parttest")
lines := client.sendAndExpect(
"PART #parttest :leaving", "PART",
)
assertContains(t, lines, "#parttest", "PART echo")
}
func TestModeQuery(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("modeuser")
client.joinAndDrain("#modetest")
lines := client.sendAndExpect(
"MODE #modetest", " 324 ",
)
assertContains(
t, lines, " 324 ", "RPL_CHANNELMODEIS",
)
}
func TestWhoChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("whouser")
client.joinAndDrain("#whotest")
lines := client.sendAndExpect("WHO #whotest", " 315 ")
assertContains(t, lines, " 352 ", "RPL_WHOREPLY")
}
func TestLusers(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("luseruser")
lines := client.sendAndExpect("LUSERS", " 255 ")
assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT")
}
func TestMotd(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("motduser")
lines := client.sendAndExpect("MOTD", " 376 ")
assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD")
}
func TestAwaySetAndClear(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("awayuser")
setLines := client.sendAndExpect(
"AWAY :brb lunch", " 306 ",
)
assertContains(t, setLines, " 306 ", "RPL_NOWAWAY")
clearLines := client.sendAndExpect("AWAY", " 305 ")
assertContains(t, clearLines, " 305 ", "RPL_UNAWAY")
}
func TestHandlePassPostRegistration(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("passuser")
lines := client.sendAndExpect(
"PASS :mypassword123", "Password set",
)
assertContains(
t, lines, "Password set", "password confirmation",
)
}
func TestPreRegistrationNotRegistered(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.send("PRIVMSG #test :hello")
line := client.readLine()
if !strings.Contains(line, " 451 ") {
t.Errorf(
"expected ERR_NOTREGISTERED (451), got: %s",
line,
)
}
}
func TestNamesNonExistentChannel(t *testing.T) {
t.Parallel()
env := newTestEnv(t)
client := env.dial(t)
client.register("namesuser")
lines := client.sendAndExpect(
"NAMES #doesnotexist", " 366 ",
)
assertContains(
t, lines, " 366 ",
"RPL_ENDOFNAMES for non-existent channel",
)
}
func BenchmarkParseMessage(b *testing.B) {
line := ":nick!user@host PRIVMSG #channel :Hello, world!"
b.ResetTimer()
for range b.N {
_ = ircserver.ParseMessage(line)
}
}
func BenchmarkFormatMessage(b *testing.B) {
b.ResetTimer()
for range b.N {
_ = ircserver.FormatMessage(
"nick!user@host", "PRIVMSG",
"#channel", "Hello, world!",
)
}
}

View File

@@ -126,23 +126,18 @@ func (mware *Middleware) Logging() func(http.Handler) http.Handler {
} }
// CORS returns middleware that handles Cross-Origin Resource Sharing. // CORS returns middleware that handles Cross-Origin Resource Sharing.
// AllowCredentials is true so browsers include cookies in
// cross-origin API requests.
func (mware *Middleware) CORS() func(http.Handler) http.Handler { func (mware *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields
AllowOriginFunc: func( AllowedOrigins: []string{"*"},
_ *http.Request, _ string,
) bool {
return true
},
AllowedMethods: []string{ AllowedMethods: []string{
"GET", "POST", "PUT", "DELETE", "OPTIONS", "GET", "POST", "PUT", "DELETE", "OPTIONS",
}, },
AllowedHeaders: []string{ AllowedHeaders: []string{
"Accept", "Content-Type", "X-CSRF-Token", "Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
}, },
ExposedHeaders: []string{"Link"}, ExposedHeaders: []string{"Link"},
AllowCredentials: true, AllowCredentials: false,
MaxAge: corsMaxAge, MaxAge: corsMaxAge,
}) })
} }

View File

@@ -1,122 +0,0 @@
// Package ratelimit provides per-IP rate limiting for HTTP endpoints.
package ratelimit
import (
"sync"
"time"
"golang.org/x/time/rate"
)
const (
// DefaultRate is the default number of allowed requests per second.
DefaultRate = 1.0
// DefaultBurst is the default maximum burst size.
DefaultBurst = 5
// DefaultSweepInterval controls how often stale entries are pruned.
DefaultSweepInterval = 10 * time.Minute
// DefaultEntryTTL is how long an unused entry lives before eviction.
DefaultEntryTTL = 15 * time.Minute
)
// entry tracks a per-IP rate limiter and when it was last used.
type entry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// Limiter manages per-key rate limiters with automatic cleanup
// of stale entries.
type Limiter struct {
mu sync.Mutex
entries map[string]*entry
rate rate.Limit
burst int
entryTTL time.Duration
stopCh chan struct{}
}
// New creates a new per-key rate Limiter.
// The ratePerSec parameter sets how many requests per second are
// allowed per key. The burst parameter sets the maximum number of
// requests that can be made in a single burst.
func New(ratePerSec float64, burst int) *Limiter {
limiter := &Limiter{
mu: sync.Mutex{},
entries: make(map[string]*entry),
rate: rate.Limit(ratePerSec),
burst: burst,
entryTTL: DefaultEntryTTL,
stopCh: make(chan struct{}),
}
go limiter.sweepLoop()
return limiter
}
// Allow reports whether a request from the given key should be
// allowed. It consumes one token from the key's rate limiter.
func (l *Limiter) Allow(key string) bool {
l.mu.Lock()
ent, exists := l.entries[key]
if !exists {
ent = &entry{
limiter: rate.NewLimiter(l.rate, l.burst),
lastSeen: time.Now(),
}
l.entries[key] = ent
} else {
ent.lastSeen = time.Now()
}
l.mu.Unlock()
return ent.limiter.Allow()
}
// Stop terminates the background sweep goroutine.
func (l *Limiter) Stop() {
close(l.stopCh)
}
// Len returns the number of tracked keys (for testing).
func (l *Limiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.entries)
}
// sweepLoop periodically removes entries that haven't been seen
// within the TTL.
func (l *Limiter) sweepLoop() {
ticker := time.NewTicker(DefaultSweepInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.sweep()
case <-l.stopCh:
return
}
}
}
// sweep removes stale entries.
func (l *Limiter) sweep() {
l.mu.Lock()
defer l.mu.Unlock()
cutoff := time.Now().Add(-l.entryTTL)
for key, ent := range l.entries {
if ent.lastSeen.Before(cutoff) {
delete(l.entries, key)
}
}
}

View File

@@ -1,106 +0,0 @@
package ratelimit_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
)
func TestNewCreatesLimiter(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter == nil {
t.Fatal("expected non-nil limiter")
}
}
func TestAllowWithinBurst(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 3)
defer limiter.Stop()
for i := range 3 {
if !limiter.Allow("192.168.1.1") {
t.Fatalf(
"request %d should be allowed within burst",
i+1,
)
}
}
}
func TestAllowExceedsBurst(t *testing.T) {
t.Parallel()
// Rate of 0 means no token replenishment, only burst.
limiter := ratelimit.New(0, 3)
defer limiter.Stop()
for range 3 {
limiter.Allow("10.0.0.1")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("fourth request should be denied after burst exhausted")
}
}
func TestAllowSeparateKeys(t *testing.T) {
t.Parallel()
// Rate of 0, burst of 1 — only one request per key.
limiter := ratelimit.New(0, 1)
defer limiter.Stop()
if !limiter.Allow("10.0.0.1") {
t.Fatal("first request for key A should be allowed")
}
if !limiter.Allow("10.0.0.2") {
t.Fatal("first request for key B should be allowed")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("second request for key A should be denied")
}
if limiter.Allow("10.0.0.2") {
t.Fatal("second request for key B should be denied")
}
}
func TestLenTracksKeys(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter.Len() != 0 {
t.Fatalf("expected 0 entries, got %d", limiter.Len())
}
limiter.Allow("10.0.0.1")
limiter.Allow("10.0.0.2")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
// Same key again should not increase count.
limiter.Allow("10.0.0.1")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
}
func TestStopDoesNotPanic(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
limiter.Stop()
}

View File

@@ -75,6 +75,10 @@ func (srv *Server) setupAPIv1(router chi.Router) {
"/session", "/session",
srv.handlers.HandleCreateSession(), srv.handlers.HandleCreateSession(),
) )
router.Post(
"/register",
srv.handlers.HandleRegister(),
)
router.Post( router.Post(
"/login", "/login",
srv.handlers.HandleLogin(), srv.handlers.HandleLogin(),

View File

@@ -1,901 +0,0 @@
// Package service provides shared business logic for both
// the IRC wire protocol and HTTP/JSON transports.
package service
import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"log/slog"
"strconv"
"strings"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/pkg/irc"
"go.uber.org/fx"
)
// Params defines the dependencies for creating a Service.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *db.Database
Broker *broker.Broker
}
// Service provides shared business logic for IRC commands.
type Service struct {
db *db.Database
broker *broker.Broker
config *config.Config
log *slog.Logger
}
// New creates a new Service.
func New(params Params) *Service {
return &Service{
db: params.Database,
broker: params.Broker,
config: params.Config,
log: params.Logger.Get(),
}
}
// NewTestService creates a Service for use in tests
// outside the service package.
func NewTestService(
database *db.Database,
brk *broker.Broker,
cfg *config.Config,
log *slog.Logger,
) *Service {
return &Service{
db: database,
broker: brk,
config: cfg,
log: log,
}
}
// IRCError represents an IRC protocol-level error with a
// numeric code that both transports can map to responses.
type IRCError struct {
Code irc.IRCMessageType
Params []string
Message string
}
func (e *IRCError) Error() string { return e.Message }
// JoinResult contains the outcome of a channel join.
type JoinResult struct {
ChannelID int64
IsCreator bool
}
// DirectMsgResult contains the outcome of a direct message.
type DirectMsgResult struct {
UUID string
AwayMsg string
}
// FanOut inserts a message and enqueues it to all given
// session IDs, notifying each via the broker.
func (s *Service) FanOut(
ctx context.Context,
command, from, to string,
params, body, meta json.RawMessage,
sessionIDs []int64,
) (int64, string, error) {
dbID, msgUUID, err := s.db.InsertMessage(
ctx, command, from, to, params, body, meta,
)
if err != nil {
return 0, "", fmt.Errorf("insert message: %w", err)
}
for _, sid := range sessionIDs {
_ = s.db.EnqueueToSession(ctx, sid, dbID)
s.broker.Notify(sid)
}
return dbID, msgUUID, nil
}
// excludeSession returns a copy of ids without the given
// session.
func excludeSession(
ids []int64,
exclude int64,
) []int64 {
out := make([]int64, 0, len(ids))
for _, id := range ids {
if id != exclude {
out = append(out, id)
}
}
return out
}
// SendChannelMessage validates membership and moderation,
// then fans out a message to all channel members except
// the sender. Returns the database row ID, message UUID,
// and any error. The dbID lets callers enqueue the same
// message to the sender when echo is needed (HTTP
// transport).
func (s *Service) SendChannelMessage(
ctx context.Context,
sessionID int64,
nick, command, channel string,
body, meta json.RawMessage,
) (int64, string, error) {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return 0, "", &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID,
)
if !isMember {
return 0, "", &IRCError{
irc.ErrCannotSendToChan,
[]string{channel},
"Cannot send to channel",
}
}
// Ban check — banned users cannot send messages.
isBanned, banErr := s.db.IsSessionBanned(
ctx, chID, sessionID,
)
if banErr == nil && isBanned {
return 0, "", &IRCError{
irc.ErrCannotSendToChan,
[]string{channel},
"Cannot send to channel (+b)",
}
}
moderated, _ := s.db.IsChannelModerated(ctx, chID)
if moderated {
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
isVoiced, _ := s.db.IsChannelVoiced(
ctx, chID, sessionID,
)
if !isOp && !isVoiced {
return 0, "", &IRCError{
irc.ErrCannotSendToChan,
[]string{channel},
"Cannot send to channel (+m)",
}
}
}
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
recipients := excludeSession(memberIDs, sessionID)
dbID, uuid, fanErr := s.FanOut(
ctx, command, nick, channel,
nil, body, meta, recipients,
)
if fanErr != nil {
return 0, "", fanErr
}
return dbID, uuid, nil
}
// SendDirectMessage validates the target and sends a
// direct message, returning the message UUID and any away
// message set on the target.
func (s *Service) SendDirectMessage(
ctx context.Context,
sessionID int64,
nick, command, target string,
body, meta json.RawMessage,
) (*DirectMsgResult, error) {
targetSID, err := s.db.GetSessionByNick(ctx, target)
if err != nil {
return nil, &IRCError{
irc.ErrNoSuchNick,
[]string{target},
"No such nick",
}
}
away, _ := s.db.GetAway(ctx, targetSID)
recipients := []int64{targetSID}
if targetSID != sessionID {
recipients = append(recipients, sessionID)
}
_, uuid, fanErr := s.FanOut(
ctx, command, nick, target,
nil, body, meta, recipients,
)
if fanErr != nil {
return nil, fanErr
}
return &DirectMsgResult{UUID: uuid, AwayMsg: away}, nil
}
// JoinChannel creates or joins a channel, making the
// first joiner the operator. Fans out the JOIN to all
// channel members.
func (s *Service) JoinChannel(
ctx context.Context,
sessionID int64,
nick, channel, suppliedKey string,
) (*JoinResult, error) {
chID, err := s.db.GetOrCreateChannel(ctx, channel)
if err != nil {
return nil, fmt.Errorf("get/create channel: %w", err)
}
memberCount, countErr := s.db.CountChannelMembers(
ctx, chID,
)
isCreator := countErr == nil && memberCount == 0
if !isCreator {
if joinErr := checkJoinRestrictions(
ctx, s.db, chID, sessionID,
channel, suppliedKey, memberCount,
); joinErr != nil {
return nil, joinErr
}
}
if isCreator {
err = s.db.JoinChannelAsOperator(
ctx, chID, sessionID,
)
} else {
err = s.db.JoinChannel(ctx, chID, sessionID)
}
if err != nil {
return nil, fmt.Errorf("join channel: %w", err)
}
// Clear invite after successful join.
_ = s.db.ClearChannelInvite(ctx, chID, sessionID)
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{channel}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdJoin, nick, channel,
nil, body, nil, memberIDs,
)
return &JoinResult{
ChannelID: chID,
IsCreator: isCreator,
}, nil
}
// PartChannel validates membership, broadcasts PART to
// remaining members, removes the user, and cleans up empty
// channels.
func (s *Service) PartChannel(
ctx context.Context,
sessionID int64,
nick, channel, reason string,
) error {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID,
)
if !isMember {
return &IRCError{
irc.ErrNotOnChannel,
[]string{channel},
"You're not on that channel",
}
}
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
recipients := excludeSession(memberIDs, sessionID)
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdPart, nick, channel,
nil, body, nil, recipients,
)
s.db.PartChannel(ctx, chID, sessionID) //nolint:errcheck,gosec
s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec
return nil
}
// SetTopic validates membership and topic-lock, sets the
// topic, and broadcasts the change.
func (s *Service) SetTopic(
ctx context.Context,
sessionID int64,
nick, channel, topic string,
) error {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, sessionID,
)
if !isMember {
return &IRCError{
irc.ErrNotOnChannel,
[]string{channel},
"You're not on that channel",
}
}
topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID)
if topicLocked {
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
if !isOp {
return &IRCError{
irc.ErrChanOpPrivsNeeded,
[]string{channel},
"You're not channel operator",
}
}
}
if setErr := s.db.SetTopic(
ctx, channel, topic,
); setErr != nil {
return fmt.Errorf("set topic: %w", setErr)
}
_ = s.db.SetTopicMeta(ctx, channel, topic, nick)
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{topic}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdTopic, nick, channel,
nil, body, nil, memberIDs,
)
return nil
}
// KickUser validates operator status and target
// membership, broadcasts the KICK, removes the target,
// and cleans up empty channels.
func (s *Service) KickUser(
ctx context.Context,
sessionID int64,
nick, channel, targetNick, reason string,
) error {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
if !isOp {
return &IRCError{
irc.ErrChanOpPrivsNeeded,
[]string{channel},
"You're not channel operator",
}
}
targetSID, err := s.db.GetSessionByNick(
ctx, targetNick,
)
if err != nil {
return &IRCError{
irc.ErrNoSuchNick,
[]string{targetNick},
"No such nick/channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, targetSID,
)
if !isMember {
return &IRCError{
irc.ErrUserNotInChannel,
[]string{targetNick, channel},
"They aren't on that channel",
}
}
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
params, _ := json.Marshal( //nolint:errchkjson
[]string{targetNick},
)
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdKick, nick, channel,
params, body, nil, memberIDs,
)
s.db.PartChannel(ctx, chID, targetSID) //nolint:errcheck,gosec
s.db.DeleteChannelIfEmpty(ctx, chID) //nolint:errcheck,gosec
return nil
}
// ChangeNick changes a user's nickname and broadcasts the
// change to all users sharing channels.
func (s *Service) ChangeNick(
ctx context.Context,
sessionID int64,
oldNick, newNick string,
) error {
err := s.db.ChangeNick(ctx, sessionID, newNick)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") ||
db.IsUniqueConstraintError(err) {
return &IRCError{
irc.ErrNicknameInUse,
[]string{newNick},
"Nickname is already in use",
}
}
return &IRCError{
irc.ErrErroneusNickname,
[]string{newNick},
"Erroneous nickname",
}
}
s.broadcastNickChange(ctx, sessionID, oldNick, newNick)
return nil
}
// BroadcastQuit broadcasts a QUIT to all channel peers,
// parts all channels, and deletes the session. Uses the
// FanOut pattern: one message row fanned out to all unique
// peer sessions.
func (s *Service) BroadcastQuit(
ctx context.Context,
sessionID int64,
nick, reason string,
) {
channels, err := s.db.GetSessionChannels(
ctx, sessionID,
)
if err != nil {
return
}
notified := make(map[int64]bool)
for _, ch := range channels {
memberIDs, memErr := s.db.GetChannelMemberIDs(
ctx, ch.ID,
)
if memErr != nil {
continue
}
for _, mid := range memberIDs {
if mid == sessionID || notified[mid] {
continue
}
notified[mid] = true
}
}
if len(notified) > 0 {
recipients := make([]int64, 0, len(notified))
for sid := range notified {
recipients = append(recipients, sid)
}
body, _ := json.Marshal([]string{reason}) //nolint:errchkjson
_, _, _ = s.FanOut(
ctx, irc.CmdQuit, nick, "",
nil, body, nil, recipients,
)
}
for _, ch := range channels {
s.db.PartChannel(ctx, ch.ID, sessionID) //nolint:errcheck,gosec
s.db.DeleteChannelIfEmpty(ctx, ch.ID) //nolint:errcheck,gosec
}
s.db.DeleteSession(ctx, sessionID) //nolint:errcheck,gosec
}
// SetAway sets or clears the away message. Returns true
// if the message was cleared (empty string).
func (s *Service) SetAway(
ctx context.Context,
sessionID int64,
message string,
) (bool, error) {
err := s.db.SetAway(ctx, sessionID, message)
if err != nil {
return false, fmt.Errorf("set away: %w", err)
}
return message == "", nil
}
// Oper validates operator credentials and grants oper
// status to the session.
func (s *Service) Oper(
ctx context.Context,
sessionID int64,
name, password string,
) error {
cfgName := s.config.OperName
cfgPassword := s.config.OperPassword
// Use constant-time comparison and return the same
// error for all failures to prevent information
// leakage about valid operator names.
if cfgName == "" || cfgPassword == "" ||
subtle.ConstantTimeCompare(
[]byte(name), []byte(cfgName),
) != 1 ||
subtle.ConstantTimeCompare(
[]byte(password), []byte(cfgPassword),
) != 1 {
return &IRCError{
irc.ErrNoOperHost,
nil,
"No O-lines for your host",
}
}
_ = s.db.SetSessionOper(ctx, sessionID, true)
return nil
}
// ValidateChannelOp checks that the session is a channel
// operator. Returns the channel ID.
func (s *Service) ValidateChannelOp(
ctx context.Context,
sessionID int64,
channel string,
) (int64, error) {
chID, err := s.db.GetChannelByName(ctx, channel)
if err != nil {
return 0, &IRCError{
irc.ErrNoSuchChannel,
[]string{channel},
"No such channel",
}
}
isOp, _ := s.db.IsChannelOperator(
ctx, chID, sessionID,
)
if !isOp {
return 0, &IRCError{
irc.ErrChanOpPrivsNeeded,
[]string{channel},
"You're not channel operator",
}
}
return chID, nil
}
// ApplyMemberMode applies +o/-o or +v/-v on a channel
// member after validating the target.
func (s *Service) ApplyMemberMode(
ctx context.Context,
chID int64,
channel, targetNick string,
mode rune,
adding bool,
) error {
targetSID, err := s.db.GetSessionByNick(
ctx, targetNick,
)
if err != nil {
return &IRCError{
irc.ErrNoSuchNick,
[]string{targetNick},
"No such nick/channel",
}
}
isMember, _ := s.db.IsChannelMember(
ctx, chID, targetSID,
)
if !isMember {
return &IRCError{
irc.ErrUserNotInChannel,
[]string{targetNick, channel},
"They aren't on that channel",
}
}
switch mode {
case 'o':
_ = s.db.SetChannelMemberOperator(
ctx, chID, targetSID, adding,
)
case 'v':
_ = s.db.SetChannelMemberVoiced(
ctx, chID, targetSID, adding,
)
}
return nil
}
// SetChannelFlag applies a simple boolean channel mode
// (+m/-m, +t/-t, +i/-i, +s/-s, +n/-n).
func (s *Service) SetChannelFlag(
ctx context.Context,
chID int64,
flag rune,
setting bool,
) error {
switch flag {
case 'm':
if err := s.db.SetChannelModerated(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set moderated: %w", err)
}
case 't':
if err := s.db.SetChannelTopicLocked(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set topic locked: %w", err)
}
case 'i':
if err := s.db.SetChannelInviteOnly(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set invite only: %w", err)
}
case 's':
if err := s.db.SetChannelSecret(
ctx, chID, setting,
); err != nil {
return fmt.Errorf("set secret: %w", err)
}
case 'n':
if err := s.db.SetChannelNoExternal(
ctx, chID, setting,
); err != nil {
return fmt.Errorf(
"set no external: %w", err,
)
}
}
return nil
}
// BroadcastMode fans out a MODE change to all channel
// members.
func (s *Service) BroadcastMode(
ctx context.Context,
nick, channel string,
chID int64,
modeText string,
) {
memberIDs, _ := s.db.GetChannelMemberIDs(ctx, chID)
body, _ := json.Marshal([]string{modeText}) //nolint:errchkjson
_, _, _ = s.FanOut( //nolint:dogsled // fire-and-forget broadcast
ctx, irc.CmdMode, nick, channel,
nil, body, nil, memberIDs,
)
}
// QueryChannelMode returns the complete channel mode
// string including all flags and parameterized modes.
func (s *Service) QueryChannelMode(
ctx context.Context,
chID int64,
) string {
modes := "+"
noExternal, _ := s.db.IsChannelNoExternal(ctx, chID)
if noExternal {
modes += "n"
}
inviteOnly, _ := s.db.IsChannelInviteOnly(ctx, chID)
if inviteOnly {
modes += "i"
}
moderated, _ := s.db.IsChannelModerated(ctx, chID)
if moderated {
modes += "m"
}
secret, _ := s.db.IsChannelSecret(ctx, chID)
if secret {
modes += "s"
}
topicLocked, _ := s.db.IsChannelTopicLocked(ctx, chID)
if topicLocked {
modes += "t"
}
var modeParams string
key, _ := s.db.GetChannelKey(ctx, chID)
if key != "" {
modes += "k"
modeParams += " " + key
}
limit, _ := s.db.GetChannelUserLimit(ctx, chID)
if limit > 0 {
modes += "l"
modeParams += " " + strconv.Itoa(limit)
}
bits, _ := s.db.GetChannelHashcashBits(ctx, chID)
if bits > 0 {
modes += "H"
modeParams += " " + strconv.Itoa(bits)
}
return modes + modeParams
}
// broadcastNickChange notifies channel peers of a nick
// change.
func (s *Service) broadcastNickChange(
ctx context.Context,
sessionID int64,
oldNick, newNick string,
) {
channels, err := s.db.GetSessionChannels(
ctx, sessionID,
)
if err != nil {
return
}
body, _ := json.Marshal([]string{newNick}) //nolint:errchkjson
notified := make(map[int64]bool)
dbID, _, insErr := s.db.InsertMessage(
ctx, irc.CmdNick, oldNick, "",
nil, body, nil,
)
if insErr != nil {
return
}
// Notify the user themselves (for multi-client sync).
_ = s.db.EnqueueToSession(ctx, sessionID, dbID)
s.broker.Notify(sessionID)
notified[sessionID] = true
for _, ch := range channels {
memberIDs, memErr := s.db.GetChannelMemberIDs(
ctx, ch.ID,
)
if memErr != nil {
continue
}
for _, mid := range memberIDs {
if notified[mid] {
continue
}
notified[mid] = true
_ = s.db.EnqueueToSession(ctx, mid, dbID)
s.broker.Notify(mid)
}
}
}
// checkJoinRestrictions validates Tier 2 join conditions:
// bans, invite-only, channel key, and user limit.
func checkJoinRestrictions(
ctx context.Context,
database *db.Database,
chID, sessionID int64,
channel, suppliedKey string,
memberCount int64,
) error {
isBanned, banErr := database.IsSessionBanned(
ctx, chID, sessionID,
)
if banErr == nil && isBanned {
return &IRCError{
Code: irc.ErrBannedFromChan,
Params: []string{channel},
Message: "Cannot join channel (+b)",
}
}
isInviteOnly, ioErr := database.IsChannelInviteOnly(
ctx, chID,
)
if ioErr == nil && isInviteOnly {
hasInvite, invErr := database.HasChannelInvite(
ctx, chID, sessionID,
)
if invErr != nil || !hasInvite {
return &IRCError{
Code: irc.ErrInviteOnlyChan,
Params: []string{channel},
Message: "Cannot join channel (+i)",
}
}
}
key, keyErr := database.GetChannelKey(ctx, chID)
if keyErr == nil && key != "" && suppliedKey != key {
return &IRCError{
Code: irc.ErrBadChannelKey,
Params: []string{channel},
Message: "Cannot join channel (+k)",
}
}
limit, limErr := database.GetChannelUserLimit(ctx, chID)
if limErr == nil && limit > 0 &&
memberCount >= int64(limit) {
return &IRCError{
Code: irc.ErrChannelIsFull,
Params: []string{channel},
Message: "Cannot join channel (+l)",
}
}
return nil
}

View File

@@ -1,365 +0,0 @@
// Tests use a global viper instance for configuration,
// making parallel execution unsafe.
//
//nolint:paralleltest
package service_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"testing"
"git.eeqj.de/sneak/neoirc/internal/broker"
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/globals"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/service"
"git.eeqj.de/sneak/neoirc/pkg/irc"
"go.uber.org/fx"
"go.uber.org/fx/fxtest"
"golang.org/x/crypto/bcrypt"
)
func TestMain(m *testing.M) {
db.SetBcryptCost(bcrypt.MinCost)
os.Exit(m.Run())
}
// testEnv holds all dependencies for a service test.
type testEnv struct {
svc *service.Service
db *db.Database
broker *broker.Broker
app *fxtest.App
}
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
dbURL := fmt.Sprintf(
"file:svc_test_%p?mode=memory&cache=shared",
t,
)
var (
database *db.Database
svc *service.Service
)
brk := broker.New()
app := fxtest.New(t,
fx.Provide(
func() *globals.Globals {
return &globals.Globals{ //nolint:exhaustruct
Appname: "neoirc-test",
Version: "test",
}
},
logger.New,
func(
lifecycle fx.Lifecycle,
globs *globals.Globals,
log *logger.Logger,
) (*config.Config, error) {
cfg, err := config.New(
lifecycle, config.Params{ //nolint:exhaustruct
Globals: globs, Logger: log,
},
)
if err != nil {
return nil, fmt.Errorf(
"test config: %w", err,
)
}
cfg.DBURL = dbURL
cfg.Port = 0
cfg.OperName = "admin"
cfg.OperPassword = "secret"
return cfg, nil
},
func(
lifecycle fx.Lifecycle,
log *logger.Logger,
cfg *config.Config,
) (*db.Database, error) {
return db.New(lifecycle, db.Params{ //nolint:exhaustruct
Logger: log, Config: cfg,
})
},
func() *broker.Broker { return brk },
service.New,
),
fx.Populate(&database, &svc),
)
app.RequireStart()
t.Cleanup(func() {
app.RequireStop()
})
return &testEnv{
svc: svc,
db: database,
broker: brk,
app: app,
}
}
// createSession is a test helper that creates a session
// and returns the session ID.
func createSession(
ctx context.Context,
t *testing.T,
database *db.Database,
nick string,
) int64 {
t.Helper()
sessionID, _, _, err := database.CreateSession(
ctx, nick, nick, "localhost", "127.0.0.1",
)
if err != nil {
t.Fatalf("create session %s: %v", nick, err)
}
return sessionID
}
func TestFanOut(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
body, _ := json.Marshal([]string{"hello"}) //nolint:errchkjson
dbID, uuid, err := env.svc.FanOut(
ctx, irc.CmdPrivmsg, "alice", "#test",
nil, body, nil,
[]int64{sid1, sid2},
)
if err != nil {
t.Fatalf("FanOut: %v", err)
}
if dbID == 0 {
t.Error("expected non-zero dbID")
}
if uuid == "" {
t.Error("expected non-empty UUID")
}
}
func TestJoinChannel(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid := createSession(ctx, t, env.db, "alice")
result, err := env.svc.JoinChannel(
ctx, sid, "alice", "#general", "",
)
if err != nil {
t.Fatalf("JoinChannel: %v", err)
}
if result.ChannelID == 0 {
t.Error("expected non-zero channel ID")
}
if !result.IsCreator {
t.Error("first joiner should be creator")
}
// Second user joins — not creator.
sid2 := createSession(ctx, t, env.db, "bob")
result2, err := env.svc.JoinChannel(
ctx, sid2, "bob", "#general", "",
)
if err != nil {
t.Fatalf("JoinChannel bob: %v", err)
}
if result2.IsCreator {
t.Error("second joiner should not be creator")
}
if result2.ChannelID != result.ChannelID {
t.Error("both should join the same channel")
}
}
func TestPartChannel(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid := createSession(ctx, t, env.db, "alice")
_, err := env.svc.JoinChannel(
ctx, sid, "alice", "#general", "",
)
if err != nil {
t.Fatalf("JoinChannel: %v", err)
}
err = env.svc.PartChannel(
ctx, sid, "alice", "#general", "bye",
)
if err != nil {
t.Fatalf("PartChannel: %v", err)
}
// Parting a non-existent channel returns error.
err = env.svc.PartChannel(
ctx, sid, "alice", "#nonexistent", "",
)
if err == nil {
t.Error("expected error for non-existent channel")
}
var ircErr *service.IRCError
if !errors.As(err, &ircErr) {
t.Errorf("expected IRCError, got %T", err)
}
}
func TestSendChannelMessage(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
_, err := env.svc.JoinChannel(
ctx, sid1, "alice", "#chat", "",
)
if err != nil {
t.Fatalf("join alice: %v", err)
}
_, err = env.svc.JoinChannel(
ctx, sid2, "bob", "#chat", "",
)
if err != nil {
t.Fatalf("join bob: %v", err)
}
body, _ := json.Marshal([]string{"hello world"}) //nolint:errchkjson
dbID, uuid, err := env.svc.SendChannelMessage(
ctx, sid1, "alice",
irc.CmdPrivmsg, "#chat", body, nil,
)
if err != nil {
t.Fatalf("SendChannelMessage: %v", err)
}
if dbID == 0 {
t.Error("expected non-zero dbID")
}
if uuid == "" {
t.Error("expected non-empty UUID")
}
// Non-member cannot send.
sid3 := createSession(ctx, t, env.db, "charlie")
_, _, err = env.svc.SendChannelMessage(
ctx, sid3, "charlie",
irc.CmdPrivmsg, "#chat", body, nil,
)
if err == nil {
t.Error("expected error for non-member send")
}
}
func TestBroadcastQuit(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
_, err := env.svc.JoinChannel(
ctx, sid1, "alice", "#room", "",
)
if err != nil {
t.Fatalf("join alice: %v", err)
}
_, err = env.svc.JoinChannel(
ctx, sid2, "bob", "#room", "",
)
if err != nil {
t.Fatalf("join bob: %v", err)
}
// BroadcastQuit should not panic and should clean up.
env.svc.BroadcastQuit(
ctx, sid1, "alice", "Goodbye",
)
// Session should be deleted.
_, lookupErr := env.db.GetSessionByNick(ctx, "alice")
if lookupErr == nil {
t.Error("expected session to be deleted after quit")
}
}
func TestSendChannelMessage_Moderated(t *testing.T) {
env := newTestEnv(t)
ctx := t.Context()
sid1 := createSession(ctx, t, env.db, "alice")
sid2 := createSession(ctx, t, env.db, "bob")
result, err := env.svc.JoinChannel(
ctx, sid1, "alice", "#modchat", "",
)
if err != nil {
t.Fatalf("join alice: %v", err)
}
_, err = env.svc.JoinChannel(
ctx, sid2, "bob", "#modchat", "",
)
if err != nil {
t.Fatalf("join bob: %v", err)
}
// Set channel to moderated.
chID := result.ChannelID
_ = env.svc.SetChannelFlag(ctx, chID, 'm', true)
body, _ := json.Marshal([]string{"test"}) //nolint:errchkjson
// Bob (non-op, non-voiced) should fail to send.
_, _, err = env.svc.SendChannelMessage(
ctx, sid2, "bob",
irc.CmdPrivmsg, "#modchat", body, nil,
)
if err == nil {
t.Error("expected error for non-voiced user in moderated channel")
}
// Alice (operator) should succeed.
_, _, err = env.svc.SendChannelMessage(
ctx, sid1, "alice",
irc.CmdPrivmsg, "#modchat", body, nil,
)
if err != nil {
t.Errorf("operator should be able to send in moderated channel: %v", err)
}
}

View File

@@ -3,9 +3,7 @@ package irc
// IRC command names (RFC 1459 / RFC 2812). // IRC command names (RFC 1459 / RFC 2812).
const ( const (
CmdAway = "AWAY" CmdAway = "AWAY"
CmdInvite = "INVITE"
CmdJoin = "JOIN" CmdJoin = "JOIN"
CmdKick = "KICK"
CmdList = "LIST" CmdList = "LIST"
CmdLusers = "LUSERS" CmdLusers = "LUSERS"
CmdMode = "MODE" CmdMode = "MODE"
@@ -14,14 +12,12 @@ const (
CmdNick = "NICK" CmdNick = "NICK"
CmdNotice = "NOTICE" CmdNotice = "NOTICE"
CmdOper = "OPER" CmdOper = "OPER"
CmdPass = "PASS"
CmdPart = "PART" CmdPart = "PART"
CmdPing = "PING" CmdPing = "PING"
CmdPong = "PONG" CmdPong = "PONG"
CmdPrivmsg = "PRIVMSG" CmdPrivmsg = "PRIVMSG"
CmdQuit = "QUIT" CmdQuit = "QUIT"
CmdTopic = "TOPIC" CmdTopic = "TOPIC"
CmdUser = "USER"
CmdWho = "WHO" CmdWho = "WHO"
CmdWhois = "WHOIS" CmdWhois = "WHOIS"
) )