1 Commits

Author SHA1 Message Date
user
bd4326ad6f docs: update README schema section to match actual database schema
All checks were successful
check / check (push) Successful in 1m11s
Update the Schema section and related references throughout README.md to
accurately reflect the current 001_initial.sql migration:

- Rename 'users' table to 'sessions' with new columns: uuid, password_hash,
  signing_key, away_message
- Add new 'clients' table (uuid, session_id FK, token, created_at, last_seen)
- Add topic_set_by and topic_set_at columns to 'channels' table
- Update channel_members FK from user_id to session_id
- Add params column to messages table
- Update client_queues FK from user_id to client_id
- Update Queue Architecture diagram labels and surrounding text
- Update In-Memory Broker description to use client_id terminology
- Update Multi-Client Model MVP note to reflect sessions/clients split
2026-03-17 02:17:34 -07:00
33 changed files with 877 additions and 11304 deletions

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks ensure-web-dist
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks
BINARY := neoircd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -7,21 +7,10 @@ LDFLAGS := -X main.Version=$(VERSION) -X main.Buildarch=$(BUILDARCH)
all: check build
# ensure-web-dist creates placeholder files so //go:embed dist/* in
# web/embed.go resolves without a full Node.js build. The real SPA is
# built by the web-builder Docker stage; these placeholders let
# "make test" and "make build" work outside Docker.
ensure-web-dist:
@if [ ! -d web/dist ]; then \
mkdir -p web/dist && \
touch web/dist/index.html web/dist/style.css web/dist/app.js && \
echo "==> Created placeholder web/dist/ for go:embed"; \
fi
build: ensure-web-dist
build:
go build -ldflags "$(LDFLAGS)" -o bin/$(BINARY) ./cmd/neoircd
lint: ensure-web-dist
lint:
golangci-lint run --config .golangci.yml ./...
fmt:
@@ -31,8 +20,8 @@ fmt:
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: ensure-web-dist
go test -timeout 30s -race -cover ./... || go test -timeout 30s -race -v ./...
test:
go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong

787
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,6 @@ import (
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/middleware"
"git.eeqj.de/sneak/neoirc/internal/server"
"git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx"
)
@@ -36,7 +35,6 @@ func main() {
server.New,
middleware.New,
healthcheck.New,
stats.New,
),
fx.Invoke(func(*server.Server) {}),
).Run()

1
go.mod
View File

@@ -16,7 +16,6 @@ require (
github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0
golang.org/x/time v0.6.0
modernc.org/sqlite v1.45.0
)

2
go.sum
View File

@@ -151,8 +151,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strconv"
"strings"
@@ -29,19 +28,16 @@ var errHTTP = errors.New("HTTP error")
// Client wraps HTTP calls to the neoirc server API.
type Client struct {
BaseURL string
Token string
HTTPClient *http.Client
}
// NewClient creates a new API client with a cookie jar
// for automatic auth cookie management.
// NewClient creates a new API client.
func NewClient(baseURL string) *Client {
jar, _ := cookiejar.New(nil)
return &Client{
return &Client{ //nolint:exhaustruct // Token set after CreateSession
BaseURL: baseURL,
HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: httpTimeout,
Jar: jar,
},
}
}
@@ -83,6 +79,8 @@ func (client *Client) CreateSession(
return nil, fmt.Errorf("decode session: %w", err)
}
client.Token = resp.Token
return &resp, nil
}
@@ -123,7 +121,6 @@ func (client *Client) PollMessages(
Timeout: time.Duration(
timeout+pollExtraTime,
) * time.Second,
Jar: client.HTTPClient.Jar,
}
params := url.Values{}
@@ -148,6 +145,10 @@ func (client *Client) PollMessages(
return nil, fmt.Errorf("new request: %w", err)
}
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
resp, err := pollClient.Do(request)
if err != nil {
return nil, fmt.Errorf("poll request: %w", err)
@@ -303,6 +304,12 @@ func (client *Client) do(
"Content-Type", "application/json",
)
if client.Token != "" {
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
}
resp, err := client.HTTPClient.Do(request)
if err != nil {
return nil, fmt.Errorf("http: %w", err)

View File

@@ -7,8 +7,6 @@ import (
"fmt"
"math/big"
"time"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
)
const (
@@ -39,23 +37,6 @@ func MintHashcash(bits int, resource string) string {
}
}
// MintChannelHashcash computes a hashcash stamp bound to
// a specific channel and message body. The stamp format
// is 1:bits:YYMMDD:channel:bodyhash:counter where
// bodyhash is the hex-encoded SHA-256 of the message
// body bytes. Delegates to the internal/hashcash package.
func MintChannelHashcash(
bits int,
channel string,
body []byte,
) string {
bodyHash := hashcash.BodyHash(body)
return hashcash.MintChannelStamp(
bits, channel, bodyHash,
)
}
// hasLeadingZeroBits checks if hash has at least numBits
// leading zero bits.
func hasLeadingZeroBits(

View File

@@ -12,6 +12,7 @@ type SessionRequest struct {
type SessionResponse struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
// StateResponse is the response from GET /api/v1/state.

View File

@@ -46,10 +46,6 @@ type Config struct {
FederationKey string
SessionIdleTimeout string
HashcashBits int
OperName string
OperPassword string
LoginRateLimit float64
LoginRateBurst int
params *Params
log *slog.Logger
}
@@ -82,10 +78,6 @@ func New(
viper.SetDefault("FEDERATION_KEY", "")
viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h")
viper.SetDefault("NEOIRC_HASHCASH_BITS", "20")
viper.SetDefault("NEOIRC_OPER_NAME", "")
viper.SetDefault("NEOIRC_OPER_PASSWORD", "")
viper.SetDefault("LOGIN_RATE_LIMIT", "1")
viper.SetDefault("LOGIN_RATE_BURST", "5")
err := viper.ReadInConfig()
if err != nil {
@@ -112,10 +104,6 @@ func New(
FederationKey: viper.GetString("FEDERATION_KEY"),
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"),
OperName: viper.GetString("NEOIRC_OPER_NAME"),
OperPassword: viper.GetString("NEOIRC_OPER_PASSWORD"),
LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"),
LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"),
log: log,
params: &params,
}

View File

@@ -10,46 +10,93 @@ import (
"golang.org/x/crypto/bcrypt"
)
//nolint:gochecknoglobals // var so tests can override via SetBcryptCost
var bcryptCost = bcrypt.DefaultCost
// SetBcryptCost overrides the bcrypt cost.
// Use bcrypt.MinCost in tests to avoid slow hashing.
func SetBcryptCost(cost int) { bcryptCost = cost }
const bcryptCost = bcrypt.DefaultCost
var errNoPassword = errors.New(
"account has no password set",
)
// SetPassword sets a bcrypt-hashed password on a session,
// enabling multi-client login via POST /api/v1/login.
func (database *Database) SetPassword(
// RegisterUser creates a session with a hashed password
// and returns session ID, client ID, and token.
func (database *Database) RegisterUser(
ctx context.Context,
sessionID int64,
password string,
) error {
nick, password string,
) (int64, int64, string, error) {
hash, err := bcrypt.GenerateFromPassword(
[]byte(password), bcryptCost,
)
if err != nil {
return fmt.Errorf("hash password: %w", err)
return 0, 0, "", fmt.Errorf(
"hash password: %w", err,
)
}
_, err = database.conn.ExecContext(ctx,
"UPDATE sessions SET password_hash = ? WHERE id = ?",
string(hash), sessionID)
sessionUUID := uuid.New().String()
clientUUID := uuid.New().String()
token, err := generateToken()
if err != nil {
return fmt.Errorf("set password: %w", err)
return 0, 0, "", err
}
return nil
now := time.Now()
transaction, err := database.conn.BeginTx(ctx, nil)
if err != nil {
return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
}
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, password_hash,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
sessionUUID, nick, string(hash), now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
tokenHash := hashToken(token)
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit registration: %w", err,
)
}
return sessionID, clientID, token, nil
}
// LoginUser verifies a nick/password and creates a new
// client token.
func (database *Database) LoginUser(
ctx context.Context,
nick, password, remoteIP, hostname string,
nick, password string,
) (int64, int64, string, error) {
var (
sessionID int64
@@ -96,11 +143,10 @@ func (database *Database) LoginUser(
res, err := database.conn.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token, ip, hostname,
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
if err != nil {
return 0, 0, "", fmt.Errorf(
"create login client: %w", err,

View File

@@ -6,65 +6,63 @@ import (
_ "modernc.org/sqlite"
)
func TestSetPassword(t *testing.T) {
func TestRegisterUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err :=
database.CreateSession(ctx, "passuser", "", "", "")
sessionID, clientID, token, err :=
database.RegisterUser(ctx, "reguser", "password123")
if err != nil {
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "password123",
)
if err != nil {
t.Fatal(err)
}
// Verify we can now log in with the password.
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "passuser", "password123", "", "")
if err != nil {
t.Fatal(err)
}
if loginSID == 0 || loginCID == 0 || loginToken == "" {
if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token")
}
// Verify session works via token lookup.
sid, cid, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if sid != sessionID || cid != clientID {
t.Fatal("session/client id mismatch")
}
if nick != "reguser" {
t.Fatalf("expected reguser, got %s", nick)
}
}
func TestSetPasswordThenWrongLogin(t *testing.T) {
func TestRegisterUserDuplicateNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err :=
database.CreateSession(ctx, "wrongpw", "", "", "")
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "dupnick", "password123")
if err != nil {
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "correctpass",
)
if err != nil {
t.Fatal(err)
_ = regSID
_ = regCID
_ = regToken
dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345")
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
}
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
_ = dupSID
_ = dupCID
_ = dupToken
}
func TestLoginUser(t *testing.T) {
@@ -73,26 +71,23 @@ func TestLoginUser(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err :=
database.CreateSession(ctx, "loginuser", "", "", "")
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "loginuser", "mypassword")
if err != nil {
t.Fatal(err)
}
err = database.SetPassword(
ctx, sessionID, "mypassword",
)
_ = regSID
_ = regCID
_ = regToken
sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword")
if err != nil {
t.Fatal(err)
}
loginSID, loginCID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword", "", "")
if err != nil {
t.Fatal(err)
}
if loginSID == 0 || loginCID == 0 || token == "" {
if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token")
}
@@ -108,6 +103,33 @@ func TestLoginUser(t *testing.T) {
}
}
func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
}
func TestLoginUserNoPassword(t *testing.T) {
t.Parallel()
@@ -116,7 +138,7 @@ func TestLoginUserNoPassword(t *testing.T) {
// Create anonymous session (no password).
anonSID, anonCID, anonToken, err :=
database.CreateSession(ctx, "anon", "", "", "")
database.CreateSession(ctx, "anon")
if err != nil {
t.Fatal(err)
}
@@ -126,7 +148,7 @@ func TestLoginUserNoPassword(t *testing.T) {
_ = anonToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "anon", "anything1", "", "")
database.LoginUser(ctx, "anon", "anything1")
if loginErr == nil {
t.Fatal(
"expected error for login on passwordless account",
@@ -145,7 +167,7 @@ func TestLoginUserNonexistent(t *testing.T) {
ctx := t.Context()
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "ghost", "password123", "", "")
database.LoginUser(ctx, "ghost", "password123")
if err == nil {
t.Fatal("expected error for nonexistent user")
}

View File

@@ -1,14 +0,0 @@
package db_test
import (
"os"
"testing"
"git.eeqj.de/sneak/neoirc/internal/db"
"golang.org/x/crypto/bcrypt"
)
func TestMain(m *testing.M) {
db.SetBcryptCost(bcrypt.MinCost)
os.Exit(m.Run())
}

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) {
ctx := t.Context()
sessionID, _, token, err := database.CreateSession(
ctx, "alice", "", "", "",
ctx, "alice",
)
if err != nil {
t.Fatal(err)
@@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) {
}
_, _, dupToken, dupErr := database.CreateSession(
ctx, "alice", "", "", "",
ctx, "alice",
)
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
@@ -54,249 +54,13 @@ func TestCreateSession(t *testing.T) {
_ = dupToken
}
// assertSessionHostInfo creates a session and verifies
// the stored username and hostname match expectations.
func assertSessionHostInfo(
t *testing.T,
database *db.Database,
nick, inputUser, inputHost,
expectUser, expectHost string,
) {
t.Helper()
sessionID, _, _, err := database.CreateSession(
t.Context(), nick, inputUser, inputHost, "",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
t.Context(), sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.Username != expectUser {
t.Fatalf(
"expected username %s, got %s",
expectUser, info.Username,
)
}
if info.Hostname != expectHost {
t.Fatalf(
"expected hostname %s, got %s",
expectHost, info.Hostname,
)
}
}
func TestCreateSessionWithUserHost(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
assertSessionHostInfo(
t, database,
"hostuser", "myident", "example.com",
"myident", "example.com",
)
}
func TestCreateSessionDefaultUsername(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
// Empty username defaults to nick.
assertSessionHostInfo(
t, database,
"defaultu", "", "host.local",
"defaultu", "host.local",
)
}
func TestCreateSessionStoresIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, clientID, _, err := database.CreateSession(
ctx, "ipuser", "ident", "host.example.com",
"192.168.1.42",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "192.168.1.42" {
t.Fatalf(
"expected session IP 192.168.1.42, got %s",
info.IP,
)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "192.168.1.42" {
t.Fatalf(
"expected client IP 192.168.1.42, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "host.example.com" {
t.Fatalf(
"expected client hostname host.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestGetClientHostInfoNotFound(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
_, err := database.GetClientHostInfo(
t.Context(), 99999,
)
if err == nil {
t.Fatal("expected error for nonexistent client")
}
}
func TestGetSessionHostInfoNotFound(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
_, err := database.GetSessionHostInfo(
t.Context(), 99999,
)
if err == nil {
t.Fatal("expected error for nonexistent session")
}
}
func TestFormatHostmask(t *testing.T) {
t.Parallel()
result := db.FormatHostmask(
"nick", "user", "host.com",
)
if result != "nick!user@host.com" {
t.Fatalf(
"expected nick!user@host.com, got %s",
result,
)
}
}
func TestFormatHostmaskDefaults(t *testing.T) {
t.Parallel()
result := db.FormatHostmask("nick", "", "")
if result != "nick!nick@*" {
t.Fatalf(
"expected nick!nick@*, got %s",
result,
)
}
}
func TestMemberInfoHostmask(t *testing.T) {
t.Parallel()
member := &db.MemberInfo{ //nolint:exhaustruct // test only uses hostmask fields
Nick: "alice",
Username: "aliceident",
Hostname: "alice.example.com",
}
hostmask := member.Hostmask()
expected := "alice!aliceident@alice.example.com"
if hostmask != expected {
t.Fatalf(
"expected %s, got %s", expected, hostmask,
)
}
}
func TestChannelMembersIncludeUserHost(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "memuser", "myuser", "myhost.net", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(
ctx, "#hostchan",
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
members, err := database.ChannelMembers(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(members) != 1 {
t.Fatalf(
"expected 1 member, got %d", len(members),
)
}
if members[0].Username != "myuser" {
t.Fatalf(
"expected username myuser, got %s",
members[0].Username,
)
}
if members[0].Hostname != "myhost.net" {
t.Fatalf(
"expected hostname myhost.net, got %s",
members[0].Hostname,
)
}
}
func TestGetSessionByToken(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(ctx, "bob", "", "", "")
_, _, token, err := database.CreateSession(ctx, "bob")
if err != nil {
t.Fatal(err)
}
@@ -329,7 +93,7 @@ func TestGetSessionByNick(t *testing.T) {
ctx := t.Context()
charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie", "", "", "")
database.CreateSession(ctx, "charlie")
if err != nil {
t.Fatal(err)
}
@@ -386,7 +150,7 @@ func TestJoinAndPart(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, "user1", "", "", "")
sid, _, _, err := database.CreateSession(ctx, "user1")
if err != nil {
t.Fatal(err)
}
@@ -435,7 +199,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Fatal(err)
}
sid, _, _, err := database.CreateSession(ctx, "temp", "", "", "")
sid, _, _, err := database.CreateSession(ctx, "temp")
if err != nil {
t.Fatal(err)
}
@@ -470,7 +234,7 @@ func createSessionWithChannels(
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, nick, "", "", "")
sid, _, _, err := database.CreateSession(ctx, nick)
if err != nil {
t.Fatal(err)
}
@@ -553,7 +317,7 @@ func TestChangeNick(t *testing.T) {
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "old", "", "", "",
ctx, "old",
)
if err != nil {
t.Fatal(err)
@@ -637,7 +401,7 @@ func TestPollMessages(t *testing.T) {
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "poller", "", "", "",
ctx, "poller",
)
if err != nil {
t.Fatal(err)
@@ -744,7 +508,7 @@ func TestDeleteSession(t *testing.T) {
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "deleteme", "", "", "",
ctx, "deleteme",
)
if err != nil {
t.Fatal(err)
@@ -784,12 +548,12 @@ func TestChannelMembers(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "", "")
sid1, _, _, err := database.CreateSession(ctx, "m1")
if err != nil {
t.Fatal(err)
}
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "", "")
sid2, _, _, err := database.CreateSession(ctx, "m2")
if err != nil {
t.Fatal(err)
}
@@ -847,7 +611,7 @@ func TestEnqueueToClient(t *testing.T) {
ctx := t.Context()
_, _, token, err := database.CreateSession(
ctx, "enqclient", "", "", "",
ctx, "enqclient",
)
if err != nil {
t.Fatal(err)
@@ -887,604 +651,3 @@ func TestEnqueueToClient(t *testing.T) {
t.Fatalf("expected 1, got %d", len(msgs))
}
}
func TestSetAndCheckSessionOper(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.CreateSession(
ctx, "opernick", "", "", "",
)
if err != nil {
t.Fatal(err)
}
// Initially not oper.
isOper, err := database.IsSessionOper(ctx, sessionID)
if err != nil {
t.Fatal(err)
}
if isOper {
t.Fatal("expected session not to be oper")
}
// Set oper.
err = database.SetSessionOper(ctx, sessionID, true)
if err != nil {
t.Fatal(err)
}
isOper, err = database.IsSessionOper(ctx, sessionID)
if err != nil {
t.Fatal(err)
}
if !isOper {
t.Fatal("expected session to be oper")
}
// Unset oper.
err = database.SetSessionOper(ctx, sessionID, false)
if err != nil {
t.Fatal(err)
}
isOper, err = database.IsSessionOper(ctx, sessionID)
if err != nil {
t.Fatal(err)
}
if isOper {
t.Fatal("expected session not to be oper")
}
}
func TestGetLatestClientForSession(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.CreateSession(
ctx, "clientnick", "", "", "10.0.0.1",
)
if err != nil {
t.Fatal(err)
}
clientInfo, err := database.GetLatestClientForSession(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "10.0.0.1" {
t.Fatalf(
"expected IP 10.0.0.1, got %s",
clientInfo.IP,
)
}
}
func TestGetOperCount(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
// Create two sessions.
sid1, _, _, err := database.CreateSession(
ctx, "user1", "", "", "",
)
if err != nil {
t.Fatal(err)
}
sid2, _, _, err := database.CreateSession(
ctx, "user2", "", "", "",
)
_ = sid2
if err != nil {
t.Fatal(err)
}
// Initially zero opers.
count, err := database.GetOperCount(ctx)
if err != nil {
t.Fatal(err)
}
if count != 0 {
t.Fatalf("expected 0 opers, got %d", count)
}
// Set one as oper.
err = database.SetSessionOper(ctx, sid1, true)
if err != nil {
t.Fatal(err)
}
count, err = database.GetOperCount(ctx)
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatalf("expected 1 oper, got %d", count)
}
}
// --- Tier 2 Tests ---
func TestWildcardMatch(t *testing.T) {
t.Parallel()
tests := []struct {
pattern string
input string
match bool
}{
{"*!*@*", "nick!user@host", true},
{"*!*@*.example.com", "nick!user@foo.example.com", true},
{"*!*@*.example.com", "nick!user@other.net", false},
{"badnick!*@*", "badnick!user@host", true},
{"badnick!*@*", "goodnick!user@host", false},
{"nick!user@host", "nick!user@host", true},
{"nick!user@host", "nick!user@other", false},
{"*", "anything", true},
{"?ick!*@*", "nick!user@host", true},
{"?ick!*@*", "nn!user@host", false},
// Case-insensitive.
{"Nick!*@*", "nick!user@host", true},
}
for _, tc := range tests {
result := db.MatchBanMask(tc.pattern, tc.input)
if result != tc.match {
t.Errorf(
"MatchBanMask(%q, %q) = %v, want %v",
tc.pattern, tc.input, result, tc.match,
)
}
}
}
func TestChannelBanCRUD(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil {
t.Fatal(err)
}
// No bans initially.
bans, err := database.ListChannelBans(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(bans) != 0 {
t.Fatalf("expected 0 bans, got %d", len(bans))
}
// Add a ban.
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op",
)
if err != nil {
t.Fatal(err)
}
bans, err = database.ListChannelBans(ctx, chID)
if err != nil {
t.Fatal(err)
}
if len(bans) != 1 {
t.Fatalf("expected 1 ban, got %d", len(bans))
}
if bans[0].Mask != "*!*@evil.com" {
t.Fatalf("wrong mask: %s", bans[0].Mask)
}
// Duplicate add is ignored (OR IGNORE).
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op2",
)
if err != nil {
t.Fatal(err)
}
bans, _ = database.ListChannelBans(ctx, chID)
if len(bans) != 1 {
t.Fatalf("expected 1 ban after dup, got %d", len(bans))
}
// Remove ban.
err = database.RemoveChannelBan(
ctx, chID, "*!*@evil.com",
)
if err != nil {
t.Fatal(err)
}
bans, _ = database.ListChannelBans(ctx, chID)
if len(bans) != 0 {
t.Fatalf("expected 0 bans after remove, got %d", len(bans))
}
}
func TestIsSessionBanned(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "victim", "victim", "evil.com", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#bantest")
if err != nil {
t.Fatal(err)
}
// Not banned initially.
banned, err := database.IsSessionBanned(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if banned {
t.Fatal("should not be banned initially")
}
// Add ban matching the hostmask.
err = database.AddChannelBan(
ctx, chID, "*!*@evil.com", "op",
)
if err != nil {
t.Fatal(err)
}
banned, err = database.IsSessionBanned(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if !banned {
t.Fatal("should be banned")
}
}
func TestChannelInviteOnly(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#invite")
if err != nil {
t.Fatal(err)
}
// Default: not invite-only.
isIO, err := database.IsChannelInviteOnly(ctx, chID)
if err != nil {
t.Fatal(err)
}
if isIO {
t.Fatal("should not be invite-only by default")
}
// Set invite-only.
err = database.SetChannelInviteOnly(ctx, chID, true)
if err != nil {
t.Fatal(err)
}
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
if !isIO {
t.Fatal("should be invite-only")
}
// Unset.
err = database.SetChannelInviteOnly(ctx, chID, false)
if err != nil {
t.Fatal(err)
}
isIO, _ = database.IsChannelInviteOnly(ctx, chID)
if isIO {
t.Fatal("should not be invite-only")
}
}
func TestChannelInviteCRUD(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "invited", "", "", "",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#inv")
if err != nil {
t.Fatal(err)
}
// No invite initially.
has, err := database.HasChannelInvite(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
if has {
t.Fatal("should not have invite")
}
// Add invite.
err = database.AddChannelInvite(ctx, chID, sid, "op")
if err != nil {
t.Fatal(err)
}
has, _ = database.HasChannelInvite(ctx, chID, sid)
if !has {
t.Fatal("should have invite")
}
// Clear invite.
err = database.ClearChannelInvite(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
has, _ = database.HasChannelInvite(ctx, chID, sid)
if has {
t.Fatal("invite should be cleared")
}
}
func TestChannelSecret(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#secret")
if err != nil {
t.Fatal(err)
}
// Default: not secret.
isSec, err := database.IsChannelSecret(ctx, chID)
if err != nil {
t.Fatal(err)
}
if isSec {
t.Fatal("should not be secret by default")
}
err = database.SetChannelSecret(ctx, chID, true)
if err != nil {
t.Fatal(err)
}
isSec, _ = database.IsChannelSecret(ctx, chID)
if !isSec {
t.Fatal("should be secret")
}
}
// createTestSession is a helper to create a session and
// return only the session ID.
func createTestSession(
t *testing.T,
database *db.Database,
nick string,
) int64 {
t.Helper()
sid, _, _, err := database.CreateSession(
t.Context(), nick, "", "", "",
)
if err != nil {
t.Fatalf("create session %s: %v", nick, err)
}
return sid
}
func TestSecretChannelFiltering(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
// Create two sessions.
sid1 := createTestSession(t, database, "member")
sid2 := createTestSession(t, database, "outsider")
// Create a secret channel.
chID, _ := database.GetOrCreateChannel(ctx, "#secret")
_ = database.SetChannelSecret(ctx, chID, true)
_ = database.JoinChannel(ctx, chID, sid1)
// Create a non-secret channel.
chID2, _ := database.GetOrCreateChannel(ctx, "#public")
_ = database.JoinChannel(ctx, chID2, sid1)
// Member should see both.
list, err := database.ListAllChannelsWithCountsFiltered(
ctx, sid1,
)
if err != nil {
t.Fatal(err)
}
if len(list) != 2 {
t.Fatalf("member should see 2 channels, got %d", len(list))
}
// Outsider should only see public.
list, _ = database.ListAllChannelsWithCountsFiltered(
ctx, sid2,
)
if len(list) != 1 {
t.Fatalf("outsider should see 1 channel, got %d", len(list))
}
if list[0].Name != "#public" {
t.Fatalf("outsider should see #public, got %s", list[0].Name)
}
}
func TestWhoisChannelFiltering(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid1 := createTestSession(t, database, "target")
sid2 := createTestSession(t, database, "querier")
// Create secret channel, target joins it.
chID, _ := database.GetOrCreateChannel(ctx, "#hidden")
_ = database.SetChannelSecret(ctx, chID, true)
_ = database.JoinChannel(ctx, chID, sid1)
// Querier (non-member) should not see the channel.
channels, err := database.GetSessionChannelsFiltered(
ctx, sid1, sid2,
)
if err != nil {
t.Fatal(err)
}
if len(channels) != 0 {
t.Fatalf(
"querier should see 0 channels, got %d",
len(channels),
)
}
// Target querying self should see it.
channels, _ = database.GetSessionChannelsFiltered(
ctx, sid1, sid1,
)
if len(channels) != 1 {
t.Fatalf(
"self-query should see 1 channel, got %d",
len(channels),
)
}
}
//nolint:dupl // structurally similar to TestChannelUserLimit
func TestChannelKey(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#keyed")
if err != nil {
t.Fatal(err)
}
// Default: no key.
key, err := database.GetChannelKey(ctx, chID)
if err != nil {
t.Fatal(err)
}
if key != "" {
t.Fatalf("expected empty key, got %q", key)
}
// Set key.
err = database.SetChannelKey(ctx, chID, "secret123")
if err != nil {
t.Fatal(err)
}
key, _ = database.GetChannelKey(ctx, chID)
if key != "secret123" {
t.Fatalf("expected secret123, got %q", key)
}
// Clear key.
err = database.SetChannelKey(ctx, chID, "")
if err != nil {
t.Fatal(err)
}
key, _ = database.GetChannelKey(ctx, chID)
if key != "" {
t.Fatalf("expected empty key, got %q", key)
}
}
//nolint:dupl // structurally similar to TestChannelKey
func TestChannelUserLimit(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#limited")
if err != nil {
t.Fatal(err)
}
// Default: no limit.
limit, err := database.GetChannelUserLimit(ctx, chID)
if err != nil {
t.Fatal(err)
}
if limit != 0 {
t.Fatalf("expected 0 limit, got %d", limit)
}
// Set limit.
err = database.SetChannelUserLimit(ctx, chID, 50)
if err != nil {
t.Fatal(err)
}
limit, _ = database.GetChannelUserLimit(ctx, chID)
if limit != 50 {
t.Fatalf("expected 50, got %d", limit)
}
// Clear limit.
err = database.SetChannelUserLimit(ctx, chID, 0)
if err != nil {
t.Fatal(err)
}
limit, _ = database.GetChannelUserLimit(ctx, chID)
if limit != 0 {
t.Fatalf("expected 0, got %d", limit)
}
}

View File

@@ -6,11 +6,6 @@ CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
nick TEXT NOT NULL UNIQUE,
username TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
ip TEXT NOT NULL DEFAULT '',
is_oper INTEGER NOT NULL DEFAULT 0,
is_wallops INTEGER NOT NULL DEFAULT 0,
password_hash TEXT NOT NULL DEFAULT '',
signing_key TEXT NOT NULL DEFAULT '',
away_message TEXT NOT NULL DEFAULT '',
@@ -25,8 +20,6 @@ CREATE TABLE IF NOT EXISTS clients (
uuid TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE,
ip TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
@@ -40,46 +33,15 @@ CREATE TABLE IF NOT EXISTS channels (
topic TEXT NOT NULL DEFAULT '',
topic_set_by TEXT NOT NULL DEFAULT '',
topic_set_at DATETIME,
hashcash_bits INTEGER NOT NULL DEFAULT 0,
is_moderated INTEGER NOT NULL DEFAULT 0,
is_topic_locked INTEGER NOT NULL DEFAULT 1,
is_invite_only INTEGER NOT NULL DEFAULT 0,
is_secret INTEGER NOT NULL DEFAULT 0,
channel_key TEXT NOT NULL DEFAULT '',
user_limit INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Channel bans
CREATE TABLE IF NOT EXISTS channel_bans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
mask TEXT NOT NULL,
set_by TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, mask)
);
CREATE INDEX IF NOT EXISTS idx_channel_bans_channel ON channel_bans(channel_id);
-- Channel invites (in-memory would be simpler but DB survives restarts)
CREATE TABLE IF NOT EXISTS channel_invites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
invited_by TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, session_id)
);
CREATE INDEX IF NOT EXISTS idx_channel_invites_channel ON channel_invites(channel_id);
-- Channel members
CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
is_operator INTEGER NOT NULL DEFAULT 0,
is_voiced INTEGER NOT NULL DEFAULT 0,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, session_id)
);
@@ -99,14 +61,6 @@ CREATE TABLE IF NOT EXISTS messages (
CREATE INDEX IF NOT EXISTS idx_messages_to_id ON messages(msg_to, id);
CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
-- Spent hashcash tokens for replay prevention (1-year TTL)
CREATE TABLE IF NOT EXISTS spent_hashcash (
id INTEGER PRIMARY KEY AUTOINCREMENT,
stamp_hash TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_spent_hashcash_created ON spent_hashcash(created_at);
-- Per-client message queues for fan-out delivery
CREATE TABLE IF NOT EXISTS client_queues (
id INTEGER PRIMARY KEY AUTOINCREMENT,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,117 @@ import (
"net/http"
"strings"
"git.eeqj.de/sneak/neoirc/pkg/irc"
"git.eeqj.de/sneak/neoirc/internal/db"
)
const minPasswordLength = 8
// HandleRegister creates a new user with a password.
func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
request.Body = http.MaxBytesReader(
writer, request.Body, hdlr.maxBodySize(),
)
hdlr.handleRegister(writer, request)
}
}
func (hdlr *Handlers) handleRegister(
writer http.ResponseWriter,
request *http.Request,
) {
type registerRequest struct {
Nick string `json:"nick"`
Password string `json:"password"`
}
var payload registerRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
if len(payload.Password) < minPasswordLength {
hdlr.respondError(
writer, request,
"password must be at least 8 characters",
http.StatusBadRequest,
)
return
}
sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser(
request.Context(),
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.handleRegisterError(
writer, request, err,
)
return
}
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleRegisterError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if db.IsUniqueConstraintError(err) {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"register user failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
}
// HandleLogin authenticates a user with nick and password.
func (hdlr *Handlers) HandleLogin() http.HandlerFunc {
return func(
@@ -28,21 +134,6 @@ func (hdlr *Handlers) handleLogin(
writer http.ResponseWriter,
request *http.Request,
) {
ip := clientIP(request)
if !hdlr.loginLimiter.Allow(ip) {
writer.Header().Set(
"Retry-After", "1",
)
hdlr.respondError(
writer, request,
"too many login attempts, try again later",
http.StatusTooManyRequests,
)
return
}
type loginRequest struct {
Nick string `json:"nick"`
Password string `json:"password"`
@@ -73,27 +164,11 @@ func (hdlr *Handlers) handleLogin(
return
}
hdlr.executeLogin(
writer, request, payload.Nick, payload.Password,
)
}
func (hdlr *Handlers) executeLogin(
writer http.ResponseWriter,
request *http.Request,
nick, password string,
) {
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), remoteIP,
)
sessionID, clientID, token, err :=
hdlr.params.Database.LoginUser(
request.Context(),
nick, password,
remoteIP, hostname,
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.respondError(
@@ -105,78 +180,19 @@ func (hdlr *Handlers) executeLogin(
return
}
hdlr.stats.IncrConnections()
hdlr.deliverMOTD(
request, clientID, sessionID, nick,
request, clientID, sessionID, payload.Nick,
)
// Initialize channel state so the new client knows
// which channels the session already belongs to.
hdlr.initChannelState(
request, clientID, sessionID, nick,
request, clientID, sessionID, payload.Nick,
)
hdlr.setAuthCookie(writer, request, token)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": nick,
"nick": payload.Nick,
"token": token,
}, http.StatusOK)
}
// handlePass handles the IRC PASS command to set a
// password on the authenticated session, enabling
// multi-client login via POST /api/v1/login.
func (hdlr *Handlers) handlePass(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
lines := bodyLines()
if len(lines) == 0 || lines[0] == "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Not enough parameters",
)
return
}
password := lines[0]
if len(password) < minPasswordLength {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Password must be at least 8 characters",
)
return
}
err := hdlr.params.Database.SetPassword(
request.Context(), sessionID, password,
)
if err != nil {
hdlr.log.Error(
"set password failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}

View File

@@ -16,8 +16,6 @@ import (
"git.eeqj.de/sneak/neoirc/internal/hashcash"
"git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
"git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx"
)
@@ -32,16 +30,10 @@ type Params struct {
Config *config.Config
Database *db.Database
Healthcheck *healthcheck.Healthcheck
Stats *stats.Tracker
}
const defaultIdleTimeout = 30 * 24 * time.Hour
// spentHashcashTTL is how long spent hashcash tokens are
// retained for replay prevention. Per issue requirements,
// this is 1 year.
const spentHashcashTTL = 365 * 24 * time.Hour
// Handlers manages HTTP request handling.
type Handlers struct {
params *Params
@@ -49,9 +41,6 @@ type Handlers struct {
hc *healthcheck.Healthcheck
broker *broker.Broker
hashcashVal *hashcash.Validator
channelHashcash *hashcash.ChannelValidator
loginLimiter *ratelimit.Limiter
stats *stats.Tracker
cancelCleanup context.CancelFunc
}
@@ -65,25 +54,12 @@ func New(
resource = "neoirc"
}
loginRate := params.Config.LoginRateLimit
if loginRate <= 0 {
loginRate = ratelimit.DefaultRate
}
loginBurst := params.Config.LoginRateBurst
if loginBurst <= 0 {
loginBurst = ratelimit.DefaultBurst
}
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params,
log: params.Logger.Get(),
hc: params.Healthcheck,
broker: broker.New(),
hashcashVal: hashcash.NewValidator(resource),
channelHashcash: hashcash.NewChannelValidator(),
loginLimiter: ratelimit.New(loginRate, loginBurst),
stats: params.Stats,
}
lifecycle.Append(fx.Hook{
@@ -175,10 +151,6 @@ func (hdlr *Handlers) stopCleanup() {
if hdlr.cancelCleanup != nil {
hdlr.cancelCleanup()
}
if hdlr.loginLimiter != nil {
hdlr.loginLimiter.Stop()
}
}
func (hdlr *Handlers) cleanupLoop(ctx context.Context) {
@@ -309,20 +281,4 @@ func (hdlr *Handlers) pruneQueuesAndMessages(
)
}
}
// Prune spent hashcash tokens older than 1 year.
hashcashCutoff := time.Now().Add(-spentHashcashTTL)
pruned, err := hdlr.params.Database.
PruneSpentHashcash(ctx, hashcashCutoff)
if err != nil {
hdlr.log.Error(
"spent hashcash pruning failed", "error", err,
)
} else if pruned > 0 {
hdlr.log.Info(
"pruned spent hashcash tokens",
"deleted", pruned,
)
}
}

View File

@@ -12,7 +12,7 @@ func (hdlr *Handlers) HandleHealthCheck() http.HandlerFunc {
writer http.ResponseWriter,
request *http.Request,
) {
resp := hdlr.hc.Healthcheck(request.Context())
resp := hdlr.hc.Healthcheck()
hdlr.respondJSON(writer, request, resp, httpStatusOK)
}
}

View File

@@ -1,727 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
// maxUserhostNicks is the maximum number of nicks allowed
// in a single USERHOST query (RFC 2812).
const maxUserhostNicks = 5
// dispatchBodyOnlyCommand routes commands that take
// (writer, request, sessionID, clientID, nick, bodyLines).
func (hdlr *Handlers) dispatchBodyOnlyCommand(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, command string,
bodyLines func() []string,
) {
switch command {
case irc.CmdAway:
hdlr.handleAway(
writer, request,
sessionID, clientID, nick, bodyLines,
)
case irc.CmdNick:
hdlr.handleNick(
writer, request,
sessionID, clientID, nick, bodyLines,
)
case irc.CmdPass:
hdlr.handlePass(
writer, request,
sessionID, clientID, nick, bodyLines,
)
case irc.CmdInvite:
hdlr.handleInvite(
writer, request,
sessionID, clientID, nick, bodyLines,
)
}
}
// dispatchOperCommand routes oper-related commands (OPER,
// KILL, WALLOPS) to their handlers.
func (hdlr *Handlers) dispatchOperCommand(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, command string,
bodyLines func() []string,
) {
switch command {
case irc.CmdOper:
hdlr.handleOper(
writer, request,
sessionID, clientID, nick, bodyLines,
)
case irc.CmdKill:
hdlr.handleKill(
writer, request,
sessionID, clientID, nick, bodyLines,
)
case irc.CmdWallops:
hdlr.handleWallops(
writer, request,
sessionID, clientID, nick, bodyLines,
)
}
}
// handleUserhost handles the USERHOST command.
// Returns user@host info for up to 5 nicks.
func (hdlr *Handlers) handleUserhost(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
ctx := request.Context()
lines := bodyLines()
if len(lines) == 0 {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdUserhost},
"Not enough parameters",
)
return
}
// Limit to 5 nicks per RFC 2812.
nicks := lines
if len(nicks) > maxUserhostNicks {
nicks = nicks[:maxUserhostNicks]
}
infos, err := hdlr.params.Database.GetUserhostInfo(
ctx, nicks,
)
if err != nil {
hdlr.log.Error(
"userhost query failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
replyStr := hdlr.buildUserhostReply(infos)
hdlr.enqueueNumeric(
ctx, clientID, irc.RplUserHost, nick, nil,
replyStr,
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// buildUserhostReply builds the RPL_USERHOST reply
// string per RFC 2812.
func (hdlr *Handlers) buildUserhostReply(
infos []db.UserhostInfo,
) string {
replies := make([]string, 0, len(infos))
for idx := range infos {
info := &infos[idx]
username := info.Username
if username == "" {
username = info.Nick
}
hostname := info.Hostname
if hostname == "" {
hostname = hdlr.serverName()
}
operStar := ""
if info.IsOper {
operStar = "*"
}
awayPrefix := "+"
if info.AwayMessage != "" {
awayPrefix = "-"
}
replies = append(replies,
info.Nick+operStar+"="+
awayPrefix+username+"@"+hostname,
)
}
return strings.Join(replies, " ")
}
// handleVersion handles the VERSION command.
// Returns the server version string.
func (hdlr *Handlers) handleVersion(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
) {
ctx := request.Context()
srvName := hdlr.serverName()
version := hdlr.serverVersion()
// 351 RPL_VERSION
hdlr.enqueueNumeric(
ctx, clientID, irc.RplVersion, nick,
[]string{version + ".", srvName},
"",
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// handleAdmin handles the ADMIN command.
// Returns server admin contact info.
func (hdlr *Handlers) handleAdmin(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
) {
ctx := request.Context()
srvName := hdlr.serverName()
// 256 RPL_ADMINME
hdlr.enqueueNumeric(
ctx, clientID, irc.RplAdminMe, nick,
[]string{srvName},
"Administrative info",
)
// 257 RPL_ADMINLOC1
hdlr.enqueueNumeric(
ctx, clientID, irc.RplAdminLoc1, nick, nil,
"neoirc server",
)
// 258 RPL_ADMINLOC2
hdlr.enqueueNumeric(
ctx, clientID, irc.RplAdminLoc2, nick, nil,
"IRC over HTTP",
)
// 259 RPL_ADMINEMAIL
hdlr.enqueueNumeric(
ctx, clientID, irc.RplAdminEmail, nick, nil,
"admin@"+srvName,
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// handleInfo handles the INFO command.
// Returns server software information.
func (hdlr *Handlers) handleInfo(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
) {
ctx := request.Context()
version := hdlr.serverVersion()
infoLines := []string{
"neoirc — IRC semantics over HTTP",
"Version: " + version,
"Written in Go",
"Started: " +
hdlr.params.Globals.StartTime.
Format(time.RFC1123),
}
for _, line := range infoLines {
// 371 RPL_INFO
hdlr.enqueueNumeric(
ctx, clientID, irc.RplInfo, nick, nil,
line,
)
}
// 374 RPL_ENDOFINFO
hdlr.enqueueNumeric(
ctx, clientID, irc.RplEndOfInfo, nick, nil,
"End of /INFO list",
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// handleTime handles the TIME command.
// Returns the server's local time in RFC format.
func (hdlr *Handlers) handleTime(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
) {
ctx := request.Context()
srvName := hdlr.serverName()
// 391 RPL_TIME
hdlr.enqueueNumeric(
ctx, clientID, irc.RplTime, nick,
[]string{srvName},
time.Now().Format(time.RFC1123),
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// handleKill handles the KILL command.
// Forcibly disconnects a user (oper only).
func (hdlr *Handlers) handleKill(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
ctx := request.Context()
// Check oper status.
isOper, err := hdlr.params.Database.IsSessionOper(
ctx, sessionID,
)
if err != nil || !isOper {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNoPrivileges, nick, nil,
"Permission Denied- You're not an IRC operator",
)
return
}
lines := bodyLines()
if len(lines) == 0 {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdKill},
"Not enough parameters",
)
return
}
targetNick := strings.TrimSpace(lines[0])
if targetNick == "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdKill},
"Not enough parameters",
)
return
}
reason := "KILLed"
if len(lines) > 1 {
reason = lines[1]
}
targetSID, lookupErr := hdlr.params.Database.
GetSessionByNick(ctx, targetNick)
if lookupErr != nil {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNoSuchNick, nick,
[]string{targetNick},
"No such nick/channel",
)
return
}
// Do not allow killing yourself.
if targetSID == sessionID {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrCantKillServer, nick, nil,
"You cannot KILL yourself",
)
return
}
hdlr.executeKillUser(
request, targetSID, targetNick, nick, reason,
)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// executeKillUser forcibly disconnects a user: broadcasts
// QUIT to their channels, parts all channels, and deletes
// the session.
func (hdlr *Handlers) executeKillUser(
request *http.Request,
targetSID int64,
targetNick, killerNick, reason string,
) {
ctx := request.Context()
quitMsg := "Killed (" + killerNick + " (" + reason + "))"
quitBody, err := json.Marshal([]string{quitMsg})
if err != nil {
hdlr.log.Error(
"marshal kill quit body", "error", err,
)
return
}
channels, _ := hdlr.params.Database.
GetSessionChannels(ctx, targetSID)
notified := map[int64]bool{}
var dbID int64
if len(channels) > 0 {
dbID, _, _ = hdlr.params.Database.InsertMessage(
ctx, irc.CmdQuit, targetNick, "",
nil, json.RawMessage(quitBody), nil,
)
}
for _, chanInfo := range channels {
memberIDs, _ := hdlr.params.Database.
GetChannelMemberIDs(ctx, chanInfo.ID)
for _, mid := range memberIDs {
if mid != targetSID && !notified[mid] {
notified[mid] = true
_ = hdlr.params.Database.EnqueueToSession(
ctx, mid, dbID,
)
hdlr.broker.Notify(mid)
}
}
_ = hdlr.params.Database.PartChannel(
ctx, chanInfo.ID, targetSID,
)
_ = hdlr.params.Database.DeleteChannelIfEmpty(
ctx, chanInfo.ID,
)
}
_ = hdlr.params.Database.DeleteSession(
ctx, targetSID,
)
}
// handleWallops handles the WALLOPS command.
// Broadcasts a message to all users with +w usermode
// (oper only).
func (hdlr *Handlers) handleWallops(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
ctx := request.Context()
// Check oper status.
isOper, err := hdlr.params.Database.IsSessionOper(
ctx, sessionID,
)
if err != nil || !isOper {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNoPrivileges, nick, nil,
"Permission Denied- You're not an IRC operator",
)
return
}
lines := bodyLines()
if len(lines) == 0 {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdWallops},
"Not enough parameters",
)
return
}
message := strings.Join(lines, " ")
wallopsSIDs, err := hdlr.params.Database.
GetWallopsSessionIDs(ctx)
if err != nil {
hdlr.log.Error(
"get wallops sessions failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
if len(wallopsSIDs) > 0 {
body, mErr := json.Marshal([]string{message})
if mErr != nil {
hdlr.log.Error(
"marshal wallops body", "error", mErr,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
_ = hdlr.fanOutSilent(
request, irc.CmdWallops, nick, "*",
json.RawMessage(body), wallopsSIDs,
)
}
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// handleUserMode handles user mode queries and changes
// (e.g., MODE nick, MODE nick +w).
func (hdlr *Handlers) handleUserMode(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, target string,
bodyLines func() []string,
) {
ctx := request.Context()
lines := bodyLines()
// Mode change requested.
if len(lines) > 0 {
// Users can only change their own modes.
if target != nick && target != "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrUsersDoNotMatch, nick, nil,
"Can't change mode for other users",
)
return
}
hdlr.applyUserModeChange(
writer, request,
sessionID, clientID, nick, lines[0],
)
return
}
// Mode query — build the current mode string.
modeStr := hdlr.buildUserModeString(ctx, sessionID)
hdlr.enqueueNumeric(
ctx, clientID, irc.RplUmodeIs, nick, nil,
modeStr,
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// buildUserModeString constructs the mode string for a
// user (e.g., "+ow" for oper+wallops).
func (hdlr *Handlers) buildUserModeString(
ctx context.Context,
sessionID int64,
) string {
modes := "+"
isOper, err := hdlr.params.Database.IsSessionOper(
ctx, sessionID,
)
if err == nil && isOper {
modes += "o"
}
isWallops, err := hdlr.params.Database.IsSessionWallops(
ctx, sessionID,
)
if err == nil && isWallops {
modes += "w"
}
return modes
}
// applyUserModeChange applies a user mode change string
// (e.g., "+w", "-w").
func (hdlr *Handlers) applyUserModeChange(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, modeStr string,
) {
ctx := request.Context()
if len(modeStr) < 2 { //nolint:mnd // +/- and mode char
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrUmodeUnknownFlag, nick, nil,
"Unknown MODE flag",
)
return
}
adding := modeStr[0] == '+'
modeChar := modeStr[1:]
applied, err := hdlr.applyModeChar(
ctx, writer, request,
sessionID, clientID, nick,
modeChar, adding,
)
if err != nil || !applied {
return
}
newModes := hdlr.buildUserModeString(ctx, sessionID)
hdlr.enqueueNumeric(
ctx, clientID, irc.RplUmodeIs, nick, nil,
newModes,
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// applyModeChar applies a single user mode character.
// Returns (applied, error).
func (hdlr *Handlers) applyModeChar(
ctx context.Context,
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, modeChar string,
adding bool,
) (bool, error) {
switch modeChar {
case "w":
err := hdlr.params.Database.SetSessionWallops(
ctx, sessionID, adding,
)
if err != nil {
hdlr.log.Error(
"set wallops mode failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return false, fmt.Errorf(
"set wallops: %w", err,
)
}
case "o":
// +o cannot be set via MODE, only via OPER.
if adding {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrUmodeUnknownFlag, nick, nil,
"Unknown MODE flag",
)
return false, nil
}
err := hdlr.params.Database.SetSessionOper(
ctx, sessionID, false,
)
if err != nil {
hdlr.log.Error(
"clear oper mode failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return false, fmt.Errorf(
"clear oper: %w", err,
)
}
default:
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrUmodeUnknownFlag, nick, nil,
"Unknown MODE flag",
)
return false, nil
}
return true, nil
}

View File

@@ -1,982 +0,0 @@
// Tests for Tier 3 utility IRC commands: USERHOST,
// VERSION, ADMIN, INFO, TIME, KILL, WALLOPS.
//
//nolint:paralleltest
package handlers_test
import (
"strings"
"testing"
)
// --- USERHOST ---
func TestUserhostSingleNick(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("alice")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "USERHOST",
bodyKey: []string{"alice"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 302 RPL_USERHOST.
msg := findNumericWithParams(msgs, "302")
if msg == nil {
t.Fatalf(
"expected RPL_USERHOST (302), got %v",
msgs,
)
}
// Body should contain "alice" with the
// nick=+user@host format.
body := getNumericBody(msg)
if !strings.Contains(body, "alice") {
t.Fatalf(
"expected body to contain 'alice', got %q",
body,
)
}
// '+' means not away.
if !strings.Contains(body, "=+") {
t.Fatalf(
"expected not-away prefix '=+', got %q",
body,
)
}
}
func TestUserhostMultipleNicks(t *testing.T) {
tserver := newTestServer(t)
token1 := tserver.createSession("bob")
token2 := tserver.createSession("carol")
_ = token2
_, lastID := tserver.pollMessages(token1, 0)
tserver.sendCommand(token1, map[string]any{
commandKey: "USERHOST",
bodyKey: []string{"bob", "carol"},
})
msgs, _ := tserver.pollMessages(token1, lastID)
msg := findNumericWithParams(msgs, "302")
if msg == nil {
t.Fatalf(
"expected RPL_USERHOST (302), got %v",
msgs,
)
}
body := getNumericBody(msg)
if !strings.Contains(body, "bob") {
t.Fatalf(
"expected body to contain 'bob', got %q",
body,
)
}
if !strings.Contains(body, "carol") {
t.Fatalf(
"expected body to contain 'carol', got %q",
body,
)
}
}
func TestUserhostNonexistentNick(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("dave")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "USERHOST",
bodyKey: []string{"nobody"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Should still get 302 but with empty body.
msg := findNumericWithParams(msgs, "302")
if msg == nil {
t.Fatalf(
"expected RPL_USERHOST (302), got %v",
msgs,
)
}
}
func TestUserhostNoParams(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("eve")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "USERHOST",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 461 ERR_NEEDMOREPARAMS.
if !findNumeric(msgs, "461") {
t.Fatalf(
"expected ERR_NEEDMOREPARAMS (461), got %v",
msgs,
)
}
}
func TestUserhostShowsOper(t *testing.T) {
tserver := newTestServerWithOper(t)
token := tserver.createSession("opernick")
_, lastID := tserver.pollMessages(token, 0)
// Authenticate as oper.
tserver.sendCommand(token, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
_, lastID = tserver.pollMessages(token, lastID)
// USERHOST should show '*' for oper.
tserver.sendCommand(token, map[string]any{
commandKey: "USERHOST",
bodyKey: []string{"opernick"},
})
msgs, _ := tserver.pollMessages(token, lastID)
msg := findNumericWithParams(msgs, "302")
if msg == nil {
t.Fatalf(
"expected RPL_USERHOST (302), got %v",
msgs,
)
}
body := getNumericBody(msg)
if !strings.Contains(body, "opernick*=") {
t.Fatalf(
"expected oper '*' in reply, got %q",
body,
)
}
}
func TestUserhostShowsAway(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("awaynick")
_, lastID := tserver.pollMessages(token, 0)
// Set away.
tserver.sendCommand(token, map[string]any{
commandKey: "AWAY",
bodyKey: []string{"gone fishing"},
})
_, lastID = tserver.pollMessages(token, lastID)
// USERHOST should show '-' for away.
tserver.sendCommand(token, map[string]any{
commandKey: "USERHOST",
bodyKey: []string{"awaynick"},
})
msgs, _ := tserver.pollMessages(token, lastID)
msg := findNumericWithParams(msgs, "302")
if msg == nil {
t.Fatalf(
"expected RPL_USERHOST (302), got %v",
msgs,
)
}
body := getNumericBody(msg)
if !strings.Contains(body, "=-") {
t.Fatalf(
"expected away prefix '=-' in reply, got %q",
body,
)
}
}
// --- VERSION ---
func TestVersion(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("frank")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "VERSION",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 351 RPL_VERSION.
msg := findNumericWithParams(msgs, "351")
if msg == nil {
t.Fatalf(
"expected RPL_VERSION (351), got %v",
msgs,
)
}
params := getNumericParams(msg)
if len(params) == 0 {
t.Fatal("expected VERSION params, got none")
}
// First param should contain version string.
if !strings.Contains(params[0], "test") {
t.Fatalf(
"expected version to contain 'test', got %q",
params[0],
)
}
}
// --- ADMIN ---
func TestAdmin(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("grace")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "ADMIN",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 256 RPL_ADMINME.
if !findNumeric(msgs, "256") {
t.Fatalf(
"expected RPL_ADMINME (256), got %v",
msgs,
)
}
// Expect 257 RPL_ADMINLOC1.
if !findNumeric(msgs, "257") {
t.Fatalf(
"expected RPL_ADMINLOC1 (257), got %v",
msgs,
)
}
// Expect 258 RPL_ADMINLOC2.
if !findNumeric(msgs, "258") {
t.Fatalf(
"expected RPL_ADMINLOC2 (258), got %v",
msgs,
)
}
// Expect 259 RPL_ADMINEMAIL.
if !findNumeric(msgs, "259") {
t.Fatalf(
"expected RPL_ADMINEMAIL (259), got %v",
msgs,
)
}
}
// --- INFO ---
func TestInfo(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("hank")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "INFO",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 371 RPL_INFO (at least one).
if !findNumeric(msgs, "371") {
t.Fatalf(
"expected RPL_INFO (371), got %v",
msgs,
)
}
// Expect 374 RPL_ENDOFINFO.
if !findNumeric(msgs, "374") {
t.Fatalf(
"expected RPL_ENDOFINFO (374), got %v",
msgs,
)
}
}
// --- TIME ---
func TestTime(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("iris")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "TIME",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 391 RPL_TIME.
msg := findNumericWithParams(msgs, "391")
if msg == nil {
t.Fatalf(
"expected RPL_TIME (391), got %v",
msgs,
)
}
}
// --- KILL ---
func TestKillSuccess(t *testing.T) {
tserver := newTestServerWithOper(t)
// Create the victim first.
victimToken := tserver.createSession("victim")
_ = victimToken
// Create oper user.
operToken := tserver.createSession("killer")
_, lastID := tserver.pollMessages(operToken, 0)
// Authenticate as oper.
tserver.sendCommand(operToken, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
_, lastID = tserver.pollMessages(operToken, lastID)
// Kill the victim.
status, result := tserver.sendCommand(
operToken, map[string]any{
commandKey: "KILL",
bodyKey: []string{"victim", "go away"},
},
)
if status != 200 {
t.Fatalf("expected 200, got %d: %v", status, result)
}
resultStatus, _ := result[statusKey].(string)
if resultStatus != "ok" {
t.Fatalf(
"expected status ok, got %v",
result,
)
}
// Verify the victim's session is gone by trying
// to WHOIS them.
tserver.sendCommand(operToken, map[string]any{
commandKey: "WHOIS",
toKey: "victim",
})
msgs, _ := tserver.pollMessages(operToken, lastID)
// Should get 401 ERR_NOSUCHNICK.
if !findNumeric(msgs, "401") {
t.Fatalf(
"expected victim to be gone (401), got %v",
msgs,
)
}
}
func TestKillNotOper(t *testing.T) {
tserver := newTestServer(t)
_ = tserver.createSession("target")
token := tserver.createSession("notoper")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "KILL",
bodyKey: []string{"target", "no reason"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 481 ERR_NOPRIVILEGES.
if !findNumeric(msgs, "481") {
t.Fatalf(
"expected ERR_NOPRIVILEGES (481), got %v",
msgs,
)
}
}
func TestKillNoParams(t *testing.T) {
tserver := newTestServerWithOper(t)
token := tserver.createSession("opertest")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
_, lastID = tserver.pollMessages(token, lastID)
tserver.sendCommand(token, map[string]any{
commandKey: "KILL",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 461 ERR_NEEDMOREPARAMS.
if !findNumeric(msgs, "461") {
t.Fatalf(
"expected ERR_NEEDMOREPARAMS (461), got %v",
msgs,
)
}
}
// sendOperKillCommand is a helper that creates an oper
// session, authenticates, then sends KILL with the given
// target nick, and returns the resulting messages.
func sendOperKillCommand(
t *testing.T,
tserver *testServer,
operNick, targetNick string,
) []map[string]any {
t.Helper()
token := tserver.createSession(operNick)
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
_, lastID = tserver.pollMessages(token, lastID)
tserver.sendCommand(token, map[string]any{
commandKey: "KILL",
bodyKey: []string{targetNick},
})
msgs, _ := tserver.pollMessages(token, lastID)
return msgs
}
func TestKillNonexistentUser(t *testing.T) {
tserver := newTestServerWithOper(t)
msgs := sendOperKillCommand(
t, tserver, "opertest2", "ghost",
)
// Expect 401 ERR_NOSUCHNICK.
if !findNumeric(msgs, "401") {
t.Fatalf(
"expected ERR_NOSUCHNICK (401), got %v",
msgs,
)
}
}
func TestKillSelf(t *testing.T) {
tserver := newTestServerWithOper(t)
msgs := sendOperKillCommand(
t, tserver, "selfkiller", "selfkiller",
)
// Expect 483 ERR_CANTKILLSERVER.
if !findNumeric(msgs, "483") {
t.Fatalf(
"expected ERR_CANTKILLSERVER (483), got %v",
msgs,
)
}
}
func TestKillBroadcastsQuit(t *testing.T) {
tserver := newTestServerWithOper(t)
// Create victim and join a channel.
victimToken := tserver.createSession("vuser")
tserver.sendCommand(victimToken, map[string]any{
commandKey: joinCmd,
toKey: "#killtest",
})
// Create observer and join same channel.
observerToken := tserver.createSession("observer")
tserver.sendCommand(observerToken, map[string]any{
commandKey: joinCmd,
toKey: "#killtest",
})
_, lastObs := tserver.pollMessages(observerToken, 0)
// Create oper.
operToken := tserver.createSession("theoper2")
tserver.pollMessages(operToken, 0)
tserver.sendCommand(operToken, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
tserver.pollMessages(operToken, 0)
// Kill the victim.
tserver.sendCommand(operToken, map[string]any{
commandKey: "KILL",
bodyKey: []string{"vuser", "testing kill"},
})
// Observer should see a QUIT message.
msgs, _ := tserver.pollMessages(observerToken, lastObs)
foundQuit := false
for _, msg := range msgs {
cmd, _ := msg["command"].(string)
if cmd == "QUIT" {
from, _ := msg["from"].(string)
if from == "vuser" {
foundQuit = true
break
}
}
}
if !foundQuit {
t.Fatalf(
"expected QUIT from vuser, got %v",
msgs,
)
}
}
// --- WALLOPS ---
func TestWallopsSuccess(t *testing.T) {
tserver := newTestServerWithOper(t)
// Create receiver with +w.
receiverToken := tserver.createSession("receiver")
tserver.sendCommand(receiverToken, map[string]any{
commandKey: "MODE",
toKey: "receiver",
bodyKey: []string{"+w"},
})
_, lastRecv := tserver.pollMessages(receiverToken, 0)
// Create oper.
operToken := tserver.createSession("walloper")
tserver.pollMessages(operToken, 0)
tserver.sendCommand(operToken, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
tserver.pollMessages(operToken, 0)
// Also set +w on oper so they receive it too.
tserver.sendCommand(operToken, map[string]any{
commandKey: "MODE",
toKey: "walloper",
bodyKey: []string{"+w"},
})
tserver.pollMessages(operToken, 0)
// Send WALLOPS.
tserver.sendCommand(operToken, map[string]any{
commandKey: "WALLOPS",
bodyKey: []string{"server going down"},
})
// Receiver should get the WALLOPS message.
msgs, _ := tserver.pollMessages(receiverToken, lastRecv)
foundWallops := false
for _, msg := range msgs {
cmd, _ := msg["command"].(string)
if cmd == "WALLOPS" {
foundWallops = true
break
}
}
if !foundWallops {
t.Fatalf(
"expected WALLOPS message, got %v",
msgs,
)
}
}
func TestWallopsNotOper(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("notoper2")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "WALLOPS",
bodyKey: []string{"hello"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 481 ERR_NOPRIVILEGES.
if !findNumeric(msgs, "481") {
t.Fatalf(
"expected ERR_NOPRIVILEGES (481), got %v",
msgs,
)
}
}
func TestWallopsNoParams(t *testing.T) {
tserver := newTestServerWithOper(t)
token := tserver.createSession("operempty")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
_, lastID = tserver.pollMessages(token, lastID)
tserver.sendCommand(token, map[string]any{
commandKey: "WALLOPS",
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 461 ERR_NEEDMOREPARAMS.
if !findNumeric(msgs, "461") {
t.Fatalf(
"expected ERR_NEEDMOREPARAMS (461), got %v",
msgs,
)
}
}
func TestWallopsNotReceivedWithoutW(t *testing.T) {
tserver := newTestServerWithOper(t)
// Create receiver WITHOUT +w.
receiverToken := tserver.createSession("nowallops")
_, lastRecv := tserver.pollMessages(receiverToken, 0)
// Create oper.
operToken := tserver.createSession("walloper2")
tserver.pollMessages(operToken, 0)
tserver.sendCommand(operToken, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
tserver.pollMessages(operToken, 0)
// Send WALLOPS.
tserver.sendCommand(operToken, map[string]any{
commandKey: "WALLOPS",
bodyKey: []string{"secret message"},
})
// Receiver should NOT get the WALLOPS message.
msgs, _ := tserver.pollMessages(receiverToken, lastRecv)
for _, msg := range msgs {
cmd, _ := msg["command"].(string)
if cmd == "WALLOPS" {
t.Fatalf(
"did not expect WALLOPS for user "+
"without +w, got %v",
msgs,
)
}
}
}
// --- User Mode +w ---
func TestUserModeSetW(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("wmoder")
_, lastID := tserver.pollMessages(token, 0)
// Set +w.
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "wmoder",
bodyKey: []string{"+w"},
})
msgs, lastID := tserver.pollMessages(token, lastID)
// Expect 221 RPL_UMODEIS with "+w".
msg := findNumericWithParams(msgs, "221")
if msg == nil {
t.Fatalf(
"expected RPL_UMODEIS (221), got %v",
msgs,
)
}
body := getNumericBody(msg)
if !strings.Contains(body, "w") {
t.Fatalf(
"expected mode string to contain 'w', got %q",
body,
)
}
// Now query mode.
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "wmoder",
})
msgs, _ = tserver.pollMessages(token, lastID)
msg = findNumericWithParams(msgs, "221")
if msg == nil {
t.Fatalf(
"expected RPL_UMODEIS (221) on query, got %v",
msgs,
)
}
body = getNumericBody(msg)
if !strings.Contains(body, "w") {
t.Fatalf(
"expected mode '+w' in query, got %q",
body,
)
}
}
func TestUserModeUnsetW(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("wunsetter")
_, lastID := tserver.pollMessages(token, 0)
// Set +w first.
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "wunsetter",
bodyKey: []string{"+w"},
})
_, lastID = tserver.pollMessages(token, lastID)
// Unset -w.
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "wunsetter",
bodyKey: []string{"-w"},
})
msgs, _ := tserver.pollMessages(token, lastID)
msg := findNumericWithParams(msgs, "221")
if msg == nil {
t.Fatalf(
"expected RPL_UMODEIS (221), got %v",
msgs,
)
}
body := getNumericBody(msg)
if strings.Contains(body, "w") {
t.Fatalf(
"expected 'w' to be removed, got %q",
body,
)
}
}
func TestUserModeUnknownFlag(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("badmode")
_, lastID := tserver.pollMessages(token, 0)
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "badmode",
bodyKey: []string{"+z"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 501 ERR_UMODEUNKNOWNFLAG.
if !findNumeric(msgs, "501") {
t.Fatalf(
"expected ERR_UMODEUNKNOWNFLAG (501), got %v",
msgs,
)
}
}
func TestUserModeCannotSetO(t *testing.T) {
tserver := newTestServer(t)
token := tserver.createSession("tryoper")
_, lastID := tserver.pollMessages(token, 0)
// Try to set +o via MODE (should fail).
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "tryoper",
bodyKey: []string{"+o"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 501 ERR_UMODEUNKNOWNFLAG.
if !findNumeric(msgs, "501") {
t.Fatalf(
"expected ERR_UMODEUNKNOWNFLAG (501), got %v",
msgs,
)
}
}
func TestUserModeDeoper(t *testing.T) {
tserver := newTestServerWithOper(t)
token := tserver.createSession("deoper")
_, lastID := tserver.pollMessages(token, 0)
// Authenticate as oper.
tserver.sendCommand(token, map[string]any{
commandKey: "OPER",
bodyKey: []string{testOperName, testOperPassword},
})
_, lastID = tserver.pollMessages(token, lastID)
// Use MODE -o to de-oper.
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "deoper",
bodyKey: []string{"-o"},
})
msgs, _ := tserver.pollMessages(token, lastID)
msg := findNumericWithParams(msgs, "221")
if msg == nil {
t.Fatalf(
"expected RPL_UMODEIS (221), got %v",
msgs,
)
}
body := getNumericBody(msg)
if strings.Contains(body, "o") {
t.Fatalf(
"expected 'o' to be removed, got %q",
body,
)
}
}
func TestUserModeCannotChangeOtherUser(t *testing.T) {
tserver := newTestServer(t)
_ = tserver.createSession("other")
token := tserver.createSession("changer")
_, lastID := tserver.pollMessages(token, 0)
// Try to change another user's mode.
tserver.sendCommand(token, map[string]any{
commandKey: "MODE",
toKey: "other",
bodyKey: []string{"+w"},
})
msgs, _ := tserver.pollMessages(token, lastID)
// Expect 502 ERR_USERSDONTMATCH.
if !findNumeric(msgs, "502") {
t.Fatalf(
"expected ERR_USERSDONTMATCH (502), got %v",
msgs,
)
}
}
// getNumericBody extracts the body text from a numeric
// message. The body is stored as a JSON array; this
// returns the first element.
func getNumericBody(msg map[string]any) string {
raw, exists := msg["body"]
if !exists || raw == nil {
return ""
}
arr, isArr := raw.([]any)
if !isArr || len(arr) == 0 {
return ""
}
str, isStr := arr[0].(string)
if !isStr {
return ""
}
return str
}

View File

@@ -1,186 +0,0 @@
package hashcash
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"time"
)
var (
errBodyHashMismatch = errors.New(
"body hash mismatch",
)
errBodyHashMissing = errors.New(
"body hash missing",
)
)
// ChannelValidator checks hashcash stamps for
// per-channel PRIVMSG validation. It verifies that
// stamps are bound to a specific channel and message
// body. Replay prevention is handled externally via
// the database spent_hashcash table for persistence
// across server restarts (1-year TTL).
type ChannelValidator struct{}
// NewChannelValidator creates a ChannelValidator.
func NewChannelValidator() *ChannelValidator {
return &ChannelValidator{}
}
// BodyHash computes the hex-encoded SHA-256 hash of a
// message body for use in hashcash stamp validation.
func BodyHash(body []byte) string {
hash := sha256.Sum256(body)
return hex.EncodeToString(hash[:])
}
// ValidateStamp checks a channel hashcash stamp. It
// verifies the stamp format, difficulty, date, channel
// binding, body hash binding, and proof-of-work. Replay
// detection is NOT performed here — callers must check
// the spent_hashcash table separately.
//
// Stamp format: 1:bits:YYMMDD:channel:bodyhash:counter.
func (cv *ChannelValidator) ValidateStamp(
stamp string,
requiredBits int,
channel string,
bodyHash string,
) error {
if requiredBits <= 0 {
return nil
}
parts := strings.Split(stamp, ":")
if len(parts) != stampFields {
return fmt.Errorf(
"%w: expected %d, got %d",
errInvalidFields, stampFields, len(parts),
)
}
version := parts[0]
bitsStr := parts[1]
dateStr := parts[2]
resource := parts[3]
stampBodyHash := parts[4]
headerErr := validateChannelHeader(
version, bitsStr, resource,
requiredBits, channel,
)
if headerErr != nil {
return headerErr
}
stampTime, parseErr := parseStampDate(dateStr)
if parseErr != nil {
return parseErr
}
timeErr := validateTime(stampTime)
if timeErr != nil {
return timeErr
}
bodyErr := validateBodyHash(
stampBodyHash, bodyHash,
)
if bodyErr != nil {
return bodyErr
}
return validateProof(stamp, requiredBits)
}
// StampHash returns a deterministic hash of a stamp
// string for use as a spent-token key.
func StampHash(stamp string) string {
hash := sha256.Sum256([]byte(stamp))
return hex.EncodeToString(hash[:])
}
func validateChannelHeader(
version, bitsStr, resource string,
requiredBits int,
channel string,
) error {
if version != stampVersion {
return fmt.Errorf(
"%w: %s", errBadVersion, version,
)
}
claimedBits, err := strconv.Atoi(bitsStr)
if err != nil || claimedBits < requiredBits {
return fmt.Errorf(
"%w: need %d bits",
errInsufficientBits, requiredBits,
)
}
if resource != channel {
return fmt.Errorf(
"%w: got %q, want %q",
errWrongResource, resource, channel,
)
}
return nil
}
func validateBodyHash(
stampBodyHash, expectedBodyHash string,
) error {
if stampBodyHash == "" {
return errBodyHashMissing
}
if stampBodyHash != expectedBodyHash {
return fmt.Errorf(
"%w: got %q, want %q",
errBodyHashMismatch,
stampBodyHash, expectedBodyHash,
)
}
return nil
}
// MintChannelStamp computes a channel hashcash stamp
// with the given difficulty, channel name, and body hash.
// This is intended for clients to generate stamps before
// sending PRIVMSG to hashcash-protected channels.
//
// Stamp format: 1:bits:YYMMDD:channel:bodyhash:counter.
func MintChannelStamp(
bits int,
channel string,
bodyHash string,
) string {
date := time.Now().UTC().Format(dateFormatShort)
prefix := fmt.Sprintf(
"1:%d:%s:%s:%s:",
bits, date, channel, bodyHash,
)
counter := uint64(0)
for {
stamp := prefix + strconv.FormatUint(counter, 16)
hash := sha256.Sum256([]byte(stamp))
if hasLeadingZeroBits(hash[:], bits) {
return stamp
}
counter++
}
}

View File

@@ -1,244 +0,0 @@
package hashcash_test
import (
"crypto/sha256"
"encoding/hex"
"testing"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
)
const (
testChannel = "#general"
testBodyText = `["hello world"]`
)
func testBodyHash() string {
hash := sha256.Sum256([]byte(testBodyText))
return hex.EncodeToString(hash[:])
}
func TestChannelValidateHappyPath(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err != nil {
t.Fatalf("valid channel stamp rejected: %v", err)
}
}
func TestChannelValidateWrongChannel(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, testBits, "#other", bodyHash,
)
if err == nil {
t.Fatal("expected channel mismatch error")
}
}
func TestChannelValidateWrongBodyHash(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
wrongHash := sha256.Sum256([]byte("different body"))
wrongBodyHash := hex.EncodeToString(wrongHash[:])
err := validator.ValidateStamp(
stamp, testBits, testChannel, wrongBodyHash,
)
if err == nil {
t.Fatal("expected body hash mismatch error")
}
}
func TestChannelValidateInsufficientBits(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Mint with 2 bits but require 4.
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, 4, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected insufficient bits error")
}
}
func TestChannelValidateZeroBitsSkips(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
"garbage", 0, "#ch", "abc",
)
if err != nil {
t.Fatalf("zero bits should skip: %v", err)
}
}
func TestChannelValidateBadFormat(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
"not:valid", testBits, testChannel, "abc",
)
if err == nil {
t.Fatal("expected bad format error")
}
}
func TestChannelValidateBadVersion(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := "2:2:260317:#general:" + bodyHash + ":counter"
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected bad version error")
}
}
func TestChannelValidateExpiredStamp(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Mint with a very old date by manually constructing.
stamp := mintStampWithDate(
t, testBits, testChannel, "200101",
)
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected expired stamp error")
}
}
func TestChannelValidateMissingBodyHash(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Construct a stamp with empty body hash field.
stamp := mintStampWithDate(
t, testBits, testChannel, todayDate(),
)
// This uses the session-style stamp which has empty
// ext field — body hash is missing.
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected missing body hash error")
}
}
func TestBodyHash(t *testing.T) {
t.Parallel()
body := []byte(`["hello world"]`)
bodyHash := hashcash.BodyHash(body)
if len(bodyHash) != 64 {
t.Fatalf(
"expected 64-char hex hash, got %d",
len(bodyHash),
)
}
// Same input should produce same hash.
bodyHash2 := hashcash.BodyHash(body)
if bodyHash != bodyHash2 {
t.Fatal("body hash not deterministic")
}
// Different input should produce different hash.
bodyHash3 := hashcash.BodyHash([]byte("different"))
if bodyHash == bodyHash3 {
t.Fatal("different inputs produced same hash")
}
}
func TestStampHash(t *testing.T) {
t.Parallel()
hash1 := hashcash.StampHash("stamp1")
hash2 := hashcash.StampHash("stamp2")
if hash1 == hash2 {
t.Fatal("different stamps produced same hash")
}
// Same input should be deterministic.
hash1b := hashcash.StampHash("stamp1")
if hash1 != hash1b {
t.Fatal("stamp hash not deterministic")
}
}
func TestMintChannelStamp(t *testing.T) {
t.Parallel()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
if stamp == "" {
t.Fatal("expected non-empty stamp")
}
// Validate the minted stamp.
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err != nil {
t.Fatalf("minted stamp failed validation: %v", err)
}
}

View File

@@ -10,7 +10,6 @@ import (
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/globals"
"git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx"
)
@@ -22,7 +21,6 @@ type Params struct {
Config *config.Config
Logger *logger.Logger
Database *db.Database
Stats *stats.Tracker
}
// Healthcheck tracks server uptime and provides health status.
@@ -66,22 +64,11 @@ type Response struct {
Version string `json:"version"`
Appname string `json:"appname"`
Maintenance bool `json:"maintenanceMode"`
// Runtime statistics.
Sessions int64 `json:"sessions"`
Clients int64 `json:"clients"`
QueuedLines int64 `json:"queuedLines"`
Channels int64 `json:"channels"`
ConnectionsSinceBoot int64 `json:"connectionsSinceBoot"`
SessionsSinceBoot int64 `json:"sessionsSinceBoot"`
MessagesSinceBoot int64 `json:"messagesSinceBoot"`
}
// Healthcheck returns the current health status of the server.
func (hcheck *Healthcheck) Healthcheck(
ctx context.Context,
) *Response {
resp := &Response{
func (hcheck *Healthcheck) Healthcheck() *Response {
return &Response{
Status: "ok",
Now: time.Now().UTC().Format(time.RFC3339Nano),
UptimeSeconds: int64(hcheck.uptime().Seconds()),
@@ -89,64 +76,6 @@ func (hcheck *Healthcheck) Healthcheck(
Appname: hcheck.params.Globals.Appname,
Version: hcheck.params.Globals.Version,
Maintenance: hcheck.params.Config.MaintenanceMode,
Sessions: 0,
Clients: 0,
QueuedLines: 0,
Channels: 0,
ConnectionsSinceBoot: hcheck.params.Stats.ConnectionsSinceBoot(),
SessionsSinceBoot: hcheck.params.Stats.SessionsSinceBoot(),
MessagesSinceBoot: hcheck.params.Stats.MessagesSinceBoot(),
}
hcheck.populateDBStats(ctx, resp)
return resp
}
// populateDBStats fills in database-derived counters.
func (hcheck *Healthcheck) populateDBStats(
ctx context.Context,
resp *Response,
) {
sessions, err := hcheck.params.Database.GetUserCount(ctx)
if err != nil {
hcheck.log.Error(
"healthcheck: session count failed",
"error", err,
)
} else {
resp.Sessions = sessions
}
clients, err := hcheck.params.Database.GetClientCount(ctx)
if err != nil {
hcheck.log.Error(
"healthcheck: client count failed",
"error", err,
)
} else {
resp.Clients = clients
}
queued, err := hcheck.params.Database.GetQueueEntryCount(ctx)
if err != nil {
hcheck.log.Error(
"healthcheck: queue entry count failed",
"error", err,
)
} else {
resp.QueuedLines = queued
}
channels, err := hcheck.params.Database.GetChannelCount(ctx)
if err != nil {
hcheck.log.Error(
"healthcheck: channel count failed",
"error", err,
)
} else {
resp.Channels = channels
}
}

View File

@@ -126,23 +126,18 @@ func (mware *Middleware) Logging() func(http.Handler) http.Handler {
}
// CORS returns middleware that handles Cross-Origin Resource Sharing.
// AllowCredentials is true so browsers include cookies in
// cross-origin API requests.
func (mware *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields
AllowOriginFunc: func(
_ *http.Request, _ string,
) bool {
return true
},
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
"GET", "POST", "PUT", "DELETE", "OPTIONS",
},
AllowedHeaders: []string{
"Accept", "Content-Type", "X-CSRF-Token",
"Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
AllowCredentials: false,
MaxAge: corsMaxAge,
})
}

View File

@@ -1,122 +0,0 @@
// Package ratelimit provides per-IP rate limiting for HTTP endpoints.
package ratelimit
import (
"sync"
"time"
"golang.org/x/time/rate"
)
const (
// DefaultRate is the default number of allowed requests per second.
DefaultRate = 1.0
// DefaultBurst is the default maximum burst size.
DefaultBurst = 5
// DefaultSweepInterval controls how often stale entries are pruned.
DefaultSweepInterval = 10 * time.Minute
// DefaultEntryTTL is how long an unused entry lives before eviction.
DefaultEntryTTL = 15 * time.Minute
)
// entry tracks a per-IP rate limiter and when it was last used.
type entry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// Limiter manages per-key rate limiters with automatic cleanup
// of stale entries.
type Limiter struct {
mu sync.Mutex
entries map[string]*entry
rate rate.Limit
burst int
entryTTL time.Duration
stopCh chan struct{}
}
// New creates a new per-key rate Limiter.
// The ratePerSec parameter sets how many requests per second are
// allowed per key. The burst parameter sets the maximum number of
// requests that can be made in a single burst.
func New(ratePerSec float64, burst int) *Limiter {
limiter := &Limiter{
mu: sync.Mutex{},
entries: make(map[string]*entry),
rate: rate.Limit(ratePerSec),
burst: burst,
entryTTL: DefaultEntryTTL,
stopCh: make(chan struct{}),
}
go limiter.sweepLoop()
return limiter
}
// Allow reports whether a request from the given key should be
// allowed. It consumes one token from the key's rate limiter.
func (l *Limiter) Allow(key string) bool {
l.mu.Lock()
ent, exists := l.entries[key]
if !exists {
ent = &entry{
limiter: rate.NewLimiter(l.rate, l.burst),
lastSeen: time.Now(),
}
l.entries[key] = ent
} else {
ent.lastSeen = time.Now()
}
l.mu.Unlock()
return ent.limiter.Allow()
}
// Stop terminates the background sweep goroutine.
func (l *Limiter) Stop() {
close(l.stopCh)
}
// Len returns the number of tracked keys (for testing).
func (l *Limiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.entries)
}
// sweepLoop periodically removes entries that haven't been seen
// within the TTL.
func (l *Limiter) sweepLoop() {
ticker := time.NewTicker(DefaultSweepInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.sweep()
case <-l.stopCh:
return
}
}
}
// sweep removes stale entries.
func (l *Limiter) sweep() {
l.mu.Lock()
defer l.mu.Unlock()
cutoff := time.Now().Add(-l.entryTTL)
for key, ent := range l.entries {
if ent.lastSeen.Before(cutoff) {
delete(l.entries, key)
}
}
}

View File

@@ -1,106 +0,0 @@
package ratelimit_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
)
func TestNewCreatesLimiter(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter == nil {
t.Fatal("expected non-nil limiter")
}
}
func TestAllowWithinBurst(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 3)
defer limiter.Stop()
for i := range 3 {
if !limiter.Allow("192.168.1.1") {
t.Fatalf(
"request %d should be allowed within burst",
i+1,
)
}
}
}
func TestAllowExceedsBurst(t *testing.T) {
t.Parallel()
// Rate of 0 means no token replenishment, only burst.
limiter := ratelimit.New(0, 3)
defer limiter.Stop()
for range 3 {
limiter.Allow("10.0.0.1")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("fourth request should be denied after burst exhausted")
}
}
func TestAllowSeparateKeys(t *testing.T) {
t.Parallel()
// Rate of 0, burst of 1 — only one request per key.
limiter := ratelimit.New(0, 1)
defer limiter.Stop()
if !limiter.Allow("10.0.0.1") {
t.Fatal("first request for key A should be allowed")
}
if !limiter.Allow("10.0.0.2") {
t.Fatal("first request for key B should be allowed")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("second request for key A should be denied")
}
if limiter.Allow("10.0.0.2") {
t.Fatal("second request for key B should be denied")
}
}
func TestLenTracksKeys(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter.Len() != 0 {
t.Fatalf("expected 0 entries, got %d", limiter.Len())
}
limiter.Allow("10.0.0.1")
limiter.Allow("10.0.0.2")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
// Same key again should not increase count.
limiter.Allow("10.0.0.1")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
}
func TestStopDoesNotPanic(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
limiter.Stop()
}

View File

@@ -75,6 +75,10 @@ func (srv *Server) setupAPIv1(router chi.Router) {
"/session",
srv.handlers.HandleCreateSession(),
)
router.Post(
"/register",
srv.handlers.HandleRegister(),
)
router.Post(
"/login",
srv.handlers.HandleLogin(),

View File

@@ -1,52 +0,0 @@
// Package stats tracks runtime statistics since server boot.
package stats
import (
"sync/atomic"
)
// Tracker holds atomic counters for runtime statistics
// that accumulate since the server started.
type Tracker struct {
connectionsSinceBoot atomic.Int64
sessionsSinceBoot atomic.Int64
messagesSinceBoot atomic.Int64
}
// New creates a new Tracker with all counters at zero.
func New() *Tracker {
return &Tracker{} //nolint:exhaustruct // atomic fields have zero-value defaults
}
// IncrConnections increments the total connection count.
func (t *Tracker) IncrConnections() {
t.connectionsSinceBoot.Add(1)
}
// IncrSessions increments the total session count.
func (t *Tracker) IncrSessions() {
t.sessionsSinceBoot.Add(1)
}
// IncrMessages increments the total PRIVMSG/NOTICE count.
func (t *Tracker) IncrMessages() {
t.messagesSinceBoot.Add(1)
}
// ConnectionsSinceBoot returns the total number of
// client connections since boot.
func (t *Tracker) ConnectionsSinceBoot() int64 {
return t.connectionsSinceBoot.Load()
}
// SessionsSinceBoot returns the total number of sessions
// created since boot.
func (t *Tracker) SessionsSinceBoot() int64 {
return t.sessionsSinceBoot.Load()
}
// MessagesSinceBoot returns the total number of
// PRIVMSG/NOTICE messages sent since boot.
func (t *Tracker) MessagesSinceBoot() int64 {
return t.messagesSinceBoot.Load()
}

View File

@@ -1,117 +0,0 @@
package stats_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/stats"
)
func TestNew(t *testing.T) {
t.Parallel()
tracker := stats.New()
if tracker == nil {
t.Fatal("expected non-nil tracker")
}
if tracker.ConnectionsSinceBoot() != 0 {
t.Errorf(
"expected 0 connections, got %d",
tracker.ConnectionsSinceBoot(),
)
}
if tracker.SessionsSinceBoot() != 0 {
t.Errorf(
"expected 0 sessions, got %d",
tracker.SessionsSinceBoot(),
)
}
if tracker.MessagesSinceBoot() != 0 {
t.Errorf(
"expected 0 messages, got %d",
tracker.MessagesSinceBoot(),
)
}
}
func TestIncrConnections(t *testing.T) {
t.Parallel()
tracker := stats.New()
tracker.IncrConnections()
tracker.IncrConnections()
tracker.IncrConnections()
got := tracker.ConnectionsSinceBoot()
if got != 3 {
t.Errorf(
"expected 3 connections, got %d", got,
)
}
}
func TestIncrSessions(t *testing.T) {
t.Parallel()
tracker := stats.New()
tracker.IncrSessions()
tracker.IncrSessions()
got := tracker.SessionsSinceBoot()
if got != 2 {
t.Errorf(
"expected 2 sessions, got %d", got,
)
}
}
func TestIncrMessages(t *testing.T) {
t.Parallel()
tracker := stats.New()
tracker.IncrMessages()
got := tracker.MessagesSinceBoot()
if got != 1 {
t.Errorf(
"expected 1 message, got %d", got,
)
}
}
func TestCountersAreIndependent(t *testing.T) {
t.Parallel()
tracker := stats.New()
tracker.IncrConnections()
tracker.IncrSessions()
tracker.IncrMessages()
tracker.IncrMessages()
if tracker.ConnectionsSinceBoot() != 1 {
t.Errorf(
"expected 1 connection, got %d",
tracker.ConnectionsSinceBoot(),
)
}
if tracker.SessionsSinceBoot() != 1 {
t.Errorf(
"expected 1 session, got %d",
tracker.SessionsSinceBoot(),
)
}
if tracker.MessagesSinceBoot() != 2 {
t.Errorf(
"expected 2 messages, got %d",
tracker.MessagesSinceBoot(),
)
}
}

View File

@@ -2,13 +2,8 @@ package irc
// IRC command names (RFC 1459 / RFC 2812).
const (
CmdAdmin = "ADMIN"
CmdAway = "AWAY"
CmdInfo = "INFO"
CmdInvite = "INVITE"
CmdJoin = "JOIN"
CmdKick = "KICK"
CmdKill = "KILL"
CmdList = "LIST"
CmdLusers = "LUSERS"
CmdMode = "MODE"
@@ -16,18 +11,12 @@ const (
CmdNames = "NAMES"
CmdNick = "NICK"
CmdNotice = "NOTICE"
CmdOper = "OPER"
CmdPass = "PASS"
CmdPart = "PART"
CmdPing = "PING"
CmdPong = "PONG"
CmdPrivmsg = "PRIVMSG"
CmdQuit = "QUIT"
CmdTime = "TIME"
CmdTopic = "TOPIC"
CmdUserhost = "USERHOST"
CmdVersion = "VERSION"
CmdWallops = "WALLOPS"
CmdWho = "WHO"
CmdWhois = "WHOIS"
)

View File

@@ -132,7 +132,6 @@ const (
RplNoTopic IRCMessageType = 331
RplTopic IRCMessageType = 332
RplTopicWhoTime IRCMessageType = 333
RplWhoisActually IRCMessageType = 338
RplInviting IRCMessageType = 341
RplSummoning IRCMessageType = 342
RplInviteList IRCMessageType = 346
@@ -296,7 +295,6 @@ var names = map[IRCMessageType]string{
RplNoTopic: "RPL_NOTOPIC",
RplTopic: "RPL_TOPIC",
RplTopicWhoTime: "RPL_TOPICWHOTIME",
RplWhoisActually: "RPL_WHOISACTUALLY",
RplInviting: "RPL_INVITING",
RplSummoning: "RPL_SUMMONING",
RplInviteList: "RPL_INVITELIST",