2 Commits

Author SHA1 Message Date
clawbot
d0e925bf70 docs: add reverse proxy security note to login rate limiting section
All checks were successful
check / check (push) Successful in 1m8s
2026-03-17 04:49:49 -07:00
user
e519ffa1e6 feat: add per-IP rate limiting to login endpoint
Add a token-bucket rate limiter (golang.org/x/time/rate) that limits
login attempts per client IP on POST /api/v1/login. Returns 429 Too
Many Requests with a Retry-After header when the limit is exceeded.

Configurable via LOGIN_RATE_LIMIT (requests/sec, default 1) and
LOGIN_RATE_BURST (burst size, default 5). Stale per-IP entries are
automatically cleaned up every 10 minutes.

Only the login endpoint is rate-limited per sneak's instruction —
session creation and registration use hashcash proof-of-work instead.
2026-03-17 04:48:46 -07:00
9 changed files with 508 additions and 58 deletions

102
README.md
View File

@@ -301,8 +301,8 @@ The server implements HTTP long-polling for real-time message delivery:
- The client disconnects (connection closed, no response needed) - The client disconnects (connection closed, no response needed)
**Implementation detail:** The server maintains an in-memory broker with **Implementation detail:** The server maintains an in-memory broker with
per-client notification channels. When a message is enqueued for a client, the per-user notification channels. When a message is enqueued for a user, the
broker closes all waiting channels for that client, waking up any blocked broker closes all waiting channels for that user, waking up any blocked
long-poll handlers. This is O(1) notification — no polling loops, no database long-poll handlers. This is O(1) notification — no polling loops, no database
scanning. scanning.
@@ -460,28 +460,28 @@ The entire read/write loop for a client is two endpoints. Everything else
│ │ │ │ │ │
┌─────────▼──┐ ┌───────▼────┐ ┌──────▼─────┐ ┌─────────▼──┐ ┌───────▼────┐ ┌──────▼─────┐
│client_queue│ │client_queue│ │client_queue│ │client_queue│ │client_queue│ │client_queue│
client_id=1│ │ client_id=2│ │ client_id=3│ user_id=1 │ │ user_id=2 │ │ user_id=3
│ msg_id=N │ │ msg_id=N │ │ msg_id=N │ │ msg_id=N │ │ msg_id=N │ │ msg_id=N │
└────────────┘ └────────────┘ └────────────┘ └────────────┘ └────────────┘ └────────────┘
alice bob carol alice bob carol
Each message is stored ONCE. One queue entry per recipient client. Each message is stored ONCE. One queue entry per recipient.
``` ```
The `client_queues` table contains `(client_id, message_id)` pairs. When a The `client_queues` table contains `(user_id, message_id)` pairs. When a
client polls with `GET /messages?after=<queue_id>`, the server queries for client polls with `GET /messages?after=<queue_id>`, the server queries for
queue entries with `id > after` for that client, joins against the messages queue entries with `id > after` for that user, joins against the messages
table, and returns the results. The `queue_id` (auto-incrementing primary table, and returns the results. The `queue_id` (auto-incrementing primary
key of `client_queues`) serves as a monotonically increasing cursor. key of `client_queues`) serves as a monotonically increasing cursor.
### In-Memory Broker ### In-Memory Broker
The server maintains an in-memory notification broker to avoid database The server maintains an in-memory notification broker to avoid database
polling. The broker is a map of `client_id → []chan struct{}`. When a message polling. The broker is a map of `user_id → []chan struct{}`. When a message
is enqueued for a client: is enqueued for a user:
1. The handler calls `broker.Notify(clientID)` 1. The handler calls `broker.Notify(userID)`
2. The broker closes all waiting channels for that client 2. The broker closes all waiting channels for that user
3. Any goroutines blocked in `select` on those channels wake up 3. Any goroutines blocked in `select` on those channels wake up
4. The woken handler queries the database for new queue entries 4. The woken handler queries the database for new queue entries
5. Messages are returned to the client 5. Messages are returned to the client
@@ -1937,50 +1937,44 @@ The database schema is managed via embedded SQL migration files in
#### `sessions` #### `sessions`
| Column | Type | Description | | Column | Type | Description |
|-----------------|----------|-------------| |----------------|----------|-------------|
| `id` | INTEGER | Primary key (auto-increment) | | `id` | INTEGER | Primary key (auto-increment) |
| `uuid` | TEXT | Unique session UUID | | `uuid` | TEXT | Unique session UUID |
| `nick` | TEXT | Unique nick | | `nick` | TEXT | Unique nick |
| `password_hash` | TEXT | bcrypt hash (empty string for anonymous sessions) | | `password_hash`| TEXT | bcrypt hash (empty string for anonymous sessions) |
| `signing_key` | TEXT | Public signing key (empty string if unset) | | `signing_key` | TEXT | Public signing key (empty string if unset) |
| `away_message` | TEXT | Away message (empty string if not away) | | `away_message` | TEXT | Away message (empty string if not away) |
| `created_at` | DATETIME | Session creation time | | `created_at` | DATETIME | Session creation time |
| `last_seen` | DATETIME | Last API request time | | `last_seen` | DATETIME | Last API request time |
Index on `(uuid)`.
#### `clients` #### `clients`
| Column | Type | Description | | Column | Type | Description |
|--------------|----------|-------------| |-------------|----------|-------------|
| `id` | INTEGER | Primary key (auto-increment) | | `id` | INTEGER | Primary key (auto-increment) |
| `uuid` | TEXT | Unique client UUID | | `uuid` | TEXT | Unique client UUID |
| `session_id` | INTEGER | FK → sessions.id (cascade delete) | | `session_id`| INTEGER | FK → sessions.id (cascade delete) |
| `token` | TEXT | Unique auth token (SHA-256 hash of 64 hex chars) | | `token` | TEXT | Unique auth token (SHA-256 hash of 64 hex chars) |
| `created_at` | DATETIME | Client creation time | | `created_at`| DATETIME | Client creation time |
| `last_seen` | DATETIME | Last API request time | | `last_seen` | DATETIME | Last API request time |
Indexes on `(token)` and `(session_id)`.
#### `channels` #### `channels`
| Column | Type | Description | | Column | Type | Description |
|---------------|----------|-------------| |-------------|----------|-------------|
| `id` | INTEGER | Primary key (auto-increment) | | `id` | INTEGER | Primary key (auto-increment) |
| `name` | TEXT | Unique channel name (e.g., `#general`) | | `name` | TEXT | Unique channel name (e.g., `#general`) |
| `topic` | TEXT | Channel topic (default empty) | | `topic` | TEXT | Channel topic (default empty) |
| `topic_set_by`| TEXT | Nick of the user who set the topic (default empty) | | `created_at`| DATETIME | Channel creation time |
| `topic_set_at`| DATETIME | When the topic was last set | | `updated_at`| DATETIME | Last modification time |
| `created_at` | DATETIME | Channel creation time |
| `updated_at` | DATETIME | Last modification time |
#### `channel_members` #### `channel_members`
| Column | Type | Description | | Column | Type | Description |
|--------------|----------|-------------| |-------------|----------|-------------|
| `id` | INTEGER | Primary key (auto-increment) | | `id` | INTEGER | Primary key (auto-increment) |
| `channel_id` | INTEGER | FK → channels.id (cascade delete) | | `channel_id`| INTEGER | FK → channels.id |
| `session_id` | INTEGER | FK → sessions.id (cascade delete) | | `user_id` | INTEGER | FK → users.id |
| `joined_at` | DATETIME | When the user joined | | `joined_at` | DATETIME | When the user joined |
Unique constraint on `(channel_id, session_id)`. Unique constraint on `(channel_id, user_id)`.
#### `messages` #### `messages`
| Column | Type | Description | | Column | Type | Description |
@@ -1990,7 +1984,6 @@ Unique constraint on `(channel_id, session_id)`.
| `command` | TEXT | IRC command (`PRIVMSG`, `JOIN`, etc.) | | `command` | TEXT | IRC command (`PRIVMSG`, `JOIN`, etc.) |
| `msg_from` | TEXT | Sender nick | | `msg_from` | TEXT | Sender nick |
| `msg_to` | TEXT | Target (`#channel` or nick) | | `msg_to` | TEXT | Target (`#channel` or nick) |
| `params` | TEXT | JSON-encoded IRC-style positional parameters |
| `body` | TEXT | JSON-encoded body (array or object) | | `body` | TEXT | JSON-encoded body (array or object) |
| `meta` | TEXT | JSON-encoded metadata | | `meta` | TEXT | JSON-encoded metadata |
| `created_at`| DATETIME | Server timestamp | | `created_at`| DATETIME | Server timestamp |
@@ -2001,11 +1994,11 @@ Indexes on `(msg_to, id)` and `(created_at)`.
| Column | Type | Description | | Column | Type | Description |
|-------------|----------|-------------| |-------------|----------|-------------|
| `id` | INTEGER | Primary key (auto-increment). Used as the poll cursor. | | `id` | INTEGER | Primary key (auto-increment). Used as the poll cursor. |
| `client_id` | INTEGER | FK → clients.id (cascade delete) | | `user_id` | INTEGER | FK → users.id |
| `message_id`| INTEGER | FK → messages.id (cascade delete) | | `message_id`| INTEGER | FK → messages.id |
| `created_at`| DATETIME | When the entry was queued | | `created_at`| DATETIME | When the entry was queued |
Unique constraint on `(client_id, message_id)`. Index on `(client_id, id)`. Unique constraint on `(user_id, message_id)`. Index on `(user_id, id)`.
The `client_queues.id` is the monotonically increasing cursor used by The `client_queues.id` is the monotonically increasing cursor used by
`GET /messages?after=<id>`. This is more reliable than timestamps (no clock `GET /messages?after=<id>`. This is more reliable than timestamps (no clock
@@ -2058,6 +2051,8 @@ directory is also loaded automatically via
| `METRICS_USERNAME` | string | `""` | Basic auth username for `/metrics` endpoint. If empty, metrics endpoint is disabled. | | `METRICS_USERNAME` | string | `""` | Basic auth username for `/metrics` endpoint. If empty, metrics endpoint is disabled. |
| `METRICS_PASSWORD` | string | `""` | Basic auth password for `/metrics` endpoint | | `METRICS_PASSWORD` | string | `""` | Basic auth password for `/metrics` endpoint |
| `NEOIRC_HASHCASH_BITS` | int | `20` | Required hashcash proof-of-work difficulty (leading zero bits in SHA-256) for session creation. Set to `0` to disable. | | `NEOIRC_HASHCASH_BITS` | int | `20` | Required hashcash proof-of-work difficulty (leading zero bits in SHA-256) for session creation. Set to `0` to disable. |
| `LOGIN_RATE_LIMIT` | float | `1` | Allowed login attempts per second per IP address. |
| `LOGIN_RATE_BURST` | int | `5` | Maximum burst of login attempts per IP before rate limiting kicks in. |
| `MAINTENANCE_MODE` | bool | `false` | Maintenance mode flag (reserved) | | `MAINTENANCE_MODE` | bool | `false` | Maintenance mode flag (reserved) |
### Example `.env` file ### Example `.env` file
@@ -2462,6 +2457,49 @@ creating one session pays once and keeps their session.
- **Language-agnostic**: SHA-256 is available in every programming language. - **Language-agnostic**: SHA-256 is available in every programming language.
The proof computation is trivially implementable in any client. The proof computation is trivially implementable in any client.
### Login Rate Limiting
The login endpoint (`POST /api/v1/login`) has per-IP rate limiting to prevent
brute-force password attacks. This uses a token-bucket algorithm
(`golang.org/x/time/rate`) with configurable rate and burst.
| Environment Variable | Default | Description |
|---------------------|---------|-------------|
| `LOGIN_RATE_LIMIT` | `1` | Allowed login attempts per second per IP |
| `LOGIN_RATE_BURST` | `5` | Maximum burst of login attempts per IP |
When the limit is exceeded, the server returns **429 Too Many Requests** with a
`Retry-After: 1` header. Stale per-IP entries are automatically cleaned up
every 10 minutes.
> **⚠️ Security: Reverse Proxy Required for Production Use**
>
> The rate limiter extracts the client IP by checking the `X-Forwarded-For`
> header first, then `X-Real-IP`, and finally falling back to the TCP
> `RemoteAddr`. Both `X-Forwarded-For` and `X-Real-IP` are **client-controlled
> request headers** — any client can set them to arbitrary values.
>
> Without a properly configured reverse proxy in front of this server:
>
> - An attacker can **bypass rate limiting entirely** by rotating
> `X-Forwarded-For` values on each request (each value is treated as a
> distinct IP).
> - An attacker can **deny service to a specific user** by spoofing that user's
> IP in the `X-Forwarded-For` header, exhausting their rate limit bucket.
>
> **Recommendation:** Always deploy behind a reverse proxy (e.g. nginx, Caddy,
> Traefik) that strips or overwrites incoming `X-Forwarded-For` and `X-Real-IP`
> headers with the actual client IP. If running without a reverse proxy, be
> aware that the rate limiting provides no meaningful protection against a
> targeted attack.
**Why rate limits here but not on session creation?** Session creation is
protected by hashcash proof-of-work (stateless, no IP tracking needed). Login
involves bcrypt password verification against a registered account — a
fundamentally different threat model where an attacker targets a specific
account. Per-IP rate limiting is appropriate here because the cost of a wrong
guess is borne by the server (bcrypt), not the client.
--- ---
## Roadmap ## Roadmap

1
go.mod
View File

@@ -16,6 +16,7 @@ require (
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0 golang.org/x/crypto v0.48.0
golang.org/x/time v0.6.0
modernc.org/sqlite v1.45.0 modernc.org/sqlite v1.45.0
) )

2
go.sum
View File

@@ -151,6 +151,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@@ -46,6 +46,8 @@ type Config struct {
FederationKey string FederationKey string
SessionIdleTimeout string SessionIdleTimeout string
HashcashBits int HashcashBits int
LoginRateLimit float64
LoginRateBurst int
params *Params params *Params
log *slog.Logger log *slog.Logger
} }
@@ -78,6 +80,8 @@ func New(
viper.SetDefault("FEDERATION_KEY", "") viper.SetDefault("FEDERATION_KEY", "")
viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h") viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h")
viper.SetDefault("NEOIRC_HASHCASH_BITS", "20") viper.SetDefault("NEOIRC_HASHCASH_BITS", "20")
viper.SetDefault("LOGIN_RATE_LIMIT", "1")
viper.SetDefault("LOGIN_RATE_BURST", "5")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
@@ -104,6 +108,8 @@ func New(
FederationKey: viper.GetString("FEDERATION_KEY"), FederationKey: viper.GetString("FEDERATION_KEY"),
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"), SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"), HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"),
LoginRateLimit: viper.GetFloat64("LOGIN_RATE_LIMIT"),
LoginRateBurst: viper.GetInt("LOGIN_RATE_BURST"),
log: log, log: log,
params: &params, params: &params,
} }

View File

@@ -2130,6 +2130,121 @@ func TestSessionStillWorks(t *testing.T) {
} }
} }
func TestLoginRateLimitExceeded(t *testing.T) {
tserver := newTestServer(t)
// Exhaust the burst (default: 5 per IP) using
// nonexistent users. These fail fast (no bcrypt),
// preventing token replenishment between requests.
for range 5 {
loginBody, mErr := json.Marshal(
map[string]string{
"nick": "nosuchuser",
"password": "doesnotmatter",
},
)
if mErr != nil {
t.Fatal(mErr)
}
loginResp, rErr := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/login"),
bytes.NewReader(loginBody),
)
if rErr != nil {
t.Fatal(rErr)
}
_ = loginResp.Body.Close()
}
// The next request should be rate-limited.
loginBody, err := json.Marshal(map[string]string{
"nick": "nosuchuser", "password": "doesnotmatter",
})
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/login"),
bytes.NewReader(loginBody),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusTooManyRequests {
t.Fatalf(
"expected 429, got %d",
resp.StatusCode,
)
}
retryAfter := resp.Header.Get("Retry-After")
if retryAfter == "" {
t.Fatal("expected Retry-After header")
}
}
func TestLoginRateLimitAllowsNormalUse(t *testing.T) {
tserver := newTestServer(t)
// Register a user.
regBody, err := json.Marshal(map[string]string{
"nick": "normaluser", "password": "password123",
})
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/register"),
bytes.NewReader(regBody),
)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()
// A single login should succeed without rate limiting.
loginBody, err := json.Marshal(map[string]string{
"nick": "normaluser", "password": "password123",
})
if err != nil {
t.Fatal(err)
}
resp2, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/login"),
bytes.NewReader(loginBody),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp2.Body.Close() }()
if resp2.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp2.Body)
t.Fatalf(
"expected 200, got %d: %s",
resp2.StatusCode, respBody,
)
}
}
func TestNickBroadcastToChannels(t *testing.T) { func TestNickBroadcastToChannels(t *testing.T) {
tserver := newTestServer(t) tserver := newTestServer(t)
aliceToken := tserver.createSession("nick_a") aliceToken := tserver.createSession("nick_a")

View File

@@ -2,6 +2,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"net"
"net/http" "net/http"
"strings" "strings"
@@ -10,6 +11,33 @@ import (
const minPasswordLength = 8 const minPasswordLength = 8
// clientIP extracts the client IP address from the request.
// It checks X-Forwarded-For and X-Real-IP headers before
// falling back to RemoteAddr.
func clientIP(request *http.Request) string {
if forwarded := request.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For may contain a comma-separated list;
// the first entry is the original client.
parts := strings.SplitN(forwarded, ",", 2) //nolint:mnd // split into two parts
ip := strings.TrimSpace(parts[0])
if ip != "" {
return ip
}
}
if realIP := request.Header.Get("X-Real-IP"); realIP != "" {
return strings.TrimSpace(realIP)
}
host, _, err := net.SplitHostPort(request.RemoteAddr)
if err != nil {
return request.RemoteAddr
}
return host
}
// HandleRegister creates a new user with a password. // HandleRegister creates a new user with a password.
func (hdlr *Handlers) HandleRegister() http.HandlerFunc { func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
return func( return func(
@@ -137,6 +165,21 @@ func (hdlr *Handlers) handleLogin(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
ip := clientIP(request)
if !hdlr.loginLimiter.Allow(ip) {
writer.Header().Set(
"Retry-After", "1",
)
hdlr.respondError(
writer, request,
"too many login attempts, try again later",
http.StatusTooManyRequests,
)
return
}
type loginRequest struct { type loginRequest struct {
Nick string `json:"nick"` Nick string `json:"nick"`
Password string `json:"password"` Password string `json:"password"`

View File

@@ -16,6 +16,7 @@ import (
"git.eeqj.de/sneak/neoirc/internal/hashcash" "git.eeqj.de/sneak/neoirc/internal/hashcash"
"git.eeqj.de/sneak/neoirc/internal/healthcheck" "git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/logger" "git.eeqj.de/sneak/neoirc/internal/logger"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
"git.eeqj.de/sneak/neoirc/internal/stats" "git.eeqj.de/sneak/neoirc/internal/stats"
"go.uber.org/fx" "go.uber.org/fx"
) )
@@ -43,6 +44,7 @@ type Handlers struct {
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
broker *broker.Broker broker *broker.Broker
hashcashVal *hashcash.Validator hashcashVal *hashcash.Validator
loginLimiter *ratelimit.Limiter
stats *stats.Tracker stats *stats.Tracker
cancelCleanup context.CancelFunc cancelCleanup context.CancelFunc
} }
@@ -57,12 +59,23 @@ func New(
resource = "neoirc" resource = "neoirc"
} }
loginRate := params.Config.LoginRateLimit
if loginRate <= 0 {
loginRate = ratelimit.DefaultRate
}
loginBurst := params.Config.LoginRateBurst
if loginBurst <= 0 {
loginBurst = ratelimit.DefaultBurst
}
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params, params: &params,
log: params.Logger.Get(), log: params.Logger.Get(),
hc: params.Healthcheck, hc: params.Healthcheck,
broker: broker.New(), broker: broker.New(),
hashcashVal: hashcash.NewValidator(resource), hashcashVal: hashcash.NewValidator(resource),
loginLimiter: ratelimit.New(loginRate, loginBurst),
stats: params.Stats, stats: params.Stats,
} }
@@ -155,6 +168,10 @@ func (hdlr *Handlers) stopCleanup() {
if hdlr.cancelCleanup != nil { if hdlr.cancelCleanup != nil {
hdlr.cancelCleanup() hdlr.cancelCleanup()
} }
if hdlr.loginLimiter != nil {
hdlr.loginLimiter.Stop()
}
} }
func (hdlr *Handlers) cleanupLoop(ctx context.Context) { func (hdlr *Handlers) cleanupLoop(ctx context.Context) {

View File

@@ -0,0 +1,122 @@
// Package ratelimit provides per-IP rate limiting for HTTP endpoints.
package ratelimit
import (
"sync"
"time"
"golang.org/x/time/rate"
)
const (
// DefaultRate is the default number of allowed requests per second.
DefaultRate = 1.0
// DefaultBurst is the default maximum burst size.
DefaultBurst = 5
// DefaultSweepInterval controls how often stale entries are pruned.
DefaultSweepInterval = 10 * time.Minute
// DefaultEntryTTL is how long an unused entry lives before eviction.
DefaultEntryTTL = 15 * time.Minute
)
// entry tracks a per-IP rate limiter and when it was last used.
type entry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// Limiter manages per-key rate limiters with automatic cleanup
// of stale entries.
type Limiter struct {
mu sync.Mutex
entries map[string]*entry
rate rate.Limit
burst int
entryTTL time.Duration
stopCh chan struct{}
}
// New creates a new per-key rate Limiter.
// The ratePerSec parameter sets how many requests per second are
// allowed per key. The burst parameter sets the maximum number of
// requests that can be made in a single burst.
func New(ratePerSec float64, burst int) *Limiter {
limiter := &Limiter{
mu: sync.Mutex{},
entries: make(map[string]*entry),
rate: rate.Limit(ratePerSec),
burst: burst,
entryTTL: DefaultEntryTTL,
stopCh: make(chan struct{}),
}
go limiter.sweepLoop()
return limiter
}
// Allow reports whether a request from the given key should be
// allowed. It consumes one token from the key's rate limiter.
func (l *Limiter) Allow(key string) bool {
l.mu.Lock()
ent, exists := l.entries[key]
if !exists {
ent = &entry{
limiter: rate.NewLimiter(l.rate, l.burst),
lastSeen: time.Now(),
}
l.entries[key] = ent
} else {
ent.lastSeen = time.Now()
}
l.mu.Unlock()
return ent.limiter.Allow()
}
// Stop terminates the background sweep goroutine.
func (l *Limiter) Stop() {
close(l.stopCh)
}
// Len returns the number of tracked keys (for testing).
func (l *Limiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.entries)
}
// sweepLoop periodically removes entries that haven't been seen
// within the TTL.
func (l *Limiter) sweepLoop() {
ticker := time.NewTicker(DefaultSweepInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.sweep()
case <-l.stopCh:
return
}
}
}
// sweep removes stale entries.
func (l *Limiter) sweep() {
l.mu.Lock()
defer l.mu.Unlock()
cutoff := time.Now().Add(-l.entryTTL)
for key, ent := range l.entries {
if ent.lastSeen.Before(cutoff) {
delete(l.entries, key)
}
}
}

View File

@@ -0,0 +1,106 @@
package ratelimit_test
import (
"testing"
"git.eeqj.de/sneak/neoirc/internal/ratelimit"
)
func TestNewCreatesLimiter(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter == nil {
t.Fatal("expected non-nil limiter")
}
}
func TestAllowWithinBurst(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 3)
defer limiter.Stop()
for i := range 3 {
if !limiter.Allow("192.168.1.1") {
t.Fatalf(
"request %d should be allowed within burst",
i+1,
)
}
}
}
func TestAllowExceedsBurst(t *testing.T) {
t.Parallel()
// Rate of 0 means no token replenishment, only burst.
limiter := ratelimit.New(0, 3)
defer limiter.Stop()
for range 3 {
limiter.Allow("10.0.0.1")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("fourth request should be denied after burst exhausted")
}
}
func TestAllowSeparateKeys(t *testing.T) {
t.Parallel()
// Rate of 0, burst of 1 — only one request per key.
limiter := ratelimit.New(0, 1)
defer limiter.Stop()
if !limiter.Allow("10.0.0.1") {
t.Fatal("first request for key A should be allowed")
}
if !limiter.Allow("10.0.0.2") {
t.Fatal("first request for key B should be allowed")
}
if limiter.Allow("10.0.0.1") {
t.Fatal("second request for key A should be denied")
}
if limiter.Allow("10.0.0.2") {
t.Fatal("second request for key B should be denied")
}
}
func TestLenTracksKeys(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
defer limiter.Stop()
if limiter.Len() != 0 {
t.Fatalf("expected 0 entries, got %d", limiter.Len())
}
limiter.Allow("10.0.0.1")
limiter.Allow("10.0.0.2")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
// Same key again should not increase count.
limiter.Allow("10.0.0.1")
if limiter.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", limiter.Len())
}
}
func TestStopDoesNotPanic(t *testing.T) {
t.Parallel()
limiter := ratelimit.New(1.0, 5)
limiter.Stop()
}