2 Commits

Author SHA1 Message Date
clawbot
d0e925bf70 docs: add reverse proxy security note to login rate limiting section
All checks were successful
check / check (push) Successful in 1m8s
2026-03-17 04:49:49 -07:00
user
e519ffa1e6 feat: add per-IP rate limiting to login endpoint
Add a token-bucket rate limiter (golang.org/x/time/rate) that limits
login attempts per client IP on POST /api/v1/login. Returns 429 Too
Many Requests with a Retry-After header when the limit is exceeded.

Configurable via LOGIN_RATE_LIMIT (requests/sec, default 1) and
LOGIN_RATE_BURST (burst size, default 5). Stale per-IP entries are
automatically cleaned up every 10 minutes.

Only the login endpoint is rate-limited per sneak's instruction —
session creation and registration use hashcash proof-of-work instead.
2026-03-17 04:48:46 -07:00
23 changed files with 1287 additions and 2160 deletions

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks ensure-web-dist .PHONY: all build lint fmt fmt-check test check clean run debug docker hooks
BINARY := neoircd BINARY := neoircd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -7,21 +7,10 @@ LDFLAGS := -X main.Version=$(VERSION) -X main.Buildarch=$(BUILDARCH)
all: check build all: check build
# ensure-web-dist creates placeholder files so //go:embed dist/* in build:
# web/embed.go resolves without a full Node.js build. The real SPA is
# built by the web-builder Docker stage; these placeholders let
# "make test" and "make build" work outside Docker.
ensure-web-dist:
@if [ ! -d web/dist ]; then \
mkdir -p web/dist && \
touch web/dist/index.html web/dist/style.css web/dist/app.js && \
echo "==> Created placeholder web/dist/ for go:embed"; \
fi
build: ensure-web-dist
go build -ldflags "$(LDFLAGS)" -o bin/$(BINARY) ./cmd/neoircd go build -ldflags "$(LDFLAGS)" -o bin/$(BINARY) ./cmd/neoircd
lint: ensure-web-dist lint:
golangci-lint run --config .golangci.yml ./... golangci-lint run --config .golangci.yml ./...
fmt: fmt:
@@ -31,7 +20,7 @@ fmt:
fmt-check: fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1) @test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: ensure-web-dist test:
go test -timeout 30s -v -race -cover ./... go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes # check runs all validation without making changes

663
README.md

File diff suppressed because it is too large Load Diff

1
go.mod
View File

@@ -16,6 +16,7 @@ require (
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0 golang.org/x/crypto v0.48.0
golang.org/x/time v0.6.0
modernc.org/sqlite v1.45.0 modernc.org/sqlite v1.45.0
) )

2
go.sum
View File

@@ -151,6 +151,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/cookiejar"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@@ -29,19 +28,16 @@ var errHTTP = errors.New("HTTP error")
// Client wraps HTTP calls to the neoirc server API. // Client wraps HTTP calls to the neoirc server API.
type Client struct { type Client struct {
BaseURL string BaseURL string
Token string
HTTPClient *http.Client HTTPClient *http.Client
} }
// NewClient creates a new API client with a cookie jar // NewClient creates a new API client.
// for automatic auth cookie management.
func NewClient(baseURL string) *Client { func NewClient(baseURL string) *Client {
jar, _ := cookiejar.New(nil) return &Client{ //nolint:exhaustruct // Token set after CreateSession
return &Client{
BaseURL: baseURL, BaseURL: baseURL,
HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: httpTimeout, Timeout: httpTimeout,
Jar: jar,
}, },
} }
} }
@@ -83,6 +79,8 @@ func (client *Client) CreateSession(
return nil, fmt.Errorf("decode session: %w", err) return nil, fmt.Errorf("decode session: %w", err)
} }
client.Token = resp.Token
return &resp, nil return &resp, nil
} }
@@ -123,7 +121,6 @@ func (client *Client) PollMessages(
Timeout: time.Duration( Timeout: time.Duration(
timeout+pollExtraTime, timeout+pollExtraTime,
) * time.Second, ) * time.Second,
Jar: client.HTTPClient.Jar,
} }
params := url.Values{} params := url.Values{}
@@ -148,6 +145,10 @@ func (client *Client) PollMessages(
return nil, fmt.Errorf("new request: %w", err) return nil, fmt.Errorf("new request: %w", err)
} }
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
resp, err := pollClient.Do(request) resp, err := pollClient.Do(request)
if err != nil { if err != nil {
return nil, fmt.Errorf("poll request: %w", err) return nil, fmt.Errorf("poll request: %w", err)
@@ -303,6 +304,12 @@ func (client *Client) do(
"Content-Type", "application/json", "Content-Type", "application/json",
) )
if client.Token != "" {
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
}
resp, err := client.HTTPClient.Do(request) resp, err := client.HTTPClient.Do(request)
if err != nil { if err != nil {
return nil, fmt.Errorf("http: %w", err) return nil, fmt.Errorf("http: %w", err)

View File

@@ -7,8 +7,6 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"time" "time"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
) )
const ( const (
@@ -39,23 +37,6 @@ func MintHashcash(bits int, resource string) string {
} }
} }
// MintChannelHashcash computes a hashcash stamp bound to
// a specific channel and message body. The stamp format
// is 1:bits:YYMMDD:channel:bodyhash:counter where
// bodyhash is the hex-encoded SHA-256 of the message
// body bytes. Delegates to the internal/hashcash package.
func MintChannelHashcash(
bits int,
channel string,
body []byte,
) string {
bodyHash := hashcash.BodyHash(body)
return hashcash.MintChannelStamp(
bits, channel, bodyHash,
)
}
// hasLeadingZeroBits checks if hash has at least numBits // hasLeadingZeroBits checks if hash has at least numBits
// leading zero bits. // leading zero bits.
func hasLeadingZeroBits( func hasLeadingZeroBits(

View File

@@ -10,8 +10,9 @@ type SessionRequest struct {
// SessionResponse is the response from session creation. // SessionResponse is the response from session creation.
type SessionResponse struct { type SessionResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Nick string `json:"nick"` Nick string `json:"nick"`
Token string `json:"token"`
} }
// StateResponse is the response from GET /api/v1/state. // StateResponse is the response from GET /api/v1/state.

View File

@@ -46,6 +46,8 @@ type Config struct {
FederationKey string FederationKey string
SessionIdleTimeout string SessionIdleTimeout string
HashcashBits int HashcashBits int
LoginRateLimit float64
LoginRateBurst int
params *Params params *Params
log *slog.Logger log *slog.Logger
} }
@@ -78,6 +80,8 @@ func New(
viper.SetDefault("FEDERATION_KEY", "") viper.SetDefault("FEDERATION_KEY", "")
viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h") viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h")
viper.SetDefault("NEOIRC_HASHCASH_BITS", "20") viper.SetDefault("NEOIRC_HASHCASH_BITS", "20")
viper.SetDefault("LOGIN_RATE_LIMIT", "1")
viper.SetDefault("LOGIN_RATE_BURST", "5")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
@@ -104,6 +108,8 @@ func New(
FederationKey: viper.GetString("FEDERATION_KEY"), FederationKey: viper.GetString("FEDERATION_KEY"),
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"), SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"), HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"),
LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"),
LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"),
log: log, log: log,
params: &params, params: &params,
} }

View File

@@ -16,28 +16,80 @@ var errNoPassword = errors.New(
"account has no password set", "account has no password set",
) )
// SetPassword sets a bcrypt-hashed password on a session, // RegisterUser creates a session with a hashed password
// enabling multi-client login via POST /api/v1/login. // and returns session ID, client ID, and token.
func (database *Database) SetPassword( func (database *Database) RegisterUser(
ctx context.Context, ctx context.Context,
sessionID int64, nick, password string,
password string, ) (int64, int64, string, error) {
) error {
hash, err := bcrypt.GenerateFromPassword( hash, err := bcrypt.GenerateFromPassword(
[]byte(password), bcryptCost, []byte(password), bcryptCost,
) )
if err != nil { if err != nil {
return fmt.Errorf("hash password: %w", err) return 0, 0, "", fmt.Errorf(
"hash password: %w", err,
)
} }
_, err = database.conn.ExecContext(ctx, sessionUUID := uuid.New().String()
"UPDATE sessions SET password_hash = ? WHERE id = ?", clientUUID := uuid.New().String()
string(hash), sessionID)
token, err := generateToken()
if err != nil { if err != nil {
return fmt.Errorf("set password: %w", err) return 0, 0, "", err
} }
return nil now := time.Now()
transaction, err := database.conn.BeginTx(ctx, nil)
if err != nil {
return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
}
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, password_hash,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
sessionUUID, nick, string(hash), now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
tokenHash := hashToken(token)
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit registration: %w", err,
)
}
return sessionID, clientID, token, nil
} }
// LoginUser verifies a nick/password and creates a new // LoginUser verifies a nick/password and creates a new

View File

@@ -6,65 +6,63 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func TestSetPassword(t *testing.T) { func TestRegisterUser(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := sessionID, clientID, token, err :=
database.CreateSession(ctx, "passuser") database.RegisterUser(ctx, "reguser", "password123")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.SetPassword( if sessionID == 0 || clientID == 0 || token == "" {
ctx, sessionID, "password123",
)
if err != nil {
t.Fatal(err)
}
// Verify we can now log in with the password.
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "passuser", "password123")
if err != nil {
t.Fatal(err)
}
if loginSID == 0 || loginCID == 0 || loginToken == "" {
t.Fatal("expected valid ids and token") t.Fatal("expected valid ids and token")
} }
// Verify session works via token lookup.
sid, cid, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if sid != sessionID || cid != clientID {
t.Fatal("session/client id mismatch")
}
if nick != "reguser" {
t.Fatalf("expected reguser, got %s", nick)
}
} }
func TestSetPasswordThenWrongLogin(t *testing.T) { func TestRegisterUserDuplicateNick(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := regSID, regCID, regToken, err :=
database.CreateSession(ctx, "wrongpw") database.RegisterUser(ctx, "dupnick", "password123")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.SetPassword( _ = regSID
ctx, sessionID, "correctpass", _ = regCID
) _ = regToken
if err != nil {
t.Fatal(err) dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345")
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
} }
loginSID, loginCID, loginToken, loginErr := _ = dupSID
database.LoginUser(ctx, "wrongpw", "wrongpass12") _ = dupCID
if loginErr == nil { _ = dupToken
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
} }
func TestLoginUser(t *testing.T) { func TestLoginUser(t *testing.T) {
@@ -73,26 +71,23 @@ func TestLoginUser(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := regSID, regCID, regToken, err :=
database.CreateSession(ctx, "loginuser") database.RegisterUser(ctx, "loginuser", "mypassword")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.SetPassword( _ = regSID
ctx, sessionID, "mypassword", _ = regCID
) _ = regToken
if err != nil {
t.Fatal(err)
}
loginSID, loginCID, token, err := sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword") database.LoginUser(ctx, "loginuser", "mypassword")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if loginSID == 0 || loginCID == 0 || token == "" { if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token") t.Fatal("expected valid ids and token")
} }
@@ -108,6 +103,33 @@ func TestLoginUser(t *testing.T) {
} }
} }
func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
}
func TestLoginUserNoPassword(t *testing.T) { func TestLoginUserNoPassword(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -1305,110 +1305,3 @@ func (database *Database) GetQueueEntryCount(
return count, nil return count, nil
} }
// GetChannelHashcashBits returns the hashcash difficulty
// requirement for a channel. Returns 0 if not set.
func (database *Database) GetChannelHashcashBits(
ctx context.Context,
channelID int64,
) (int, error) {
var bits int
err := database.conn.QueryRowContext(
ctx,
"SELECT hashcash_bits FROM channels WHERE id = ?",
channelID,
).Scan(&bits)
if err != nil {
return 0, fmt.Errorf(
"get channel hashcash bits: %w", err,
)
}
return bits, nil
}
// SetChannelHashcashBits sets the hashcash difficulty
// requirement for a channel. A value of 0 disables the
// requirement.
func (database *Database) SetChannelHashcashBits(
ctx context.Context,
channelID int64,
bits int,
) error {
_, err := database.conn.ExecContext(ctx,
`UPDATE channels
SET hashcash_bits = ?, updated_at = ?
WHERE id = ?`,
bits, time.Now(), channelID)
if err != nil {
return fmt.Errorf(
"set channel hashcash bits: %w", err,
)
}
return nil
}
// RecordSpentHashcash stores a spent hashcash stamp hash
// for replay prevention.
func (database *Database) RecordSpentHashcash(
ctx context.Context,
stampHash string,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO spent_hashcash
(stamp_hash, created_at)
VALUES (?, ?)`,
stampHash, time.Now())
if err != nil {
return fmt.Errorf(
"record spent hashcash: %w", err,
)
}
return nil
}
// IsHashcashSpent checks whether a hashcash stamp hash
// has already been used.
func (database *Database) IsHashcashSpent(
ctx context.Context,
stampHash string,
) (bool, error) {
var count int
err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM spent_hashcash
WHERE stamp_hash = ?`,
stampHash,
).Scan(&count)
if err != nil {
return false, fmt.Errorf(
"check spent hashcash: %w", err,
)
}
return count > 0, nil
}
// PruneSpentHashcash deletes spent hashcash tokens older
// than the cutoff and returns the number of rows removed.
func (database *Database) PruneSpentHashcash(
ctx context.Context,
cutoff time.Time,
) (int64, error) {
res, err := database.conn.ExecContext(ctx,
"DELETE FROM spent_hashcash WHERE created_at < ?",
cutoff,
)
if err != nil {
return 0, fmt.Errorf(
"prune spent hashcash: %w", err,
)
}
deleted, _ := res.RowsAffected()
return deleted, nil
}

View File

@@ -33,7 +33,6 @@ CREATE TABLE IF NOT EXISTS channels (
topic TEXT NOT NULL DEFAULT '', topic TEXT NOT NULL DEFAULT '',
topic_set_by TEXT NOT NULL DEFAULT '', topic_set_by TEXT NOT NULL DEFAULT '',
topic_set_at DATETIME, topic_set_at DATETIME,
hashcash_bits INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
); );
@@ -62,14 +61,6 @@ CREATE TABLE IF NOT EXISTS messages (
CREATE INDEX IF NOT EXISTS idx_messages_to_id ON messages(msg_to, id); CREATE INDEX IF NOT EXISTS idx_messages_to_id ON messages(msg_to, id);
CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at); CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
-- Spent hashcash tokens for replay prevention (1-year TTL)
CREATE TABLE IF NOT EXISTS spent_hashcash (
id INTEGER PRIMARY KEY AUTOINCREMENT,
stamp_hash TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_spent_hashcash_created ON spent_hashcash(created_at);
-- Per-client message queues for fan-out delivery -- Per-client message queues for fan-out delivery
CREATE TABLE IF NOT EXISTS client_queues ( CREATE TABLE IF NOT EXISTS client_queues (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,

View File

@@ -3,7 +3,6 @@ package handlers
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp" "regexp"
@@ -12,16 +11,10 @@ import (
"time" "time"
"git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
"git.eeqj.de/sneak/neoirc/pkg/irc" "git.eeqj.de/sneak/neoirc/pkg/irc"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )
var (
errHashcashRequired = errors.New("hashcash required")
errHashcashReused = errors.New("hashcash reused")
)
var validNickRe = regexp.MustCompile( var validNickRe = regexp.MustCompile(
`^[a-zA-Z_][a-zA-Z0-9_\-\[\]\\^{}|` + "`" + `]{0,31}$`, `^[a-zA-Z_][a-zA-Z0-9_\-\[\]\\^{}|` + "`" + `]{0,31}$`,
) )
@@ -36,7 +29,6 @@ const (
defaultMaxBodySize = 4096 defaultMaxBodySize = 4096
defaultHistLimit = 50 defaultHistLimit = 50
maxHistLimit = 500 maxHistLimit = 500
authCookieName = "neoirc_auth"
) )
func (hdlr *Handlers) maxBodySize() int64 { func (hdlr *Handlers) maxBodySize() int64 {
@@ -47,18 +39,23 @@ func (hdlr *Handlers) maxBodySize() int64 {
return defaultMaxBodySize return defaultMaxBodySize
} }
// authSession extracts the session from the auth cookie. // authSession extracts the session from the client token.
func (hdlr *Handlers) authSession( func (hdlr *Handlers) authSession(
request *http.Request, request *http.Request,
) (int64, int64, string, error) { ) (int64, int64, string, error) {
cookie, err := request.Cookie(authCookieName) auth := request.Header.Get("Authorization")
if err != nil || cookie.Value == "" { if !strings.HasPrefix(auth, "Bearer ") {
return 0, 0, "", errUnauthorized
}
token := strings.TrimPrefix(auth, "Bearer ")
if token == "" {
return 0, 0, "", errUnauthorized return 0, 0, "", errUnauthorized
} }
sessionID, clientID, nick, err := sessionID, clientID, nick, err :=
hdlr.params.Database.GetSessionByToken( hdlr.params.Database.GetSessionByToken(
request.Context(), cookie.Value, request.Context(), token,
) )
if err != nil { if err != nil {
return 0, 0, "", fmt.Errorf("auth: %w", err) return 0, 0, "", fmt.Errorf("auth: %w", err)
@@ -67,46 +64,6 @@ func (hdlr *Handlers) authSession(
return sessionID, clientID, nick, nil return sessionID, clientID, nick, nil
} }
// setAuthCookie sets the authentication cookie on the
// response.
func (hdlr *Handlers) setAuthCookie(
writer http.ResponseWriter,
request *http.Request,
token string,
) {
secure := request.TLS != nil ||
request.Header.Get("X-Forwarded-Proto") == "https"
http.SetCookie(writer, &http.Cookie{ //nolint:exhaustruct // optional fields
Name: authCookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
})
}
// clearAuthCookie removes the authentication cookie from
// the client.
func (hdlr *Handlers) clearAuthCookie(
writer http.ResponseWriter,
request *http.Request,
) {
secure := request.TLS != nil ||
request.Header.Get("X-Forwarded-Proto") == "https"
http.SetCookie(writer, &http.Cookie{ //nolint:exhaustruct // optional fields
Name: authCookieName,
Value: "",
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
MaxAge: -1,
})
}
func (hdlr *Handlers) requireAuth( func (hdlr *Handlers) requireAuth(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
@@ -131,11 +88,10 @@ func (hdlr *Handlers) fanOut(
request *http.Request, request *http.Request,
command, from, target string, command, from, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
sessionIDs []int64, sessionIDs []int64,
) (string, error) { ) (string, error) {
dbID, msgUUID, err := hdlr.params.Database.InsertMessage( dbID, msgUUID, err := hdlr.params.Database.InsertMessage(
request.Context(), command, from, target, nil, body, meta, request.Context(), command, from, target, nil, body, nil,
) )
if err != nil { if err != nil {
return "", fmt.Errorf("insert message: %w", err) return "", fmt.Errorf("insert message: %w", err)
@@ -161,11 +117,10 @@ func (hdlr *Handlers) fanOutSilent(
request *http.Request, request *http.Request,
command, from, target string, command, from, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
sessionIDs []int64, sessionIDs []int64,
) error { ) error {
_, err := hdlr.fanOut( _, err := hdlr.fanOut(
request, command, from, target, body, meta, sessionIDs, request, command, from, target, body, sessionIDs,
) )
return err return err
@@ -262,11 +217,10 @@ func (hdlr *Handlers) handleCreateSession(
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick) hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick)
hdlr.setAuthCookie(writer, request, token)
hdlr.respondJSON(writer, request, map[string]any{ hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID, "id": sessionID,
"nick": payload.Nick, "nick": payload.Nick,
"token": token,
}, http.StatusCreated) }, http.StatusCreated)
} }
@@ -340,7 +294,7 @@ func (hdlr *Handlers) deliverWelcome(
[]string{ []string{
"CHANTYPES=#", "CHANTYPES=#",
"NICKLEN=32", "NICKLEN=32",
"CHANMODES=,,H," + "imnst", "CHANMODES=,,," + "imnst",
"NETWORK=neoirc", "NETWORK=neoirc",
"CASEMAPPING=ascii", "CASEMAPPING=ascii",
}, },
@@ -871,7 +825,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
writer, request, writer, request,
sessionID, clientID, nick, sessionID, clientID, nick,
payload.Command, payload.To, payload.Command, payload.To,
payload.Body, payload.Meta, bodyLines, payload.Body, bodyLines,
) )
} }
} }
@@ -882,7 +836,6 @@ func (hdlr *Handlers) dispatchCommand(
sessionID, clientID int64, sessionID, clientID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
bodyLines func() []string, bodyLines func() []string,
) { ) {
switch command { switch command {
@@ -895,7 +848,7 @@ func (hdlr *Handlers) dispatchCommand(
hdlr.handlePrivmsg( hdlr.handlePrivmsg(
writer, request, writer, request,
sessionID, clientID, nick, sessionID, clientID, nick,
command, target, body, meta, bodyLines, command, target, body, bodyLines,
) )
case irc.CmdJoin: case irc.CmdJoin:
hdlr.handleJoin( hdlr.handleJoin(
@@ -912,11 +865,6 @@ func (hdlr *Handlers) dispatchCommand(
writer, request, writer, request,
sessionID, clientID, nick, bodyLines, sessionID, clientID, nick, bodyLines,
) )
case irc.CmdPass:
hdlr.handlePass(
writer, request,
sessionID, clientID, nick, bodyLines,
)
case irc.CmdTopic: case irc.CmdTopic:
hdlr.handleTopic( hdlr.handleTopic(
writer, request, writer, request,
@@ -1001,7 +949,6 @@ func (hdlr *Handlers) handlePrivmsg(
sessionID, clientID int64, sessionID, clientID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
bodyLines func() []string, bodyLines func() []string,
) { ) {
if target == "" { if target == "" {
@@ -1039,7 +986,7 @@ func (hdlr *Handlers) handlePrivmsg(
hdlr.handleChannelMsg( hdlr.handleChannelMsg(
writer, request, writer, request,
sessionID, clientID, nick, sessionID, clientID, nick,
command, target, body, meta, command, target, body,
) )
return return
@@ -1048,7 +995,7 @@ func (hdlr *Handlers) handlePrivmsg(
hdlr.handleDirectMsg( hdlr.handleDirectMsg(
writer, request, writer, request,
sessionID, clientID, nick, sessionID, clientID, nick,
command, target, body, meta, command, target, body,
) )
} }
@@ -1079,7 +1026,6 @@ func (hdlr *Handlers) handleChannelMsg(
sessionID, clientID int64, sessionID, clientID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
) { ) {
chID, err := hdlr.params.Database.GetChannelByName( chID, err := hdlr.params.Database.GetChannelByName(
request.Context(), target, request.Context(), target,
@@ -1120,180 +1066,16 @@ func (hdlr *Handlers) handleChannelMsg(
return return
} }
hashcashErr := hdlr.validateChannelHashcash(
request, clientID, sessionID,
writer, nick, target, body, meta, chID,
)
if hashcashErr != nil {
return
}
hdlr.sendChannelMsg( hdlr.sendChannelMsg(
writer, request, command, nick, target, writer, request, command, nick, target, body, chID,
body, meta, chID,
) )
} }
// validateChannelHashcash checks whether the channel
// requires hashcash proof-of-work for messages and
// validates the stamp from the message meta field.
// Returns nil on success or if the channel has no
// hashcash requirement. On failure, it sends the
// appropriate IRC error and returns a non-nil error.
func (hdlr *Handlers) validateChannelHashcash(
request *http.Request,
clientID, sessionID int64,
writer http.ResponseWriter,
nick, target string,
body json.RawMessage,
meta json.RawMessage,
chID int64,
) error {
ctx := request.Context()
bits, bitsErr := hdlr.params.Database.GetChannelHashcashBits(
ctx, chID,
)
if bitsErr != nil {
hdlr.log.Error(
"get channel hashcash bits", "error", bitsErr,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return fmt.Errorf("channel hashcash bits: %w", bitsErr)
}
if bits <= 0 {
return nil
}
stamp := hdlr.extractHashcashFromMeta(meta)
if stamp == "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrCannotSendToChan, nick, []string{target},
"Channel requires hashcash proof-of-work",
)
return errHashcashRequired
}
return hdlr.verifyChannelStamp(
request, writer,
clientID, sessionID,
nick, target, body, stamp, bits,
)
}
// verifyChannelStamp validates a channel hashcash stamp
// and checks for replay attacks.
func (hdlr *Handlers) verifyChannelStamp(
request *http.Request,
writer http.ResponseWriter,
clientID, sessionID int64,
nick, target string,
body json.RawMessage,
stamp string,
bits int,
) error {
ctx := request.Context()
bodyHashStr := hashcash.BodyHash(body)
valErr := hdlr.channelHashcash.ValidateStamp(
stamp, bits, target, bodyHashStr,
)
if valErr != nil {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrCannotSendToChan, nick, []string{target},
"Invalid hashcash: "+valErr.Error(),
)
return fmt.Errorf("channel hashcash: %w", valErr)
}
stampKey := hashcash.StampHash(stamp)
spent, spentErr := hdlr.params.Database.IsHashcashSpent(
ctx, stampKey,
)
if spentErr != nil {
hdlr.log.Error(
"check spent hashcash", "error", spentErr,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return fmt.Errorf("check spent hashcash: %w", spentErr)
}
if spent {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrCannotSendToChan, nick, []string{target},
"Hashcash stamp already used",
)
return errHashcashReused
}
recordErr := hdlr.params.Database.RecordSpentHashcash(
ctx, stampKey,
)
if recordErr != nil {
hdlr.log.Error(
"record spent hashcash", "error", recordErr,
)
}
return nil
}
// extractHashcashFromMeta parses the meta JSON and
// returns the hashcash stamp string, or empty string
// if not present.
func (hdlr *Handlers) extractHashcashFromMeta(
meta json.RawMessage,
) string {
if len(meta) == 0 {
return ""
}
var metaMap map[string]json.RawMessage
err := json.Unmarshal(meta, &metaMap)
if err != nil {
return ""
}
raw, ok := metaMap["hashcash"]
if !ok {
return ""
}
var stamp string
err = json.Unmarshal(raw, &stamp)
if err != nil {
return ""
}
return stamp
}
func (hdlr *Handlers) sendChannelMsg( func (hdlr *Handlers) sendChannelMsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
command, nick, target string, command, nick, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
chID int64, chID int64,
) { ) {
memberIDs, err := hdlr.params.Database.GetChannelMemberIDs( memberIDs, err := hdlr.params.Database.GetChannelMemberIDs(
@@ -1313,7 +1095,7 @@ func (hdlr *Handlers) sendChannelMsg(
} }
msgUUID, err := hdlr.fanOut( msgUUID, err := hdlr.fanOut(
request, command, nick, target, body, meta, memberIDs, request, command, nick, target, body, memberIDs,
) )
if err != nil { if err != nil {
hdlr.log.Error("send message failed", "error", err) hdlr.log.Error("send message failed", "error", err)
@@ -1337,7 +1119,6 @@ func (hdlr *Handlers) handleDirectMsg(
sessionID, clientID int64, sessionID, clientID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
meta json.RawMessage,
) { ) {
targetSID, err := hdlr.params.Database.GetSessionByNick( targetSID, err := hdlr.params.Database.GetSessionByNick(
request.Context(), target, request.Context(), target,
@@ -1362,7 +1143,7 @@ func (hdlr *Handlers) handleDirectMsg(
} }
msgUUID, err := hdlr.fanOut( msgUUID, err := hdlr.fanOut(
request, command, nick, target, body, meta, recipients, request, command, nick, target, body, recipients,
) )
if err != nil { if err != nil {
hdlr.log.Error("send dm failed", "error", err) hdlr.log.Error("send dm failed", "error", err)
@@ -1473,7 +1254,7 @@ func (hdlr *Handlers) executeJoin(
) )
_ = hdlr.fanOutSilent( _ = hdlr.fanOutSilent(
request, irc.CmdJoin, nick, channel, nil, nil, memberIDs, request, irc.CmdJoin, nick, channel, nil, memberIDs,
) )
hdlr.deliverJoinNumerics( hdlr.deliverJoinNumerics(
@@ -1643,7 +1424,7 @@ func (hdlr *Handlers) handlePart(
) )
_ = hdlr.fanOutSilent( _ = hdlr.fanOutSilent(
request, irc.CmdPart, nick, channel, body, nil, memberIDs, request, irc.CmdPart, nick, channel, body, memberIDs,
) )
err = hdlr.params.Database.PartChannel( err = hdlr.params.Database.PartChannel(
@@ -1923,7 +1704,7 @@ func (hdlr *Handlers) executeTopic(
) )
_ = hdlr.fanOutSilent( _ = hdlr.fanOutSilent(
request, irc.CmdTopic, nick, channel, body, nil, memberIDs, request, irc.CmdTopic, nick, channel, body, memberIDs,
) )
hdlr.enqueueNumeric( hdlr.enqueueNumeric(
@@ -2047,8 +1828,6 @@ func (hdlr *Handlers) handleQuit(
request.Context(), sessionID, request.Context(), sessionID,
) )
hdlr.clearAuthCookie(writer, request)
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
map[string]string{"status": "quit"}, map[string]string{"status": "quit"},
http.StatusOK) http.StatusOK)
@@ -2088,10 +1867,11 @@ func (hdlr *Handlers) handleMode(
return return
} }
_ = bodyLines
hdlr.handleChannelMode( hdlr.handleChannelMode(
writer, request, writer, request,
sessionID, clientID, nick, channel, sessionID, clientID, nick, channel,
bodyLines,
) )
} }
@@ -2100,7 +1880,6 @@ func (hdlr *Handlers) handleChannelMode(
request *http.Request, request *http.Request,
sessionID, clientID int64, sessionID, clientID int64,
nick, channel string, nick, channel string,
bodyLines func() []string,
) { ) {
ctx := request.Context() ctx := request.Context()
@@ -2117,47 +1896,10 @@ func (hdlr *Handlers) handleChannelMode(
return return
} }
lines := bodyLines()
if len(lines) > 0 {
hdlr.applyChannelMode(
writer, request,
sessionID, clientID, nick,
channel, chID, lines,
)
return
}
hdlr.queryChannelMode(
writer, request,
sessionID, clientID, nick, channel, chID,
)
}
// queryChannelMode sends RPL_CHANNELMODEIS and
// RPL_CREATIONTIME for a channel. Includes +H if
// the channel has a hashcash requirement.
func (hdlr *Handlers) queryChannelMode(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, channel string,
chID int64,
) {
ctx := request.Context()
modeStr := "+n"
bits, bitsErr := hdlr.params.Database.
GetChannelHashcashBits(ctx, chID)
if bitsErr == nil && bits > 0 {
modeStr = fmt.Sprintf("+nH %d", bits)
}
// 324 RPL_CHANNELMODEIS // 324 RPL_CHANNELMODEIS
hdlr.enqueueNumeric( hdlr.enqueueNumeric(
ctx, clientID, irc.RplChannelModeIs, nick, ctx, clientID, irc.RplChannelModeIs, nick,
[]string{channel, modeStr}, "", []string{channel, "+n"}, "",
) )
// 329 RPL_CREATIONTIME // 329 RPL_CREATIONTIME
@@ -2182,156 +1924,6 @@ func (hdlr *Handlers) queryChannelMode(
http.StatusOK) http.StatusOK)
} }
// applyChannelMode handles setting channel modes.
// Currently supports +H/-H for hashcash bits.
func (hdlr *Handlers) applyChannelMode(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, channel string,
chID int64,
modeArgs []string,
) {
ctx := request.Context()
modeStr := modeArgs[0]
switch modeStr {
case "+H":
hdlr.setHashcashMode(
writer, request,
sessionID, clientID, nick,
channel, chID, modeArgs,
)
case "-H":
hdlr.clearHashcashMode(
writer, request,
sessionID, clientID, nick,
channel, chID,
)
default:
// Unknown or unsupported mode change.
hdlr.enqueueNumeric(
ctx, clientID, irc.ErrUnknownMode, nick,
[]string{modeStr},
"is unknown mode char to me",
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "error"},
http.StatusOK)
}
}
const (
// minHashcashBits is the minimum allowed hashcash
// difficulty for channels.
minHashcashBits = 1
// maxHashcashBits is the maximum allowed hashcash
// difficulty for channels.
maxHashcashBits = 40
)
// setHashcashMode handles MODE #channel +H <bits>.
func (hdlr *Handlers) setHashcashMode(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, channel string,
chID int64,
modeArgs []string,
) {
ctx := request.Context()
if len(modeArgs) < 2 { //nolint:mnd // +H requires a bits arg
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick, []string{irc.CmdMode},
"Not enough parameters (+H requires bits)",
)
return
}
bits, err := strconv.Atoi(modeArgs[1])
if err != nil || bits < minHashcashBits ||
bits > maxHashcashBits {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrUnknownMode, nick, []string{"+H"},
fmt.Sprintf(
"Invalid hashcash bits (must be %d-%d)",
minHashcashBits, maxHashcashBits,
),
)
return
}
err = hdlr.params.Database.SetChannelHashcashBits(
ctx, chID, bits,
)
if err != nil {
hdlr.log.Error(
"set channel hashcash bits", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.enqueueNumeric(
ctx, clientID, irc.RplChannelModeIs, nick,
[]string{
channel,
fmt.Sprintf("+H %d", bits),
}, "",
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// clearHashcashMode handles MODE #channel -H.
func (hdlr *Handlers) clearHashcashMode(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick, channel string,
chID int64,
) {
ctx := request.Context()
err := hdlr.params.Database.SetChannelHashcashBits(
ctx, chID, 0,
)
if err != nil {
hdlr.log.Error(
"clear channel hashcash bits", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.enqueueNumeric(
ctx, clientID, irc.RplChannelModeIs, nick,
[]string{channel, "+n"}, "",
)
hdlr.broker.Notify(sessionID)
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
// handleNames sends NAMES reply for a channel. // handleNames sends NAMES reply for a channel.
func (hdlr *Handlers) handleNames( func (hdlr *Handlers) handleNames(
writer http.ResponseWriter, writer http.ResponseWriter,
@@ -2851,8 +2443,6 @@ func (hdlr *Handlers) HandleLogout() http.HandlerFunc {
) )
} }
hdlr.clearAuthCookie(writer, request)
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"}, map[string]string{"status": "ok"},
http.StatusOK) http.StatusOK)

File diff suppressed because it is too large Load Diff

View File

@@ -2,14 +2,151 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"net"
"net/http" "net/http"
"strings" "strings"
"git.eeqj.de/sneak/neoirc/pkg/irc" "git.eeqj.de/sneak/neoirc/internal/db"
) )
const minPasswordLength = 8 const minPasswordLength = 8
// clientIP extracts the client IP address from the request.
// It checks X-Forwarded-For and X-Real-IP headers before
// falling back to RemoteAddr.
func clientIP(request *http.Request) string {
if forwarded := request.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For may contain a comma-separated list;
// the first entry is the original client.
parts := strings.SplitN(forwarded, ",", 2) //nolint:mnd // split into two parts
ip := strings.TrimSpace(parts[0])
if ip != "" {
return ip
}
}
if realIP := request.Header.Get("X-Real-IP"); realIP != "" {
return strings.TrimSpace(realIP)
}
host, _, err := net.SplitHostPort(request.RemoteAddr)
if err != nil {
return request.RemoteAddr
}
return host
}
// HandleRegister creates a new user with a password.
func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
request.Body = http.MaxBytesReader(
writer, request.Body, hdlr.maxBodySize(),
)
hdlr.handleRegister(writer, request)
}
}
func (hdlr *Handlers) handleRegister(
writer http.ResponseWriter,
request *http.Request,
) {
type registerRequest struct {
Nick string `json:"nick"`
Password string `json:"password"`
}
var payload registerRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
if len(payload.Password) < minPasswordLength {
hdlr.respondError(
writer, request,
"password must be at least 8 characters",
http.StatusBadRequest,
)
return
}
sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser(
request.Context(),
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.handleRegisterError(
writer, request, err,
)
return
}
hdlr.stats.IncrSessions()
hdlr.stats.IncrConnections()
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleRegisterError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if db.IsUniqueConstraintError(err) {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"register user failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
}
// HandleLogin authenticates a user with nick and password. // HandleLogin authenticates a user with nick and password.
func (hdlr *Handlers) HandleLogin() http.HandlerFunc { func (hdlr *Handlers) HandleLogin() http.HandlerFunc {
return func( return func(
@@ -28,6 +165,21 @@ func (hdlr *Handlers) handleLogin(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
ip := clientIP(request)
if !hdlr.loginLimiter.Allow(ip) {
writer.Header().Set(
"Retry-After", "1",
)
hdlr.respondError(
writer, request,
"too many login attempts, try again later",
http.StatusTooManyRequests,
)
return
}
type loginRequest struct { type loginRequest struct {
Nick string `json:"nick"` Nick string `json:"nick"`
Password string `json:"password"` Password string `json:"password"`
@@ -86,66 +238,9 @@ func (hdlr *Handlers) handleLogin(
request, clientID, sessionID, payload.Nick, request, clientID, sessionID, payload.Nick,
) )
hdlr.setAuthCookie(writer, request, token)
hdlr.respondJSON(writer, request, map[string]any{ hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID, "id": sessionID,
"nick": payload.Nick, "nick": payload.Nick,
"token": token,
}, http.StatusOK) }, http.StatusOK)
} }
// handlePass handles the IRC PASS command to set a
// password on the authenticated session, enabling
// multi-client login via POST /api/v1/login.
func (hdlr *Handlers) handlePass(
writer http.ResponseWriter,
request *http.Request,
sessionID, clientID int64,
nick string,
bodyLines func() []string,
) {
lines := bodyLines()
if len(lines) == 0 || lines[0] == "" {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Not enough parameters",
)
return
}
password := lines[0]
if len(password) < minPasswordLength {
hdlr.respondIRCError(
writer, request, clientID, sessionID,
irc.ErrNeedMoreParams, nick,
[]string{irc.CmdPass},
"Password must be at least 8 characters",
)
return
}
err := hdlr.params.Database.SetPassword(
request.Context(), sessionID, password,
)
if err != nil {
hdlr.log.Error(
"set password failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}

View File

@@ -16,6 +16,7 @@ import (
"git.eeqj.de/sneak/neoirc/internal/hashcash" "git.eeqj.de/sneak/neoirc/internal/hashcash"
"git.eeqj.de/sneak/neoirc/internal/healthcheck" "git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
"git.eeqj.de/sneak/neoirc/internal/stats" "git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx" "go.uber.org/fx"
) )
@@ -36,21 +37,16 @@ type Params struct {
const defaultIdleTimeout = 30 * 24 * time.Hour const defaultIdleTimeout = 30 * 24 * time.Hour
// spentHashcashTTL is how long spent hashcash tokens are
// retained for replay prevention. Per issue requirements,
// this is 1 year.
const spentHashcashTTL = 365 * 24 * time.Hour
// Handlers manages HTTP request handling. // Handlers manages HTTP request handling.
type Handlers struct { type Handlers struct {
params *Params params *Params
log *slog.Logger log *slog.Logger
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
broker *broker.Broker broker *broker.Broker
hashcashVal *hashcash.Validator hashcashVal *hashcash.Validator
channelHashcash *hashcash.ChannelValidator loginLimiter *ratelimit.Limiter
stats *stats.Tracker stats *stats.Tracker
cancelCleanup context.CancelFunc cancelCleanup context.CancelFunc
} }
// New creates a new Handlers instance. // New creates a new Handlers instance.
@@ -63,14 +59,24 @@ func New(
resource = "neoirc" resource = "neoirc"
} }
loginRate := params.Config.LoginRateLimit
if loginRate <= 0 {
loginRate = ratelimit.DefaultRate
}
loginBurst := params.Config.LoginRateBurst
if loginBurst <= 0 {
loginBurst = ratelimit.DefaultBurst
}
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params, params: &params,
log: params.Logger.Get(), log: params.Logger.Get(),
hc: params.Healthcheck, hc: params.Healthcheck,
broker: broker.New(), broker: broker.New(),
hashcashVal: hashcash.NewValidator(resource), hashcashVal: hashcash.NewValidator(resource),
channelHashcash: hashcash.NewChannelValidator(), loginLimiter: ratelimit.New(loginRate, loginBurst),
stats: params.Stats, stats: params.Stats,
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
@@ -162,6 +168,10 @@ func (hdlr *Handlers) stopCleanup() {
if hdlr.cancelCleanup != nil { if hdlr.cancelCleanup != nil {
hdlr.cancelCleanup() hdlr.cancelCleanup()
} }
if hdlr.loginLimiter != nil {
hdlr.loginLimiter.Stop()
}
} }
func (hdlr *Handlers) cleanupLoop(ctx context.Context) { func (hdlr *Handlers) cleanupLoop(ctx context.Context) {
@@ -292,20 +302,4 @@ func (hdlr *Handlers) pruneQueuesAndMessages(
) )
} }
} }
// Prune spent hashcash tokens older than 1 year.
hashcashCutoff := time.Now().Add(-spentHashcashTTL)
pruned, err := hdlr.params.Database.
PruneSpentHashcash(ctx, hashcashCutoff)
if err != nil {
hdlr.log.Error(
"spent hashcash pruning failed", "error", err,
)
} else if pruned > 0 {
hdlr.log.Info(
"pruned spent hashcash tokens",
"deleted", pruned,
)
}
} }

View File

@@ -1,186 +0,0 @@
package hashcash
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"time"
)
var (
errBodyHashMismatch = errors.New(
"body hash mismatch",
)
errBodyHashMissing = errors.New(
"body hash missing",
)
)
// ChannelValidator checks hashcash stamps for
// per-channel PRIVMSG validation. It verifies that
// stamps are bound to a specific channel and message
// body. Replay prevention is handled externally via
// the database spent_hashcash table for persistence
// across server restarts (1-year TTL).
type ChannelValidator struct{}
// NewChannelValidator creates a ChannelValidator.
func NewChannelValidator() *ChannelValidator {
return &ChannelValidator{}
}
// BodyHash computes the hex-encoded SHA-256 hash of a
// message body for use in hashcash stamp validation.
func BodyHash(body []byte) string {
hash := sha256.Sum256(body)
return hex.EncodeToString(hash[:])
}
// ValidateStamp checks a channel hashcash stamp. It
// verifies the stamp format, difficulty, date, channel
// binding, body hash binding, and proof-of-work. Replay
// detection is NOT performed here — callers must check
// the spent_hashcash table separately.
//
// Stamp format: 1:bits:YYMMDD:channel:bodyhash:counter.
func (cv *ChannelValidator) ValidateStamp(
stamp string,
requiredBits int,
channel string,
bodyHash string,
) error {
if requiredBits <= 0 {
return nil
}
parts := strings.Split(stamp, ":")
if len(parts) != stampFields {
return fmt.Errorf(
"%w: expected %d, got %d",
errInvalidFields, stampFields, len(parts),
)
}
version := parts[0]
bitsStr := parts[1]
dateStr := parts[2]
resource := parts[3]
stampBodyHash := parts[4]
headerErr := validateChannelHeader(
version, bitsStr, resource,
requiredBits, channel,
)
if headerErr != nil {
return headerErr
}
stampTime, parseErr := parseStampDate(dateStr)
if parseErr != nil {
return parseErr
}
timeErr := validateTime(stampTime)
if timeErr != nil {
return timeErr
}
bodyErr := validateBodyHash(
stampBodyHash, bodyHash,
)
if bodyErr != nil {
return bodyErr
}
return validateProof(stamp, requiredBits)
}
// StampHash returns a deterministic hash of a stamp
// string for use as a spent-token key.
func StampHash(stamp string) string {
hash := sha256.Sum256([]byte(stamp))
return hex.EncodeToString(hash[:])
}
func validateChannelHeader(
version, bitsStr, resource string,
requiredBits int,
channel string,
) error {
if version != stampVersion {
return fmt.Errorf(
"%w: %s", errBadVersion, version,
)
}
claimedBits, err := strconv.Atoi(bitsStr)
if err != nil || claimedBits < requiredBits {
return fmt.Errorf(
"%w: need %d bits",
errInsufficientBits, requiredBits,
)
}
if resource != channel {
return fmt.Errorf(
"%w: got %q, want %q",
errWrongResource, resource, channel,
)
}
return nil
}
func validateBodyHash(
stampBodyHash, expectedBodyHash string,
) error {
if stampBodyHash == "" {
return errBodyHashMissing
}
if stampBodyHash != expectedBodyHash {
return fmt.Errorf(
"%w: got %q, want %q",
errBodyHashMismatch,
stampBodyHash, expectedBodyHash,
)
}
return nil
}
// MintChannelStamp computes a channel hashcash stamp
// with the given difficulty, channel name, and body hash.
// This is intended for clients to generate stamps before
// sending PRIVMSG to hashcash-protected channels.
//
// Stamp format: 1:bits:YYMMDD:channel:bodyhash:counter.
func MintChannelStamp(
bits int,
channel string,
bodyHash string,
) string {
date := time.Now().UTC().Format(dateFormatShort)
prefix := fmt.Sprintf(
"1:%d:%s:%s:%s:",
bits, date, channel, bodyHash,
)
counter := uint64(0)
for {
stamp := prefix + strconv.FormatUint(counter, 16)
hash := sha256.Sum256([]byte(stamp))
if hasLeadingZeroBits(hash[:], bits) {
return stamp
}
counter++
}
}

View File

@@ -1,244 +0,0 @@
package hashcash_test
import (
"crypto/sha256"
"encoding/hex"
"testing"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
)
const (
testChannel = "#general"
testBodyText = `["hello world"]`
)
func testBodyHash() string {
hash := sha256.Sum256([]byte(testBodyText))
return hex.EncodeToString(hash[:])
}
func TestChannelValidateHappyPath(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err != nil {
t.Fatalf("valid channel stamp rejected: %v", err)
}
}
func TestChannelValidateWrongChannel(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, testBits, "#other", bodyHash,
)
if err == nil {
t.Fatal("expected channel mismatch error")
}
}
func TestChannelValidateWrongBodyHash(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
wrongHash := sha256.Sum256([]byte("different body"))
wrongBodyHash := hex.EncodeToString(wrongHash[:])
err := validator.ValidateStamp(
stamp, testBits, testChannel, wrongBodyHash,
)
if err == nil {
t.Fatal("expected body hash mismatch error")
}
}
func TestChannelValidateInsufficientBits(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Mint with 2 bits but require 4.
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
err := validator.ValidateStamp(
stamp, 4, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected insufficient bits error")
}
}
func TestChannelValidateZeroBitsSkips(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
"garbage", 0, "#ch", "abc",
)
if err != nil {
t.Fatalf("zero bits should skip: %v", err)
}
}
func TestChannelValidateBadFormat(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
"not:valid", testBits, testChannel, "abc",
)
if err == nil {
t.Fatal("expected bad format error")
}
}
func TestChannelValidateBadVersion(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
stamp := "2:2:260317:#general:" + bodyHash + ":counter"
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected bad version error")
}
}
func TestChannelValidateExpiredStamp(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Mint with a very old date by manually constructing.
stamp := mintStampWithDate(
t, testBits, testChannel, "200101",
)
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected expired stamp error")
}
}
func TestChannelValidateMissingBodyHash(t *testing.T) {
t.Parallel()
validator := hashcash.NewChannelValidator()
bodyHash := testBodyHash()
// Construct a stamp with empty body hash field.
stamp := mintStampWithDate(
t, testBits, testChannel, todayDate(),
)
// This uses the session-style stamp which has empty
// ext field — body hash is missing.
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err == nil {
t.Fatal("expected missing body hash error")
}
}
func TestBodyHash(t *testing.T) {
t.Parallel()
body := []byte(`["hello world"]`)
bodyHash := hashcash.BodyHash(body)
if len(bodyHash) != 64 {
t.Fatalf(
"expected 64-char hex hash, got %d",
len(bodyHash),
)
}
// Same input should produce same hash.
bodyHash2 := hashcash.BodyHash(body)
if bodyHash != bodyHash2 {
t.Fatal("body hash not deterministic")
}
// Different input should produce different hash.
bodyHash3 := hashcash.BodyHash([]byte("different"))
if bodyHash == bodyHash3 {
t.Fatal("different inputs produced same hash")
}
}
func TestStampHash(t *testing.T) {
t.Parallel()
hash1 := hashcash.StampHash("stamp1")
hash2 := hashcash.StampHash("stamp2")
if hash1 == hash2 {
t.Fatal("different stamps produced same hash")
}
// Same input should be deterministic.
hash1b := hashcash.StampHash("stamp1")
if hash1 != hash1b {
t.Fatal("stamp hash not deterministic")
}
}
func TestMintChannelStamp(t *testing.T) {
t.Parallel()
bodyHash := testBodyHash()
stamp := hashcash.MintChannelStamp(
testBits, testChannel, bodyHash,
)
if stamp == "" {
t.Fatal("expected non-empty stamp")
}
// Validate the minted stamp.
validator := hashcash.NewChannelValidator()
err := validator.ValidateStamp(
stamp, testBits, testChannel, bodyHash,
)
if err != nil {
t.Fatalf("minted stamp failed validation: %v", err)
}
}

View File

@@ -126,23 +126,18 @@ func (mware *Middleware) Logging() func(http.Handler) http.Handler {
} }
// CORS returns middleware that handles Cross-Origin Resource Sharing. // CORS returns middleware that handles Cross-Origin Resource Sharing.
// AllowCredentials is true so browsers include cookies in
// cross-origin API requests.
func (mware *Middleware) CORS() func(http.Handler) http.Handler { func (mware *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields
AllowOriginFunc: func( AllowedOrigins: []string{"*"},
_ *http.Request, _ string,
) bool {
return true
},
AllowedMethods: []string{ AllowedMethods: []string{
"GET", "POST", "PUT", "DELETE", "OPTIONS", "GET", "POST", "PUT", "DELETE", "OPTIONS",
}, },
AllowedHeaders: []string{ AllowedHeaders: []string{
"Accept", "Content-Type", "X-CSRF-Token", "Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
}, },
ExposedHeaders: []string{"Link"}, ExposedHeaders: []string{"Link"},
AllowCredentials: true, AllowCredentials: false,
MaxAge: corsMaxAge, MaxAge: corsMaxAge,
}) })
} }

View File

@@ -0,0 +1,122 @@
// Package ratelimit provides per-IP rate limiting for HTTP endpoints.
package ratelimit
import (
"sync"
"time"
"golang.org/x/time/rate"
)
const (
// DefaultRate is the default number of allowed requests per second.
DefaultRate = 1.0
// DefaultBurst is the default maximum burst size.
DefaultBurst = 5
// DefaultSweepInterval controls how often stale entries are pruned.
DefaultSweepInterval = 10 * time.Minute
// DefaultEntryTTL is how long an unused entry lives before eviction.
DefaultEntryTTL = 15 * time.Minute
)
// entry tracks a per-IP rate limiter and when it was last used.
type entry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// Limiter manages per-key rate limiters with automatic cleanup
// of stale entries.
type Limiter struct {
mu sync.Mutex
entries map[string]*entry
rate rate.Limit
burst int
entryTTL time.Duration
stopCh chan struct{}
}
// New creates a new per-key rate Limiter.
// The ratePerSec parameter sets how many requests per second are
// allowed per key. The burst parameter sets the maximum number of
// requests that can be made in a single burst.
func New(ratePerSec float64, burst int) *Limiter {
limiter := &Limiter{
mu: sync.Mutex{},
entries: make(map[string]*entry),
rate: rate.Limit(ratePerSec),
burst: burst,
entryTTL: DefaultEntryTTL,
stopCh: make(chan struct{}),
}
go limiter.sweepLoop()
return limiter
}
// Allow reports whether a request from the given key should be
// allowed. It consumes one token from the key's rate limiter.
func (l *Limiter) Allow(key string) bool {
l.mu.Lock()
ent, exists := l.entries[key]
if !exists {
ent = &entry{
limiter: rate.NewLimiter(l.rate, l.burst),
lastSeen: time.Now(),
}
l.entries[key] = ent
} else {
ent.lastSeen = time.Now()
}
l.mu.Unlock()
return ent.limiter.Allow()
}
// Stop terminates the background sweep goroutine.
func (l *Limiter) Stop() {
close(l.stopCh)
}
// Len returns the number of tracked keys (for testing).
func (l *Limiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.entries)
}
// sweepLoop periodically removes entries that haven't been seen
// within the TTL.
func (l *Limiter) sweepLoop() {
ticker := time.NewTicker(DefaultSweepInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.sweep()
case <-l.stopCh:
return
}
}
}
// sweep removes stale entries.
func (l *Limiter) sweep() {
l.mu.Lock()
defer l.mu.Unlock()
cutoff := time.Now().Add(-l.entryTTL)
for key, ent := range l.entries {
if ent.lastSeen.Before(cutoff) {
delete(l.entries, key)
}
}
}

View File

@@ -0,0 +1,106 @@
package ratelimit_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
)
func TestNewCreatesLimiter(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter == nil {
t.Fatal("expected non-nil limiter")
}
}
func TestAllowWithinBurst(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 3)
defer limiter.Stop()
for i := range 3 {
if !limiter.Allow("192.168.1.1") {
t.Fatalf(
"request %d should be allowed within burst",
i+1,
)
}
}
}
func TestAllowExceedsBurst(t *testing.T) {
t.Parallel()
// Rate of 0 means no token replenishment, only burst.
limiter := ratelimit.New(0, 3)
defer limiter.Stop()
for range 3 {
limiter.Allow("10.0.0.1")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("fourth request should be denied after burst exhausted")
}
}
func TestAllowSeparateKeys(t *testing.T) {
t.Parallel()
// Rate of 0, burst of 1 — only one request per key.
limiter := ratelimit.New(0, 1)
defer limiter.Stop()
if !limiter.Allow("10.0.0.1") {
t.Fatal("first request for key A should be allowed")
}
if !limiter.Allow("10.0.0.2") {
t.Fatal("first request for key B should be allowed")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("second request for key A should be denied")
}
if limiter.Allow("10.0.0.2") {
t.Fatal("second request for key B should be denied")
}
}
func TestLenTracksKeys(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter.Len() != 0 {
t.Fatalf("expected 0 entries, got %d", limiter.Len())
}
limiter.Allow("10.0.0.1")
limiter.Allow("10.0.0.2")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
// Same key again should not increase count.
limiter.Allow("10.0.0.1")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
}
func TestStopDoesNotPanic(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
limiter.Stop()
}

View File

@@ -75,6 +75,10 @@ func (srv *Server) setupAPIv1(router chi.Router) {
"/session", "/session",
srv.handlers.HandleCreateSession(), srv.handlers.HandleCreateSession(),
) )
router.Post(
"/register",
srv.handlers.HandleRegister(),
)
router.Post( router.Post(
"/login", "/login",
srv.handlers.HandleLogin(), srv.handlers.HandleLogin(),

View File

@@ -11,7 +11,6 @@ const (
CmdNames = "NAMES" CmdNames = "NAMES"
CmdNick = "NICK" CmdNick = "NICK"
CmdNotice = "NOTICE" CmdNotice = "NOTICE"
CmdPass = "PASS"
CmdPart = "PART" CmdPart = "PART"
CmdPing = "PING" CmdPing = "PING"
CmdPong = "PONG" CmdPong = "PONG"