18 Commits

Author SHA1 Message Date
clawbot
3b42620749 feat: split Dockerfile into dedicated lint stage
All checks were successful
check / check (push) Successful in 6s
Use pre-built golangci/golangci-lint:v2.1.6 image for fast lint feedback
instead of installing golangci-lint from source on every build.

- Lint stage: runs fmt-check and lint using pre-built image
- Build stage: runs tests and compiles binaries
- COPY --from=lint forces BuildKit to execute the lint stage
- All images pinned by sha256 digest
- Runtime stage unchanged
2026-03-02 00:03:04 -08:00
cd909d59c4 Merge pull request 'feat: logout, users/me, user count, session timeout' (#24) from feature/mvp-remaining into main
All checks were successful
check / check (push) Successful in 1m58s
Reviewed-on: #24
2026-03-01 15:47:03 +01:00
clawbot
f5cc098b7b docs: update README for new endpoints, fix config name, remove dead field
All checks were successful
check / check (push) Successful in 1m24s
- Document POST /api/v1/logout endpoint
- Document GET /api/v1/users/me endpoint
- Add 'users' field to GET /api/v1/server response docs
- Fix config: SESSION_TIMEOUT -> SESSION_IDLE_TIMEOUT
- Update storage section: session expiry is implemented
- Update roadmap: move session expiry to implemented
- Remove dead SessionTimeout config field from Go code
2026-03-01 06:41:10 -08:00
user
4d7b7618b2 fix: send QUIT notifications for background idle cleanup
All checks were successful
check / check (push) Successful in 2m2s
The background idle cleanup (DeleteStaleUsers) was removing stale
clients/sessions directly via SQL without sending QUIT notifications
to channel members. This caused timed-out users to silently disappear
from channels.

Now runCleanup identifies sessions that will be orphaned by the stale
client deletion and calls cleanupUser for each one first, ensuring
QUIT messages are sent to all channel members — matching the explicit
logout behavior.

Also refactored cleanupUser to accept context.Context instead of
*http.Request so it can be called from both HTTP handlers and the
background cleanup goroutine.
2026-03-01 06:33:15 -08:00
user
910a5c2606 fix: OnStart ctx bug, rename session→user, full logout cleanup
All checks were successful
check / check (push) Successful in 1m57s
- Use context.Background() for cleanup goroutine instead of
  OnStart ctx which is cancelled after startup completes
- Rename GetSessionCount→GetUserCount, DeleteStaleSessions→
  DeleteStaleUsers to reflect that sessions represent users
- HandleLogout now fully cleans up when last client disconnects:
  parts all channels (notifying members via QUIT), removes
  empty channels, and deletes the session/user record
- docker build passes, all tests green, 0 lint issues
2026-02-28 11:14:23 -08:00
bdc243224b feat: add session idle timeout cleanup goroutine
All checks were successful
check / check (push) Successful in 1m58s
- Periodic cleanup loop deletes stale clients based on SESSION_IDLE_TIMEOUT
- Orphaned sessions (no clients) are cleaned up automatically
- last_seen already updated on each authenticated request via GetSessionByToken
2026-02-28 10:59:09 -08:00
5981c750a4 feat: add SESSION_IDLE_TIMEOUT config
- New env var SESSION_IDLE_TIMEOUT (default 24h)
- Parsed as time.Duration in handlers
2026-02-28 10:59:09 -08:00
6cfab21eaa feat: add logout endpoint and users/me endpoint
- POST /api/v1/logout: deletes client token, returns {status: ok}
- GET /api/v1/users/me: returns session info (delegates to HandleState)
- Add DeleteClient, GetSessionCount, ClientCountForSession, DeleteStaleSessions to db layer
- Add user count to GET /api/v1/server response
- Extract setupAPIv1 to fix funlen lint issue
2026-02-28 10:59:09 -08:00
4a0ed57fc0 Merge pull request 'feat: password-based registration and login (closes #1)' (#23) from feature/auth-passwords into main
All checks were successful
check / check (push) Successful in 9s
Reviewed-on: #23
Reviewed-by: Jeffrey Paul <sneak@noreply.example.org>
2026-02-28 19:57:40 +01:00
user
52c85724a7 fix: remove unused //nolint:gosec directives on password fields
All checks were successful
check / check (push) Successful in 2m4s
2026-02-28 10:33:59 -08:00
clawbot
69c9550bb2 consolidate password_hash into 001 migration
Some checks failed
check / check (push) Failing after 1m27s
Pre-1.0, no installed base — merge 002_add_passwords.sql into
001_initial.sql and remove the separate migration file.
2026-02-28 07:59:01 -08:00
7047167dc8 Add tests for register and login endpoints
Some checks failed
check / check (push) Failing after 1m39s
2026-02-27 05:00:51 -08:00
3cd942ffa5 Add /api/v1/register and /api/v1/login routes 2026-02-27 04:55:40 -08:00
b8794c2587 Add register and login HTTP handlers 2026-02-27 04:55:31 -08:00
70aa15e758 Add RegisterUser and LoginUser DB functions with bcrypt 2026-02-27 04:55:06 -08:00
5e26e53187 Add migration 002: add password_hash column to sessions 2026-02-27 04:54:31 -08:00
02b906badb Merge pull request 'feat: MVP two-user chat via embedded SPA (closes #9)' (#22) from feat/mvp-two-user-chat into main
All checks were successful
check / check (push) Successful in 10s
Reviewed-on: #22
2026-02-27 13:51:20 +01:00
clawbot
32419fb1f7 feat: MVP two-user chat via embedded SPA (#9)
All checks were successful
check / check (push) Successful in 1m51s
Backend:
- Session/client UUID model: sessions table (uuid, nick, signing_key),
  clients table (uuid, session_id, token) with per-client message queues
- MOTD delivery as IRC numeric messages (375/372/376) on connect
- EnqueueToSession fans out to all clients of a session
- EnqueueToClient for targeted delivery (MOTD)
- All queries updated for session/client model

SPA client:
- Long-poll loop (15s timeout) instead of setInterval
- IRC message envelope parsing (command/from/to/body)
- Display JOIN/PART/NICK/TOPIC/QUIT system messages
- Nick change via /nick command
- Topic display in header bar
- Unread count badges on inactive tabs
- Auto-rejoin channels on reconnect (localStorage)
- Connection status indicator
- Message deduplication by UUID
- Channel history loaded on join
- /topic command support

Closes #9
2026-02-27 02:21:48 -08:00
19 changed files with 2320 additions and 996 deletions

View File

@@ -1,18 +1,28 @@
# Lint stage — fast feedback on formatting and lint issues
# golangci/golangci-lint:v2.1.6
FROM golangci/golangci-lint@sha256:568ee1c1c53493575fa9494e280e579ac9ca865787bafe4df3023ae59ecf299b AS lint
WORKDIR /src
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN make fmt-check
RUN make lint
# Build stage — tests and compilation
# golang:1.24-alpine, 2026-02-26 # golang:1.24-alpine, 2026-02-26
FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
WORKDIR /src WORKDIR /src
RUN apk add --no-cache git build-base make RUN apk add --no-cache git build-base make
# golangci-lint v2.1.6 (eabc2638a66d), 2026-02-26 # Force BuildKit to run the lint stage by creating a stage dependency
RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d COPY --from=lint /src/go.sum /dev/null
COPY go.mod go.sum ./ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
# Run all checks — build fails if branch is not green RUN make test
RUN make check
# Build static binaries (no cgo needed at runtime — modernc.org/sqlite is pure Go) # Build static binaries (no cgo needed at runtime — modernc.org/sqlite is pure Go)
ARG VERSION=dev ARG VERSION=dev

View File

@@ -1158,6 +1158,55 @@ curl -s http://localhost:8080/api/v1/channels/general/members \
-H "Authorization: Bearer $TOKEN" | jq . -H "Authorization: Bearer $TOKEN" | jq .
``` ```
### POST /api/v1/logout — Logout
Destroy the current client's auth token. If no other clients remain on the
session, the user is fully cleaned up: parted from all channels (with QUIT
broadcast to members), session deleted, nick released.
**Request:** No body. Requires auth.
**Response:** `200 OK`
```json
{"status": "ok"}
```
**Errors:**
| Status | Error | When |
|--------|-------|------|
| 401 | `unauthorized` | Missing or invalid auth token |
**curl example:**
```bash
curl -s -X POST http://localhost:8080/api/v1/logout \
-H "Authorization: Bearer $TOKEN" | jq .
```
### GET /api/v1/users/me — Current User Info
Return the current user's session state. This is an alias for
`GET /api/v1/state`.
**Request:** No body. Requires auth.
**Response:** `200 OK`
```json
{
"id": 1,
"nick": "alice",
"channels": [
{"id": 1, "name": "#general", "topic": "Welcome!"}
]
}
```
**curl example:**
```bash
curl -s http://localhost:8080/api/v1/users/me \
-H "Authorization: Bearer $TOKEN" | jq .
```
### GET /api/v1/server — Server Info ### GET /api/v1/server — Server Info
Return server metadata. No authentication required. Return server metadata. No authentication required.
@@ -1166,10 +1215,17 @@ Return server metadata. No authentication required.
```json ```json
{ {
"name": "My Chat Server", "name": "My Chat Server",
"motd": "Welcome! Be nice." "motd": "Welcome! Be nice.",
"users": 42
} }
``` ```
| Field | Type | Description |
|---------|---------|-------------|
| `name` | string | Server display name |
| `motd` | string | Message of the day |
| `users` | integer | Number of currently active user sessions |
### GET /.well-known/healthcheck.json — Health Check ### GET /.well-known/healthcheck.json — Health Check
Standard health check endpoint. No authentication required. Standard health check endpoint. No authentication required.
@@ -1572,8 +1628,10 @@ skew issues) and simpler than UUIDs (integer comparison vs. string comparison).
- **Queue entries**: Stored until pruned. Pruning by `QUEUE_MAX_AGE` is - **Queue entries**: Stored until pruned. Pruning by `QUEUE_MAX_AGE` is
planned. planned.
- **Channels**: Deleted when the last member leaves (ephemeral). - **Channels**: Deleted when the last member leaves (ephemeral).
- **Users/sessions**: Deleted on `QUIT`. Session expiry by `SESSION_TIMEOUT` - **Users/sessions**: Deleted on `QUIT` or `POST /api/v1/logout`. Idle
is planned. sessions are automatically expired after `SESSION_IDLE_TIMEOUT` (default
24h) — the server runs a background cleanup loop that parts idle users
from all channels, broadcasts QUIT, and releases their nicks.
--- ---
@@ -1590,7 +1648,7 @@ directory is also loaded automatically via
| `DBURL` | string | `file:./data.db?_journal_mode=WAL` | SQLite connection string. For file-based: `file:./path.db?_journal_mode=WAL`. For in-memory (testing): `file::memory:?cache=shared`. | | `DBURL` | string | `file:./data.db?_journal_mode=WAL` | SQLite connection string. For file-based: `file:./path.db?_journal_mode=WAL`. For in-memory (testing): `file::memory:?cache=shared`. |
| `DEBUG` | bool | `false` | Enable debug logging (verbose request/response logging) | | `DEBUG` | bool | `false` | Enable debug logging (verbose request/response logging) |
| `MAX_HISTORY` | int | `10000` | Maximum messages retained per channel before rotation (planned) | | `MAX_HISTORY` | int | `10000` | Maximum messages retained per channel before rotation (planned) |
| `SESSION_TIMEOUT` | int | `86400` | Session idle timeout in seconds (planned). Sessions with no activity for this long are expired and the nick is released. | | `SESSION_IDLE_TIMEOUT` | string | `24h` | Session idle timeout as a Go duration string (e.g. `24h`, `30m`). Sessions with no activity for this long are expired and the nick is released. |
| `QUEUE_MAX_AGE` | int | `172800` | Maximum age of client queue entries in seconds (48h). Entries older than this are pruned (planned). | | `QUEUE_MAX_AGE` | int | `172800` | Maximum age of client queue entries in seconds (48h). Entries older than this are pruned (planned). |
| `MAX_MESSAGE_SIZE` | int | `4096` | Maximum message body size in bytes (planned enforcement) | | `MAX_MESSAGE_SIZE` | int | `4096` | Maximum message body size in bytes (planned enforcement) |
| `LONG_POLL_TIMEOUT`| int | `15` | Default long-poll timeout in seconds (client can override via query param, server caps at 30) | | `LONG_POLL_TIMEOUT`| int | `15` | Default long-poll timeout in seconds (client can override via query param, server caps at 30) |
@@ -1610,7 +1668,7 @@ SERVER_NAME=My Chat Server
MOTD=Welcome! Be excellent to each other. MOTD=Welcome! Be excellent to each other.
DEBUG=false DEBUG=false
DBURL=file:./data.db?_journal_mode=WAL DBURL=file:./data.db?_journal_mode=WAL
SESSION_TIMEOUT=86400 SESSION_IDLE_TIMEOUT=24h
``` ```
--- ---
@@ -2008,11 +2066,14 @@ GET /api/v1/challenge
- [x] Docker deployment - [x] Docker deployment
- [x] Prometheus metrics endpoint - [x] Prometheus metrics endpoint
- [x] Health check endpoint - [x] Health check endpoint
- [x] Session expiry — auto-expire idle sessions, release nicks
- [x] Logout endpoint (`POST /api/v1/logout`)
- [x] Current user endpoint (`GET /api/v1/users/me`)
- [x] User count in server info (`GET /api/v1/server`)
### Post-MVP (Planned) ### Post-MVP (Planned)
- [ ] **Hashcash proof-of-work** for session creation (abuse prevention) - [ ] **Hashcash proof-of-work** for session creation (abuse prevention)
- [ ] **Session expiry** — auto-expire idle sessions, release nicks
- [ ] **Queue pruning** — delete old queue entries per `QUEUE_MAX_AGE` - [ ] **Queue pruning** — delete old queue entries per `QUEUE_MAX_AGE`
- [ ] **Message rotation** — enforce `MAX_HISTORY` per channel - [ ] **Message rotation** — enforce `MAX_HISTORY` per channel
- [ ] **Channel modes** — enforce `+i`, `+m`, `+s`, `+t`, `+n` - [ ] **Channel modes** — enforce `+i`, `+m`, `+s`, `+t`, `+n`

13
go.mod
View File

@@ -4,14 +4,18 @@ go 1.24.0
require ( require (
github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8 github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8
github.com/gdamore/tcell/v2 v2.13.8
github.com/getsentry/sentry-go v0.42.0 github.com/getsentry/sentry-go v0.42.0
github.com/go-chi/chi v1.5.5 github.com/go-chi/chi v1.5.5
github.com/go-chi/cors v1.2.2 github.com/go-chi/cors v1.2.2
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/rivo/tview v0.42.0
github.com/slok/go-http-metrics v0.13.0 github.com/slok/go-http-metrics v0.13.0
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.45.0 modernc.org/sqlite v1.45.0
) )
@@ -21,9 +25,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gdamore/encoding v1.0.1 // indirect github.com/gdamore/encoding v1.0.1 // indirect
github.com/gdamore/tcell/v2 v2.13.8 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
@@ -33,7 +35,6 @@ require (
github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/tview v0.42.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
@@ -47,9 +48,9 @@ require (
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/term v0.37.0 // indirect golang.org/x/term v0.40.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect google.golang.org/protobuf v1.36.8 // indirect
modernc.org/libc v1.67.6 // indirect modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

30
go.sum
View File

@@ -113,12 +113,14 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@@ -126,8 +128,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -135,30 +137,26 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=

View File

@@ -31,11 +31,11 @@ type Config struct {
Port int Port int
SentryDSN string SentryDSN string
MaxHistory int MaxHistory int
SessionTimeout int
MaxMessageSize int MaxMessageSize int
MOTD string MOTD string
ServerName string ServerName string
FederationKey string FederationKey string
SessionIdleTimeout string
params *Params params *Params
log *slog.Logger log *slog.Logger
} }
@@ -61,11 +61,11 @@ func New(
viper.SetDefault("METRICS_USERNAME", "") viper.SetDefault("METRICS_USERNAME", "")
viper.SetDefault("METRICS_PASSWORD", "") viper.SetDefault("METRICS_PASSWORD", "")
viper.SetDefault("MAX_HISTORY", "10000") viper.SetDefault("MAX_HISTORY", "10000")
viper.SetDefault("SESSION_TIMEOUT", "86400")
viper.SetDefault("MAX_MESSAGE_SIZE", "4096") viper.SetDefault("MAX_MESSAGE_SIZE", "4096")
viper.SetDefault("MOTD", "") viper.SetDefault("MOTD", "")
viper.SetDefault("SERVER_NAME", "") viper.SetDefault("SERVER_NAME", "")
viper.SetDefault("FEDERATION_KEY", "") viper.SetDefault("FEDERATION_KEY", "")
viper.SetDefault("SESSION_IDLE_TIMEOUT", "24h")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
@@ -85,11 +85,11 @@ func New(
MetricsUsername: viper.GetString("METRICS_USERNAME"), MetricsUsername: viper.GetString("METRICS_USERNAME"),
MetricsPassword: viper.GetString("METRICS_PASSWORD"), MetricsPassword: viper.GetString("METRICS_PASSWORD"),
MaxHistory: viper.GetInt("MAX_HISTORY"), MaxHistory: viper.GetInt("MAX_HISTORY"),
SessionTimeout: viper.GetInt("SESSION_TIMEOUT"),
MaxMessageSize: viper.GetInt("MAX_MESSAGE_SIZE"), MaxMessageSize: viper.GetInt("MAX_MESSAGE_SIZE"),
MOTD: viper.GetString("MOTD"), MOTD: viper.GetString("MOTD"),
ServerName: viper.GetString("SERVER_NAME"), ServerName: viper.GetString("SERVER_NAME"),
FederationKey: viper.GetString("FEDERATION_KEY"), FederationKey: viper.GetString("FEDERATION_KEY"),
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
log: log, log: log,
params: &params, params: &params,
} }

161
internal/db/auth.go Normal file
View File

@@ -0,0 +1,161 @@
package db
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
const bcryptCost = bcrypt.DefaultCost
var errNoPassword = errors.New(
"account has no password set",
)
// RegisterUser creates a session with a hashed password
// and returns session ID, client ID, and token.
func (database *Database) RegisterUser(
ctx context.Context,
nick, password string,
) (int64, int64, string, error) {
hash, err := bcrypt.GenerateFromPassword(
[]byte(password), bcryptCost,
)
if err != nil {
return 0, 0, "", fmt.Errorf(
"hash password: %w", err,
)
}
sessionUUID := uuid.New().String()
clientUUID := uuid.New().String()
token, err := generateToken()
if err != nil {
return 0, 0, "", err
}
now := time.Now()
transaction, err := database.conn.BeginTx(ctx, nil)
if err != nil {
return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
}
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, password_hash,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
sessionUUID, nick, string(hash), now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, token, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit registration: %w", err,
)
}
return sessionID, clientID, token, nil
}
// LoginUser verifies a nick/password and creates a new
// client token.
func (database *Database) LoginUser(
ctx context.Context,
nick, password string,
) (int64, int64, string, error) {
var (
sessionID int64
passwordHash string
)
err := database.conn.QueryRowContext(
ctx,
`SELECT id, password_hash
FROM sessions WHERE nick = ?`,
nick,
).Scan(&sessionID, &passwordHash)
if err != nil {
return 0, 0, "", fmt.Errorf(
"get session for login: %w", err,
)
}
if passwordHash == "" {
return 0, 0, "", fmt.Errorf(
"login: %w", errNoPassword,
)
}
err = bcrypt.CompareHashAndPassword(
[]byte(passwordHash), []byte(password),
)
if err != nil {
return 0, 0, "", fmt.Errorf(
"verify password: %w", err,
)
}
clientUUID := uuid.New().String()
token, err := generateToken()
if err != nil {
return 0, 0, "", err
}
now := time.Now()
res, err := database.conn.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, token, now, now)
if err != nil {
return 0, 0, "", fmt.Errorf(
"create login client: %w", err,
)
}
clientID, _ := res.LastInsertId()
_, _ = database.conn.ExecContext(
ctx,
"UPDATE sessions SET last_seen = ? WHERE id = ?",
now, sessionID,
)
return sessionID, clientID, token, nil
}

178
internal/db/auth_test.go Normal file
View File

@@ -0,0 +1,178 @@
package db_test
import (
"testing"
_ "modernc.org/sqlite"
)
func TestRegisterUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, clientID, token, err :=
database.RegisterUser(ctx, "reguser", "password123")
if err != nil {
t.Fatal(err)
}
if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token")
}
// Verify session works via token lookup.
sid, cid, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if sid != sessionID || cid != clientID {
t.Fatal("session/client id mismatch")
}
if nick != "reguser" {
t.Fatalf("expected reguser, got %s", nick)
}
}
func TestRegisterUserDuplicateNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "dupnick", "password123")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345")
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
}
_ = dupSID
_ = dupCID
_ = dupToken
}
func TestLoginUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "loginuser", "mypassword")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword")
if err != nil {
t.Fatal(err)
}
if sessionID == 0 || clientID == 0 || token == "" {
t.Fatal("expected valid ids and token")
}
// Verify the new token works.
_, _, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if nick != "loginuser" {
t.Fatalf("expected loginuser, got %s", nick)
}
}
func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass")
if err != nil {
t.Fatal(err)
}
_ = regSID
_ = regCID
_ = regToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
_ = loginSID
_ = loginCID
_ = loginToken
}
func TestLoginUserNoPassword(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
// Create anonymous session (no password).
anonSID, anonCID, anonToken, err :=
database.CreateSession(ctx, "anon")
if err != nil {
t.Fatal(err)
}
_ = anonSID
_ = anonCID
_ = anonToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "anon", "anything1")
if loginErr == nil {
t.Fatal(
"expected error for login on passwordless account",
)
}
_ = loginSID
_ = loginCID
_ = loginToken
}
func TestLoginUserNonexistent(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "ghost", "password123")
if err == nil {
t.Fatal("expected error for nonexistent user")
}
_ = loginSID
_ = loginCID
_ = loginToken
}

View File

@@ -55,76 +55,132 @@ type MemberInfo struct {
LastSeen time.Time `json:"lastSeen"` LastSeen time.Time `json:"lastSeen"`
} }
// CreateUser registers a new user with the given nick. // CreateSession registers a new session and its first client.
func (database *Database) CreateUser( func (database *Database) CreateSession(
ctx context.Context, ctx context.Context,
nick string, nick string,
) (int64, string, error) { ) (int64, int64, string, error) {
sessionUUID := uuid.New().String()
clientUUID := uuid.New().String()
token, err := generateToken() token, err := generateToken()
if err != nil { if err != nil {
return 0, "", err return 0, 0, "", err
} }
now := time.Now() now := time.Now()
res, err := database.conn.ExecContext(ctx, transaction, err := database.conn.BeginTx(ctx, nil)
`INSERT INTO users
(nick, token, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
nick, token, now, now)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("create user: %w", err) return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
} }
userID, _ := res.LastInsertId() res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
sessionUUID, nick, now, now)
if err != nil {
_ = transaction.Rollback()
return userID, token, nil return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
} }
// GetUserByToken returns user id and nick for a token. sessionID, _ := res.LastInsertId()
func (database *Database) GetUserByToken(
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, token, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit session: %w", err,
)
}
return sessionID, clientID, token, nil
}
// GetSessionByToken returns session id, client id, and
// nick for a client token.
func (database *Database) GetSessionByToken(
ctx context.Context, ctx context.Context,
token string, token string,
) (int64, string, error) { ) (int64, int64, string, error) {
var userID int64 var (
sessionID int64
var nick string clientID int64
nick string
)
err := database.conn.QueryRowContext( err := database.conn.QueryRowContext(
ctx, ctx,
"SELECT id, nick FROM users WHERE token = ?", `SELECT s.id, c.id, s.nick
FROM clients c
INNER JOIN sessions s
ON s.id = c.session_id
WHERE c.token = ?`,
token, token,
).Scan(&userID, &nick) ).Scan(&sessionID, &clientID, &nick)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("get user by token: %w", err) return 0, 0, "", fmt.Errorf(
"get session by token: %w", err,
)
} }
now := time.Now()
_, _ = database.conn.ExecContext( _, _ = database.conn.ExecContext(
ctx, ctx,
"UPDATE users SET last_seen = ? WHERE id = ?", "UPDATE sessions SET last_seen = ? WHERE id = ?",
time.Now(), userID, now, sessionID,
) )
return userID, nick, nil _, _ = database.conn.ExecContext(
ctx,
"UPDATE clients SET last_seen = ? WHERE id = ?",
now, clientID,
)
return sessionID, clientID, nick, nil
} }
// GetUserByNick returns user id for a given nick. // GetSessionByNick returns session id for a given nick.
func (database *Database) GetUserByNick( func (database *Database) GetSessionByNick(
ctx context.Context, ctx context.Context,
nick string, nick string,
) (int64, error) { ) (int64, error) {
var userID int64 var sessionID int64
err := database.conn.QueryRowContext( err := database.conn.QueryRowContext(
ctx, ctx,
"SELECT id FROM users WHERE nick = ?", "SELECT id FROM sessions WHERE nick = ?",
nick, nick,
).Scan(&userID) ).Scan(&sessionID)
if err != nil { if err != nil {
return 0, fmt.Errorf("get user by nick: %w", err) return 0, fmt.Errorf(
"get session by nick: %w", err,
)
} }
return userID, nil return sessionID, nil
} }
// GetChannelByName returns the channel ID for a name. // GetChannelByName returns the channel ID for a name.
@@ -179,16 +235,16 @@ func (database *Database) GetOrCreateChannel(
return channelID, nil return channelID, nil
} }
// JoinChannel adds a user to a channel. // JoinChannel adds a session to a channel.
func (database *Database) JoinChannel( func (database *Database) JoinChannel(
ctx context.Context, ctx context.Context,
channelID, userID int64, channelID, sessionID int64,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_members `INSERT OR IGNORE INTO channel_members
(channel_id, user_id, joined_at) (channel_id, session_id, joined_at)
VALUES (?, ?, ?)`, VALUES (?, ?, ?)`,
channelID, userID, time.Now()) channelID, sessionID, time.Now())
if err != nil { if err != nil {
return fmt.Errorf("join channel: %w", err) return fmt.Errorf("join channel: %w", err)
} }
@@ -196,15 +252,15 @@ func (database *Database) JoinChannel(
return nil return nil
} }
// PartChannel removes a user from a channel. // PartChannel removes a session from a channel.
func (database *Database) PartChannel( func (database *Database) PartChannel(
ctx context.Context, ctx context.Context,
channelID, userID int64, channelID, sessionID int64,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
`DELETE FROM channel_members `DELETE FROM channel_members
WHERE channel_id = ? AND user_id = ?`, WHERE channel_id = ? AND session_id = ?`,
channelID, userID) channelID, sessionID)
if err != nil { if err != nil {
return fmt.Errorf("part channel: %w", err) return fmt.Errorf("part channel: %w", err)
} }
@@ -265,18 +321,18 @@ func scanChannels(
return out, nil return out, nil
} }
// ListChannels returns channels the user has joined. // ListChannels returns channels the session has joined.
func (database *Database) ListChannels( func (database *Database) ListChannels(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) ([]ChannelInfo, error) { ) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic `SELECT c.id, c.name, c.topic
FROM channels c FROM channels c
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.channel_id = c.id ON cm.channel_id = c.id
WHERE cm.user_id = ? WHERE cm.session_id = ?
ORDER BY c.name`, userID) ORDER BY c.name`, sessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list channels: %w", err) return nil, fmt.Errorf("list channels: %w", err)
} }
@@ -306,12 +362,12 @@ func (database *Database) ChannelMembers(
channelID int64, channelID int64,
) ([]MemberInfo, error) { ) ([]MemberInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen `SELECT s.id, s.nick, s.last_seen
FROM users u FROM sessions s
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.user_id = u.id ON cm.session_id = s.id
WHERE cm.channel_id = ? WHERE cm.channel_id = ?
ORDER BY u.nick`, channelID) ORDER BY s.nick`, channelID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"query channel members: %w", err, "query channel members: %w", err,
@@ -349,17 +405,17 @@ func (database *Database) ChannelMembers(
return members, nil return members, nil
} }
// IsChannelMember checks if a user belongs to a channel. // IsChannelMember checks if a session belongs to a channel.
func (database *Database) IsChannelMember( func (database *Database) IsChannelMember(
ctx context.Context, ctx context.Context,
channelID, userID int64, channelID, sessionID int64,
) (bool, error) { ) (bool, error) {
var count int var count int
err := database.conn.QueryRowContext(ctx, err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM channel_members `SELECT COUNT(*) FROM channel_members
WHERE channel_id = ? AND user_id = ?`, WHERE channel_id = ? AND session_id = ?`,
channelID, userID, channelID, sessionID,
).Scan(&count) ).Scan(&count)
if err != nil { if err != nil {
return false, fmt.Errorf( return false, fmt.Errorf(
@@ -397,13 +453,13 @@ func scanInt64s(rows *sql.Rows) ([]int64, error) {
return ids, nil return ids, nil
} }
// GetChannelMemberIDs returns user IDs in a channel. // GetChannelMemberIDs returns session IDs in a channel.
func (database *Database) GetChannelMemberIDs( func (database *Database) GetChannelMemberIDs(
ctx context.Context, ctx context.Context,
channelID int64, channelID int64,
) ([]int64, error) { ) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT user_id FROM channel_members `SELECT session_id FROM channel_members
WHERE channel_id = ?`, channelID) WHERE channel_id = ?`, channelID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
@@ -414,17 +470,17 @@ func (database *Database) GetChannelMemberIDs(
return scanInt64s(rows) return scanInt64s(rows)
} }
// GetUserChannelIDs returns channel IDs the user is in. // GetSessionChannelIDs returns channel IDs for a session.
func (database *Database) GetUserChannelIDs( func (database *Database) GetSessionChannelIDs(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) ([]int64, error) { ) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT channel_id FROM channel_members `SELECT channel_id FROM channel_members
WHERE user_id = ?`, userID) WHERE session_id = ?`, sessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"get user channel ids: %w", err, "get session channel ids: %w", err,
) )
} }
@@ -467,27 +523,52 @@ func (database *Database) InsertMessage(
return dbID, msgUUID, nil return dbID, msgUUID, nil
} }
// EnqueueMessage adds a message to a user's queue. // EnqueueToSession adds a message to all clients of a
func (database *Database) EnqueueMessage( // session's queues.
func (database *Database) EnqueueToSession(
ctx context.Context, ctx context.Context,
userID, messageID int64, sessionID, messageID int64,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues `INSERT OR IGNORE INTO client_queues
(user_id, message_id, created_at) (client_id, message_id, created_at)
VALUES (?, ?, ?)`, SELECT c.id, ?, ?
userID, messageID, time.Now()) FROM clients c
WHERE c.session_id = ?`,
messageID, time.Now(), sessionID)
if err != nil { if err != nil {
return fmt.Errorf("enqueue message: %w", err) return fmt.Errorf(
"enqueue to session: %w", err,
)
} }
return nil return nil
} }
// PollMessages returns queued messages for a user. // EnqueueToClient adds a message to a specific client's
// queue.
func (database *Database) EnqueueToClient(
ctx context.Context,
clientID, messageID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues
(client_id, message_id, created_at)
VALUES (?, ?, ?)`,
clientID, messageID, time.Now())
if err != nil {
return fmt.Errorf(
"enqueue to client: %w", err,
)
}
return nil
}
// PollMessages returns queued messages for a client.
func (database *Database) PollMessages( func (database *Database) PollMessages(
ctx context.Context, ctx context.Context,
userID, afterQueueID int64, clientID, afterQueueID int64,
limit int, limit int,
) ([]IRCMessage, int64, error) { ) ([]IRCMessage, int64, error) {
if limit <= 0 { if limit <= 0 {
@@ -501,9 +582,9 @@ func (database *Database) PollMessages(
FROM client_queues cq FROM client_queues cq
INNER JOIN messages m INNER JOIN messages m
ON m.id = cq.message_id ON m.id = cq.message_id
WHERE cq.user_id = ? AND cq.id > ? WHERE cq.client_id = ? AND cq.id > ?
ORDER BY cq.id ASC LIMIT ?`, ORDER BY cq.id ASC LIMIT ?`,
userID, afterQueueID, limit) clientID, afterQueueID, limit)
if err != nil { if err != nil {
return nil, afterQueueID, fmt.Errorf( return nil, afterQueueID, fmt.Errorf(
"poll messages: %w", err, "poll messages: %w", err,
@@ -649,15 +730,15 @@ func reverseMessages(msgs []IRCMessage) {
} }
} }
// ChangeNick updates a user's nickname. // ChangeNick updates a session's nickname.
func (database *Database) ChangeNick( func (database *Database) ChangeNick(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
newNick string, newNick string,
) error { ) error {
_, err := database.conn.ExecContext(ctx, _, err := database.conn.ExecContext(ctx,
"UPDATE users SET nick = ? WHERE id = ?", "UPDATE sessions SET nick = ? WHERE id = ?",
newNick, userID) newNick, sessionID)
if err != nil { if err != nil {
return fmt.Errorf("change nick: %w", err) return fmt.Errorf("change nick: %w", err)
} }
@@ -681,38 +762,182 @@ func (database *Database) SetTopic(
return nil return nil
} }
// DeleteUser removes a user and all their data. // DeleteSession removes a session and all its data.
func (database *Database) DeleteUser( func (database *Database) DeleteSession(
ctx context.Context, ctx context.Context,
userID int64, sessionID int64,
) error { ) error {
_, err := database.conn.ExecContext( _, err := database.conn.ExecContext(
ctx, ctx,
"DELETE FROM users WHERE id = ?", "DELETE FROM sessions WHERE id = ?",
userID, sessionID,
) )
if err != nil { if err != nil {
return fmt.Errorf("delete user: %w", err) return fmt.Errorf("delete session: %w", err)
} }
return nil return nil
} }
// GetAllChannelMembershipsForUser returns channels // DeleteClient removes a single client record by ID.
// a user belongs to. func (database *Database) DeleteClient(
func (database *Database) GetAllChannelMembershipsForUser(
ctx context.Context, ctx context.Context,
userID int64, clientID int64,
) error {
_, err := database.conn.ExecContext(
ctx,
"DELETE FROM clients WHERE id = ?",
clientID,
)
if err != nil {
return fmt.Errorf("delete client: %w", err)
}
return nil
}
// GetUserCount returns the number of active users.
func (database *Database) GetUserCount(
ctx context.Context,
) (int64, error) {
var count int64
err := database.conn.QueryRowContext(
ctx,
"SELECT COUNT(*) FROM sessions",
).Scan(&count)
if err != nil {
return 0, fmt.Errorf(
"get user count: %w", err,
)
}
return count, nil
}
// ClientCountForSession returns the number of clients
// belonging to a session.
func (database *Database) ClientCountForSession(
ctx context.Context,
sessionID int64,
) (int64, error) {
var count int64
err := database.conn.QueryRowContext(
ctx,
`SELECT COUNT(*) FROM clients
WHERE session_id = ?`,
sessionID,
).Scan(&count)
if err != nil {
return 0, fmt.Errorf(
"client count for session: %w", err,
)
}
return count, nil
}
// DeleteStaleUsers removes clients not seen since the
// cutoff and cleans up orphaned users (sessions).
func (database *Database) DeleteStaleUsers(
ctx context.Context,
cutoff time.Time,
) (int64, error) {
res, err := database.conn.ExecContext(ctx,
"DELETE FROM clients WHERE last_seen < ?",
cutoff,
)
if err != nil {
return 0, fmt.Errorf(
"delete stale clients: %w", err,
)
}
deleted, _ := res.RowsAffected()
_, err = database.conn.ExecContext(ctx,
`DELETE FROM sessions WHERE id NOT IN
(SELECT DISTINCT session_id FROM clients)`,
)
if err != nil {
return deleted, fmt.Errorf(
"delete orphan sessions: %w", err,
)
}
return deleted, nil
}
// StaleSession holds the id and nick of a session
// whose clients are all stale.
type StaleSession struct {
ID int64
Nick string
}
// GetStaleOrphanSessions returns sessions where every
// client has a last_seen before cutoff.
func (database *Database) GetStaleOrphanSessions(
ctx context.Context,
cutoff time.Time,
) ([]StaleSession, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT s.id, s.nick
FROM sessions s
WHERE s.id IN (
SELECT DISTINCT session_id FROM clients
WHERE last_seen < ?
)
AND s.id NOT IN (
SELECT DISTINCT session_id FROM clients
WHERE last_seen >= ?
)`, cutoff, cutoff)
if err != nil {
return nil, fmt.Errorf(
"get stale orphan sessions: %w", err,
)
}
defer func() { _ = rows.Close() }()
var result []StaleSession
for rows.Next() {
var stale StaleSession
if err := rows.Scan(&stale.ID, &stale.Nick); err != nil {
return nil, fmt.Errorf(
"scan stale session: %w", err,
)
}
result = append(result, stale)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf(
"iterate stale sessions: %w", err,
)
}
return result, nil
}
// GetSessionChannels returns channels a session
// belongs to.
func (database *Database) GetSessionChannels(
ctx context.Context,
sessionID int64,
) ([]ChannelInfo, error) { ) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx, rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic `SELECT c.id, c.name, c.topic
FROM channels c FROM channels c
INNER JOIN channel_members cm INNER JOIN channel_members cm
ON cm.channel_id = c.id ON cm.channel_id = c.id
WHERE cm.user_id = ?`, userID) WHERE cm.session_id = ?`, sessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"get memberships: %w", err, "get session channels: %w", err,
) )
} }

View File

@@ -27,70 +27,91 @@ func setupTestDB(t *testing.T) *db.Database {
return database return database
} }
func TestCreateUser(t *testing.T) { func TestCreateSession(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
id, token, err := database.CreateUser(ctx, "alice") sessionID, _, token, err := database.CreateSession(
ctx, "alice",
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if id == 0 || token == "" { if sessionID == 0 || token == "" {
t.Fatal("expected valid id and token") t.Fatal("expected valid id and token")
} }
_, _, err = database.CreateUser(ctx, "alice") _, _, dupToken, dupErr := database.CreateSession(
if err == nil { ctx, "alice",
)
if dupErr == nil {
t.Fatal("expected error for duplicate nick") t.Fatal("expected error for duplicate nick")
} }
_ = dupToken
} }
func TestGetUserByToken(t *testing.T) { func TestGetSessionByToken(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
_, token, err := database.CreateUser(ctx, "bob") _, _, token, err := database.CreateSession(ctx, "bob")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
id, nick, err := database.GetUserByToken(ctx, token) sessionID, clientID, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if nick != "bob" || id == 0 { if nick != "bob" || sessionID == 0 || clientID == 0 {
t.Fatalf("expected bob, got %s", nick) t.Fatalf("expected bob, got %s", nick)
} }
_, _, err = database.GetUserByToken(ctx, "badtoken") badSID, badCID, badNick, badErr :=
if err == nil { database.GetSessionByToken(ctx, "badtoken")
if badErr == nil {
t.Fatal("expected error for bad token") t.Fatal("expected error for bad token")
} }
if badSID != 0 || badCID != 0 || badNick != "" {
t.Fatal("expected zero values on error")
}
} }
func TestGetUserByNick(t *testing.T) { func TestGetSessionByNick(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
_, _, err := database.CreateUser(ctx, "charlie") charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
id, err := database.GetUserByNick(ctx, "charlie") if charlieID == 0 || charlieClientID == 0 {
t.Fatal("expected valid session/client IDs")
}
if charlieToken == "" {
t.Fatal("expected non-empty token")
}
id, err := database.GetSessionByNick(ctx, "charlie")
if err != nil || id == 0 { if err != nil || id == 0 {
t.Fatal("expected to find charlie") t.Fatal("expected to find charlie")
} }
_, err = database.GetUserByNick(ctx, "nobody") _, err = database.GetSessionByNick(ctx, "nobody")
if err == nil { if err == nil {
t.Fatal("expected error for unknown nick") t.Fatal("expected error for unknown nick")
} }
@@ -129,7 +150,7 @@ func TestJoinAndPart(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, "user1") sid, _, _, err := database.CreateSession(ctx, "user1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -139,22 +160,22 @@ func TestJoinAndPart(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ids, err := database.GetChannelMemberIDs(ctx, chID) ids, err := database.GetChannelMemberIDs(ctx, chID)
if err != nil || len(ids) != 1 || ids[0] != uid { if err != nil || len(ids) != 1 || ids[0] != sid {
t.Fatal("expected user in channel") t.Fatal("expected session in channel")
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.PartChannel(ctx, chID, uid) err = database.PartChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -178,17 +199,17 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
uid, _, err := database.CreateUser(ctx, "temp") sid, _, _, err := database.CreateSession(ctx, "temp")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.PartChannel(ctx, chID, uid) err = database.PartChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -204,7 +225,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
} }
} }
func createUserWithChannels( func createSessionWithChannels(
t *testing.T, t *testing.T,
database *db.Database, database *db.Database,
nick, ch1Name, ch2Name string, nick, ch1Name, ch2Name string,
@@ -213,7 +234,7 @@ func createUserWithChannels(
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, nick) sid, _, _, err := database.CreateSession(ctx, nick)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -232,29 +253,29 @@ func createUserWithChannels(
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, ch1, uid) err = database.JoinChannel(ctx, ch1, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, ch2, uid) err = database.JoinChannel(ctx, ch2, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return uid, ch1, ch2 return sid, ch1, ch2
} }
func TestListChannels(t *testing.T) { func TestListChannels(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
uid, _, _ := createUserWithChannels( sid, _, _ := createSessionWithChannels(
t, database, "lister", "#a", "#b", t, database, "lister", "#a", "#b",
) )
channels, err := database.ListChannels( channels, err := database.ListChannels(
t.Context(), uid, t.Context(), sid,
) )
if err != nil || len(channels) != 2 { if err != nil || len(channels) != 2 {
t.Fatalf( t.Fatalf(
@@ -295,17 +316,21 @@ func TestChangeNick(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, token, err := database.CreateUser(ctx, "old") sid, _, token, err := database.CreateSession(
ctx, "old",
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.ChangeNick(ctx, uid, "new") err = database.ChangeNick(ctx, sid, "new")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, nick, err := database.GetUserByToken(ctx, token) _, _, nick, err := database.GetSessionByToken(
ctx, token,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -375,7 +400,16 @@ func TestPollMessages(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, "poller") sid, _, token, err := database.CreateSession(
ctx, "poller",
)
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.GetSessionByToken(
ctx, token,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -389,7 +423,7 @@ func TestPollMessages(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.EnqueueMessage(ctx, uid, dbID) err = database.EnqueueToSession(ctx, sid, dbID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -397,7 +431,7 @@ func TestPollMessages(t *testing.T) {
const batchSize = 10 const batchSize = 10
msgs, lastQID, err := database.PollMessages( msgs, lastQID, err := database.PollMessages(
ctx, uid, 0, batchSize, ctx, clientID, 0, batchSize,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -420,7 +454,7 @@ func TestPollMessages(t *testing.T) {
} }
msgs, _, _ = database.PollMessages( msgs, _, _ = database.PollMessages(
ctx, uid, lastQID, batchSize, ctx, clientID, lastQID, batchSize,
) )
if len(msgs) != 0 { if len(msgs) != 0 {
@@ -467,13 +501,15 @@ func TestGetHistory(t *testing.T) {
} }
} }
func TestDeleteUser(t *testing.T) { func TestDeleteSession(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid, _, err := database.CreateUser(ctx, "deleteme") sid, _, _, err := database.CreateSession(
ctx, "deleteme",
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -485,19 +521,19 @@ func TestDeleteUser(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid) err = database.JoinChannel(ctx, chID, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.DeleteUser(ctx, uid) err = database.DeleteSession(ctx, sid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = database.GetUserByNick(ctx, "deleteme") _, err = database.GetSessionByNick(ctx, "deleteme")
if err == nil { if err == nil {
t.Fatal("user should be deleted") t.Fatal("session should be deleted")
} }
ids, _ := database.GetChannelMemberIDs(ctx, chID) ids, _ := database.GetChannelMemberIDs(ctx, chID)
@@ -512,12 +548,12 @@ func TestChannelMembers(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
uid1, _, err := database.CreateUser(ctx, "m1") sid1, _, _, err := database.CreateSession(ctx, "m1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
uid2, _, err := database.CreateUser(ctx, "m2") sid2, _, _, err := database.CreateSession(ctx, "m2")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -529,12 +565,12 @@ func TestChannelMembers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid1) err = database.JoinChannel(ctx, chID, sid1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = database.JoinChannel(ctx, chID, uid2) err = database.JoinChannel(ctx, chID, sid2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -548,17 +584,17 @@ func TestChannelMembers(t *testing.T) {
} }
} }
func TestGetAllChannelMembershipsForUser(t *testing.T) { func TestGetSessionChannels(t *testing.T) {
t.Parallel() t.Parallel()
database := setupTestDB(t) database := setupTestDB(t)
uid, _, _ := createUserWithChannels( sid, _, _ := createSessionWithChannels(
t, database, "multi", "#m1", "#m2", t, database, "multi", "#m1", "#m2",
) )
channels, err := channels, err :=
database.GetAllChannelMembershipsForUser( database.GetSessionChannels(
t.Context(), uid, t.Context(), sid,
) )
if err != nil || len(channels) != 2 { if err != nil || len(channels) != 2 {
t.Fatalf( t.Fatalf(
@@ -567,3 +603,51 @@ func TestGetAllChannelMembershipsForUser(t *testing.T) {
) )
} }
} }
func TestEnqueueToClient(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(
ctx, "enqclient",
)
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.GetSessionByToken(
ctx, token,
)
if err != nil {
t.Fatal(err)
}
body := json.RawMessage(`["test"]`)
dbID, _, err := database.InsertMessage(
ctx, "PRIVMSG", "sender", "#ch", body, nil,
)
if err != nil {
t.Fatal(err)
}
err = database.EnqueueToClient(ctx, clientID, dbID)
if err != nil {
t.Fatal(err)
}
const batchSize = 10
msgs, _, err := database.PollMessages(
ctx, clientID, 0, batchSize,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1, got %d", len(msgs))
}
}

View File

@@ -1,15 +1,29 @@
-- Chat server schema (pre-1.0 consolidated) -- Chat server schema (pre-1.0 consolidated)
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
-- Users: IRC-style sessions (no passwords, just nick + token) -- Sessions: each session is a user identity (nick + optional password + signing key)
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
nick TEXT NOT NULL UNIQUE, nick TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL DEFAULT '',
signing_key TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid);
-- Clients: each session can have multiple connected clients
CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE, token TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE INDEX IF NOT EXISTS idx_users_token ON users(token); CREATE INDEX IF NOT EXISTS idx_clients_token ON clients(token);
CREATE INDEX IF NOT EXISTS idx_clients_session ON clients(session_id);
-- Channels -- Channels
CREATE TABLE IF NOT EXISTS channels ( CREATE TABLE IF NOT EXISTS channels (
@@ -24,9 +38,9 @@ CREATE TABLE IF NOT EXISTS channels (
CREATE TABLE IF NOT EXISTS channel_members ( CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP, joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, user_id) UNIQUE(channel_id, session_id)
); );
-- Messages: IRC envelope format -- Messages: IRC envelope format
@@ -46,9 +60,9 @@ CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
-- Per-client message queues for fan-out delivery -- Per-client message queues for fan-out delivery
CREATE TABLE IF NOT EXISTS client_queues ( CREATE TABLE IF NOT EXISTS client_queues (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, client_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE, message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, message_id) UNIQUE(client_id, message_id)
); );
CREATE INDEX IF NOT EXISTS idx_client_queues_user ON client_queues(user_id, id); CREATE INDEX IF NOT EXISTS idx_client_queues_client ON client_queues(client_id, id);

View File

@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -37,35 +38,37 @@ func (hdlr *Handlers) maxBodySize() int64 {
return defaultMaxBodySize return defaultMaxBodySize
} }
// authUser extracts the user from the Authorization header. // authSession extracts the session from the client token.
func (hdlr *Handlers) authUser( func (hdlr *Handlers) authSession(
request *http.Request, request *http.Request,
) (int64, string, error) { ) (int64, int64, string, error) {
auth := request.Header.Get("Authorization") auth := request.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") { if !strings.HasPrefix(auth, "Bearer ") {
return 0, "", errUnauthorized return 0, 0, "", errUnauthorized
} }
token := strings.TrimPrefix(auth, "Bearer ") token := strings.TrimPrefix(auth, "Bearer ")
if token == "" { if token == "" {
return 0, "", errUnauthorized return 0, 0, "", errUnauthorized
} }
uid, nick, err := hdlr.params.Database.GetUserByToken( sessionID, clientID, nick, err :=
hdlr.params.Database.GetSessionByToken(
request.Context(), token, request.Context(), token,
) )
if err != nil { if err != nil {
return 0, "", fmt.Errorf("auth: %w", err) return 0, 0, "", fmt.Errorf("auth: %w", err)
} }
return uid, nick, nil return sessionID, clientID, nick, nil
} }
func (hdlr *Handlers) requireAuth( func (hdlr *Handlers) requireAuth(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) (int64, string, bool) { ) (int64, int64, string, bool) {
uid, nick, err := hdlr.authUser(request) sessionID, clientID, nick, err :=
hdlr.authSession(request)
if err != nil { if err != nil {
hdlr.respondError( hdlr.respondError(
writer, request, writer, request,
@@ -73,19 +76,19 @@ func (hdlr *Handlers) requireAuth(
http.StatusUnauthorized, http.StatusUnauthorized,
) )
return 0, "", false return 0, 0, "", false
} }
return uid, nick, true return sessionID, clientID, nick, true
} }
// fanOut stores a message and enqueues it to all specified // fanOut stores a message and enqueues it to all specified
// user IDs, then notifies them. // session IDs, then notifies them.
func (hdlr *Handlers) fanOut( func (hdlr *Handlers) fanOut(
request *http.Request, request *http.Request,
command, from, target string, command, from, target string,
body json.RawMessage, body json.RawMessage,
userIDs []int64, sessionIDs []int64,
) (string, error) { ) (string, error) {
dbID, msgUUID, err := hdlr.params.Database.InsertMessage( dbID, msgUUID, err := hdlr.params.Database.InsertMessage(
request.Context(), command, from, target, body, nil, request.Context(), command, from, target, body, nil,
@@ -94,16 +97,16 @@ func (hdlr *Handlers) fanOut(
return "", fmt.Errorf("insert message: %w", err) return "", fmt.Errorf("insert message: %w", err)
} }
for _, uid := range userIDs { for _, sid := range sessionIDs {
enqErr := hdlr.params.Database.EnqueueMessage( enqErr := hdlr.params.Database.EnqueueToSession(
request.Context(), uid, dbID, request.Context(), sid, dbID,
) )
if enqErr != nil { if enqErr != nil {
hdlr.log.Error("enqueue failed", hdlr.log.Error("enqueue failed",
"error", enqErr, "user_id", uid) "error", enqErr, "session_id", sid)
} }
hdlr.broker.Notify(uid) hdlr.broker.Notify(sid)
} }
return msgUUID, nil return msgUUID, nil
@@ -114,10 +117,10 @@ func (hdlr *Handlers) fanOutSilent(
request *http.Request, request *http.Request,
command, from, target string, command, from, target string,
body json.RawMessage, body json.RawMessage,
userIDs []int64, sessionIDs []int64,
) error { ) error {
_, err := hdlr.fanOut( _, err := hdlr.fanOut(
request, command, from, target, body, userIDs, request, command, from, target, body, sessionIDs,
) )
return err return err
@@ -125,16 +128,6 @@ func (hdlr *Handlers) fanOutSilent(
// HandleCreateSession creates a new user session. // HandleCreateSession creates a new user session.
func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc { func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
type createRequest struct {
Nick string `json:"nick"`
}
type createResponse struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
return func( return func(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
@@ -143,6 +136,18 @@ func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
writer, request.Body, hdlr.maxBodySize(), writer, request.Body, hdlr.maxBodySize(),
) )
hdlr.handleCreateSession(writer, request)
}
}
func (hdlr *Handlers) handleCreateSession(
writer http.ResponseWriter,
request *http.Request,
) {
type createRequest struct {
Nick string `json:"nick"`
}
var payload createRequest var payload createRequest
err := json.NewDecoder(request.Body).Decode(&payload) err := json.NewDecoder(request.Body).Decode(&payload)
@@ -168,10 +173,32 @@ func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
return return
} }
userID, token, err := hdlr.params.Database.CreateUser( sessionID, clientID, token, err :=
hdlr.params.Database.CreateSession(
request.Context(), payload.Nick, request.Context(), payload.Nick,
) )
if err != nil { if err != nil {
hdlr.handleCreateSessionError(
writer, request, err,
)
return
}
hdlr.deliverMOTD(request, clientID, sessionID)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleCreateSessionError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
hdlr.respondError( hdlr.respondError(
writer, request, writer, request,
@@ -183,42 +210,100 @@ func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
} }
hdlr.log.Error( hdlr.log.Error(
"create user failed", "error", err, "create session failed", "error", err,
) )
hdlr.respondError( hdlr.respondError(
writer, request, writer, request,
"internal error", "internal error",
http.StatusInternalServerError, http.StatusInternalServerError,
) )
}
// deliverMOTD sends the MOTD as IRC numeric messages to a
// new client.
func (hdlr *Handlers) deliverMOTD(
request *http.Request,
clientID, sessionID int64,
) {
motd := hdlr.params.Config.MOTD
serverName := hdlr.params.Config.ServerName
if serverName == "" {
serverName = "chat"
}
if motd == "" {
return
}
ctx := request.Context()
hdlr.enqueueNumeric(
ctx, clientID, "375", serverName,
"- "+serverName+" Message of the Day -",
)
for line := range strings.SplitSeq(motd, "\n") {
hdlr.enqueueNumeric(
ctx, clientID, "372", serverName,
"- "+line,
)
}
hdlr.enqueueNumeric(
ctx, clientID, "376", serverName,
"End of /MOTD command.",
)
hdlr.broker.Notify(sessionID)
}
func (hdlr *Handlers) enqueueNumeric(
ctx context.Context,
clientID int64,
command, serverName, text string,
) {
body, err := json.Marshal([]string{text})
if err != nil {
hdlr.log.Error(
"marshal numeric body", "error", err,
)
return return
} }
hdlr.respondJSON( dbID, _, insertErr := hdlr.params.Database.InsertMessage(
writer, request, ctx, command, serverName, "",
&createResponse{ json.RawMessage(body), nil,
ID: userID,
Nick: payload.Nick,
Token: token,
},
http.StatusCreated,
) )
} if insertErr != nil {
hdlr.log.Error(
"insert numeric message", "error", insertErr,
)
return
} }
// HandleState returns the current user's info and channels. _ = hdlr.params.Database.EnqueueToClient(
ctx, clientID, dbID,
)
}
// HandleState returns the current session's info and
// channels.
func (hdlr *Handlers) HandleState() http.HandlerFunc { func (hdlr *Handlers) HandleState() http.HandlerFunc {
return func( return func(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
uid, nick, ok := hdlr.requireAuth(writer, request) sessionID, _, nick, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
channels, err := hdlr.params.Database.ListChannels( channels, err := hdlr.params.Database.ListChannels(
request.Context(), uid, request.Context(), sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@@ -234,7 +319,7 @@ func (hdlr *Handlers) HandleState() http.HandlerFunc {
} }
hdlr.respondJSON(writer, request, map[string]any{ hdlr.respondJSON(writer, request, map[string]any{
"id": uid, "id": sessionID,
"nick": nick, "nick": nick,
"channels": channels, "channels": channels,
}, http.StatusOK) }, http.StatusOK)
@@ -247,7 +332,7 @@ func (hdlr *Handlers) HandleListAllChannels() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
_, _, ok := hdlr.requireAuth(writer, request) _, _, _, ok := hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@@ -280,7 +365,7 @@ func (hdlr *Handlers) HandleChannelMembers() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
_, _, ok := hdlr.requireAuth(writer, request) _, _, _, ok := hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@@ -328,7 +413,8 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
uid, _, ok := hdlr.requireAuth(writer, request) sessionID, clientID, _, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@@ -349,7 +435,7 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
} }
msgs, lastQID, err := hdlr.params.Database.PollMessages( msgs, lastQID, err := hdlr.params.Database.PollMessages(
request.Context(), uid, request.Context(), clientID,
afterID, pollMessageLimit, afterID, pollMessageLimit,
) )
if err != nil { if err != nil {
@@ -374,17 +460,20 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
return return
} }
hdlr.longPoll(writer, request, uid, afterID, timeout) hdlr.longPoll(
writer, request,
sessionID, clientID, afterID, timeout,
)
} }
} }
func (hdlr *Handlers) longPoll( func (hdlr *Handlers) longPoll(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid, afterID int64, sessionID, clientID, afterID int64,
timeout int, timeout int,
) { ) {
waitCh := hdlr.broker.Wait(uid) waitCh := hdlr.broker.Wait(sessionID)
timer := time.NewTimer( timer := time.NewTimer(
time.Duration(timeout) * time.Second, time.Duration(timeout) * time.Second,
@@ -396,15 +485,15 @@ func (hdlr *Handlers) longPoll(
case <-waitCh: case <-waitCh:
case <-timer.C: case <-timer.C:
case <-request.Context().Done(): case <-request.Context().Done():
hdlr.broker.Remove(uid, waitCh) hdlr.broker.Remove(sessionID, waitCh)
return return
} }
hdlr.broker.Remove(uid, waitCh) hdlr.broker.Remove(sessionID, waitCh)
msgs, lastQID, err := hdlr.params.Database.PollMessages( msgs, lastQID, err := hdlr.params.Database.PollMessages(
request.Context(), uid, request.Context(), clientID,
afterID, pollMessageLimit, afterID, pollMessageLimit,
) )
if err != nil { if err != nil {
@@ -443,7 +532,8 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
writer, request.Body, hdlr.maxBodySize(), writer, request.Body, hdlr.maxBodySize(),
) )
uid, nick, ok := hdlr.requireAuth(writer, request) sessionID, _, nick, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@@ -492,7 +582,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
} }
hdlr.dispatchCommand( hdlr.dispatchCommand(
writer, request, uid, nick, writer, request, sessionID, nick,
payload.Command, payload.To, payload.Command, payload.To,
payload.Body, bodyLines, payload.Body, bodyLines,
) )
@@ -502,7 +592,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
func (hdlr *Handlers) dispatchCommand( func (hdlr *Handlers) dispatchCommand(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
bodyLines func() []string, bodyLines func() []string,
@@ -510,20 +600,20 @@ func (hdlr *Handlers) dispatchCommand(
switch command { switch command {
case cmdPrivmsg, "NOTICE": case cmdPrivmsg, "NOTICE":
hdlr.handlePrivmsg( hdlr.handlePrivmsg(
writer, request, uid, nick, writer, request, sessionID, nick,
command, target, body, bodyLines, command, target, body, bodyLines,
) )
case "JOIN": case "JOIN":
hdlr.handleJoin( hdlr.handleJoin(
writer, request, uid, nick, target, writer, request, sessionID, nick, target,
) )
case "PART": case "PART":
hdlr.handlePart( hdlr.handlePart(
writer, request, uid, nick, target, body, writer, request, sessionID, nick, target, body,
) )
case "NICK": case "NICK":
hdlr.handleNick( hdlr.handleNick(
writer, request, uid, nick, bodyLines, writer, request, sessionID, nick, bodyLines,
) )
case "TOPIC": case "TOPIC":
hdlr.handleTopic( hdlr.handleTopic(
@@ -531,7 +621,7 @@ func (hdlr *Handlers) dispatchCommand(
) )
case "QUIT": case "QUIT":
hdlr.handleQuit( hdlr.handleQuit(
writer, request, uid, nick, body, writer, request, sessionID, nick, body,
) )
case "PING": case "PING":
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
@@ -552,7 +642,7 @@ func (hdlr *Handlers) dispatchCommand(
func (hdlr *Handlers) handlePrivmsg( func (hdlr *Handlers) handlePrivmsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
bodyLines func() []string, bodyLines func() []string,
@@ -580,7 +670,7 @@ func (hdlr *Handlers) handlePrivmsg(
if strings.HasPrefix(target, "#") { if strings.HasPrefix(target, "#") {
hdlr.handleChannelMsg( hdlr.handleChannelMsg(
writer, request, uid, nick, writer, request, sessionID, nick,
command, target, body, command, target, body,
) )
@@ -588,7 +678,7 @@ func (hdlr *Handlers) handlePrivmsg(
} }
hdlr.handleDirectMsg( hdlr.handleDirectMsg(
writer, request, uid, nick, writer, request, sessionID, nick,
command, target, body, command, target, body,
) )
} }
@@ -596,7 +686,7 @@ func (hdlr *Handlers) handlePrivmsg(
func (hdlr *Handlers) handleChannelMsg( func (hdlr *Handlers) handleChannelMsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
) { ) {
@@ -614,7 +704,7 @@ func (hdlr *Handlers) handleChannelMsg(
} }
isMember, err := hdlr.params.Database.IsChannelMember( isMember, err := hdlr.params.Database.IsChannelMember(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@@ -677,11 +767,11 @@ func (hdlr *Handlers) handleChannelMsg(
func (hdlr *Handlers) handleDirectMsg( func (hdlr *Handlers) handleDirectMsg(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, command, target string, nick, command, target string,
body json.RawMessage, body json.RawMessage,
) { ) {
targetUID, err := hdlr.params.Database.GetUserByNick( targetSID, err := hdlr.params.Database.GetSessionByNick(
request.Context(), target, request.Context(), target,
) )
if err != nil { if err != nil {
@@ -694,9 +784,9 @@ func (hdlr *Handlers) handleDirectMsg(
return return
} }
recipients := []int64{targetUID} recipients := []int64{targetSID}
if targetUID != uid { if targetSID != sessionID {
recipients = append(recipients, uid) recipients = append(recipients, sessionID)
} }
msgUUID, err := hdlr.fanOut( msgUUID, err := hdlr.fanOut(
@@ -721,7 +811,7 @@ func (hdlr *Handlers) handleDirectMsg(
func (hdlr *Handlers) handleJoin( func (hdlr *Handlers) handleJoin(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, target string, nick, target string,
) { ) {
if target == "" { if target == "" {
@@ -766,7 +856,7 @@ func (hdlr *Handlers) handleJoin(
} }
err = hdlr.params.Database.JoinChannel( err = hdlr.params.Database.JoinChannel(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@@ -800,7 +890,7 @@ func (hdlr *Handlers) handleJoin(
func (hdlr *Handlers) handlePart( func (hdlr *Handlers) handlePart(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, target string, nick, target string,
body json.RawMessage, body json.RawMessage,
) { ) {
@@ -841,7 +931,7 @@ func (hdlr *Handlers) handlePart(
) )
err = hdlr.params.Database.PartChannel( err = hdlr.params.Database.PartChannel(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@@ -871,7 +961,7 @@ func (hdlr *Handlers) handlePart(
func (hdlr *Handlers) handleNick( func (hdlr *Handlers) handleNick(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick string, nick string,
bodyLines func() []string, bodyLines func() []string,
) { ) {
@@ -909,7 +999,7 @@ func (hdlr *Handlers) handleNick(
} }
err := hdlr.params.Database.ChangeNick( err := hdlr.params.Database.ChangeNick(
request.Context(), uid, newNick, request.Context(), sessionID, newNick,
) )
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
@@ -934,7 +1024,7 @@ func (hdlr *Handlers) handleNick(
return return
} }
hdlr.broadcastNick(request, uid, nick, newNick) hdlr.broadcastNick(request, sessionID, nick, newNick)
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
map[string]string{ map[string]string{
@@ -945,15 +1035,15 @@ func (hdlr *Handlers) handleNick(
func (hdlr *Handlers) broadcastNick( func (hdlr *Handlers) broadcastNick(
request *http.Request, request *http.Request,
uid int64, sessionID int64,
oldNick, newNick string, oldNick, newNick string,
) { ) {
channels, _ := hdlr.params.Database. channels, _ := hdlr.params.Database.
GetAllChannelMembershipsForUser( GetSessionChannels(
request.Context(), uid, request.Context(), sessionID,
) )
notified := map[int64]bool{uid: true} notified := map[int64]bool{sessionID: true}
nickBody, err := json.Marshal([]string{newNick}) nickBody, err := json.Marshal([]string{newNick})
if err != nil { if err != nil {
@@ -969,11 +1059,11 @@ func (hdlr *Handlers) broadcastNick(
json.RawMessage(nickBody), nil, json.RawMessage(nickBody), nil,
) )
_ = hdlr.params.Database.EnqueueMessage( _ = hdlr.params.Database.EnqueueToSession(
request.Context(), uid, dbID, request.Context(), sessionID, dbID,
) )
hdlr.broker.Notify(uid) hdlr.broker.Notify(sessionID)
for _, chanInfo := range channels { for _, chanInfo := range channels {
memberIDs, _ := hdlr.params.Database. memberIDs, _ := hdlr.params.Database.
@@ -985,7 +1075,7 @@ func (hdlr *Handlers) broadcastNick(
if !notified[mid] { if !notified[mid] {
notified[mid] = true notified[mid] = true
_ = hdlr.params.Database.EnqueueMessage( _ = hdlr.params.Database.EnqueueToSession(
request.Context(), mid, dbID, request.Context(), mid, dbID,
) )
@@ -1077,13 +1167,13 @@ func (hdlr *Handlers) handleTopic(
func (hdlr *Handlers) handleQuit( func (hdlr *Handlers) handleQuit(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick string, nick string,
body json.RawMessage, body json.RawMessage,
) { ) {
channels, _ := hdlr.params.Database. channels, _ := hdlr.params.Database.
GetAllChannelMembershipsForUser( GetSessionChannels(
request.Context(), uid, request.Context(), sessionID,
) )
notified := map[int64]bool{} notified := map[int64]bool{}
@@ -1103,10 +1193,10 @@ func (hdlr *Handlers) handleQuit(
) )
for _, mid := range memberIDs { for _, mid := range memberIDs {
if mid != uid && !notified[mid] { if mid != sessionID && !notified[mid] {
notified[mid] = true notified[mid] = true
_ = hdlr.params.Database.EnqueueMessage( _ = hdlr.params.Database.EnqueueToSession(
request.Context(), mid, dbID, request.Context(), mid, dbID,
) )
@@ -1115,7 +1205,7 @@ func (hdlr *Handlers) handleQuit(
} }
_ = hdlr.params.Database.PartChannel( _ = hdlr.params.Database.PartChannel(
request.Context(), chanInfo.ID, uid, request.Context(), chanInfo.ID, sessionID,
) )
_ = hdlr.params.Database.DeleteChannelIfEmpty( _ = hdlr.params.Database.DeleteChannelIfEmpty(
@@ -1123,8 +1213,8 @@ func (hdlr *Handlers) handleQuit(
) )
} }
_ = hdlr.params.Database.DeleteUser( _ = hdlr.params.Database.DeleteSession(
request.Context(), uid, request.Context(), sessionID,
) )
hdlr.respondJSON(writer, request, hdlr.respondJSON(writer, request,
@@ -1138,7 +1228,8 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
uid, nick, ok := hdlr.requireAuth(writer, request) sessionID, _, nick, ok :=
hdlr.requireAuth(writer, request)
if !ok { if !ok {
return return
} }
@@ -1155,7 +1246,7 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
} }
if !hdlr.canAccessHistory( if !hdlr.canAccessHistory(
writer, request, uid, nick, target, writer, request, sessionID, nick, target,
) { ) {
return return
} }
@@ -1198,12 +1289,12 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
func (hdlr *Handlers) canAccessHistory( func (hdlr *Handlers) canAccessHistory(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
nick, target string, nick, target string,
) bool { ) bool {
if strings.HasPrefix(target, "#") { if strings.HasPrefix(target, "#") {
return hdlr.canAccessChannelHistory( return hdlr.canAccessChannelHistory(
writer, request, uid, target, writer, request, sessionID, target,
) )
} }
@@ -1225,7 +1316,7 @@ func (hdlr *Handlers) canAccessHistory(
func (hdlr *Handlers) canAccessChannelHistory( func (hdlr *Handlers) canAccessChannelHistory(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
uid int64, sessionID int64,
target string, target string,
) bool { ) bool {
chID, err := hdlr.params.Database.GetChannelByName( chID, err := hdlr.params.Database.GetChannelByName(
@@ -1242,7 +1333,7 @@ func (hdlr *Handlers) canAccessChannelHistory(
} }
isMember, err := hdlr.params.Database.IsChannelMember( isMember, err := hdlr.params.Database.IsChannelMember(
request.Context(), chID, uid, request.Context(), chID, sessionID,
) )
if err != nil { if err != nil {
hdlr.log.Error( hdlr.log.Error(
@@ -1270,20 +1361,139 @@ func (hdlr *Handlers) canAccessChannelHistory(
return true return true
} }
// HandleServerInfo returns server metadata. // HandleLogout deletes the authenticated client's token
func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc { // and cleans up the user (session) if no clients remain.
type infoResponse struct { func (hdlr *Handlers) HandleLogout() http.HandlerFunc {
Name string `json:"name"`
MOTD string `json:"motd"`
}
return func( return func(
writer http.ResponseWriter, writer http.ResponseWriter,
request *http.Request, request *http.Request,
) { ) {
hdlr.respondJSON(writer, request, &infoResponse{ sessionID, clientID, nick, ok :=
Name: hdlr.params.Config.ServerName, hdlr.requireAuth(writer, request)
MOTD: hdlr.params.Config.MOTD, if !ok {
return
}
ctx := request.Context()
err := hdlr.params.Database.DeleteClient(
ctx, clientID,
)
if err != nil {
hdlr.log.Error(
"delete client failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
// If no clients remain, clean up the user fully:
// part all channels (notifying members) and
// delete the session.
remaining, err := hdlr.params.Database.
ClientCountForSession(ctx, sessionID)
if err != nil {
hdlr.log.Error(
"client count check failed", "error", err,
)
}
if remaining == 0 {
hdlr.cleanupUser(
ctx, sessionID, nick,
)
}
hdlr.respondJSON(writer, request,
map[string]string{"status": "ok"},
http.StatusOK)
}
}
// cleanupUser parts the user from all channels (notifying
// members) and deletes the session.
func (hdlr *Handlers) cleanupUser(
ctx context.Context,
sessionID int64,
nick string,
) {
channels, _ := hdlr.params.Database.
GetSessionChannels(ctx, sessionID)
notified := map[int64]bool{}
var quitDBID int64
if len(channels) > 0 {
quitDBID, _, _ = hdlr.params.Database.InsertMessage(
ctx, "QUIT", nick, "", nil, nil,
)
}
for _, chanInfo := range channels {
memberIDs, _ := hdlr.params.Database.
GetChannelMemberIDs(ctx, chanInfo.ID)
for _, mid := range memberIDs {
if mid != sessionID && !notified[mid] {
notified[mid] = true
_ = hdlr.params.Database.EnqueueToSession(
ctx, mid, quitDBID,
)
hdlr.broker.Notify(mid)
}
}
_ = hdlr.params.Database.PartChannel(
ctx, chanInfo.ID, sessionID,
)
_ = hdlr.params.Database.DeleteChannelIfEmpty(
ctx, chanInfo.ID,
)
}
_ = hdlr.params.Database.DeleteSession(ctx, sessionID)
}
// HandleUsersMe returns the current user's session info.
func (hdlr *Handlers) HandleUsersMe() http.HandlerFunc {
return hdlr.HandleState()
}
// HandleServerInfo returns server metadata.
func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
users, err := hdlr.params.Database.GetUserCount(
request.Context(),
)
if err != nil {
hdlr.log.Error(
"get user count failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
return
}
hdlr.respondJSON(writer, request, map[string]any{
"name": hdlr.params.Config.ServerName,
"motd": hdlr.params.Config.MOTD,
"users": users,
}, http.StatusOK) }, http.StatusOK)
} }
} }

View File

@@ -1469,6 +1469,310 @@ func TestHealthcheck(t *testing.T) {
} }
} }
func TestRegisterValid(t *testing.T) {
tserver := newTestServer(t)
body, err := json.Marshal(map[string]string{
"nick": "reguser", "password": "password123",
})
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/register"),
bytes.NewReader(body),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusCreated {
respBody, _ := io.ReadAll(resp.Body)
t.Fatalf(
"expected 201, got %d: %s",
resp.StatusCode, respBody,
)
}
var result map[string]any
_ = json.NewDecoder(resp.Body).Decode(&result)
if result["token"] == nil || result["token"] == "" {
t.Fatal("expected token in response")
}
if result["nick"] != "reguser" {
t.Fatalf(
"expected reguser, got %v", result["nick"],
)
}
}
func TestRegisterDuplicate(t *testing.T) {
tserver := newTestServer(t)
body, err := json.Marshal(map[string]string{
"nick": "dupuser", "password": "password123",
})
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/register"),
bytes.NewReader(body),
)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()
resp2, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/register"),
bytes.NewReader(body),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp2.Body.Close() }()
if resp2.StatusCode != http.StatusConflict {
t.Fatalf("expected 409, got %d", resp2.StatusCode)
}
}
func postJSONExpectStatus(
t *testing.T,
tserver *testServer,
path string,
payload map[string]string,
expectedStatus int,
) {
t.Helper()
body, err := json.Marshal(payload)
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url(path),
bytes.NewReader(body),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != expectedStatus {
t.Fatalf(
"expected %d, got %d",
expectedStatus, resp.StatusCode,
)
}
}
func TestRegisterShortPassword(t *testing.T) {
tserver := newTestServer(t)
postJSONExpectStatus(
t, tserver, "/api/v1/register",
map[string]string{
"nick": "shortpw", "password": "short",
},
http.StatusBadRequest,
)
}
func TestRegisterInvalidNick(t *testing.T) {
tserver := newTestServer(t)
postJSONExpectStatus(
t, tserver, "/api/v1/register",
map[string]string{
"nick": "bad nick!",
"password": "password123",
},
http.StatusBadRequest,
)
}
func TestLoginValid(t *testing.T) {
tserver := newTestServer(t)
// Register first.
regBody, err := json.Marshal(map[string]string{
"nick": "loginuser", "password": "password123",
})
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/register"),
bytes.NewReader(regBody),
)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()
// Login.
loginBody, err := json.Marshal(map[string]string{
"nick": "loginuser", "password": "password123",
})
if err != nil {
t.Fatal(err)
}
resp2, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/login"),
bytes.NewReader(loginBody),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp2.Body.Close() }()
if resp2.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp2.Body)
t.Fatalf(
"expected 200, got %d: %s",
resp2.StatusCode, respBody,
)
}
var result map[string]any
_ = json.NewDecoder(resp2.Body).Decode(&result)
if result["token"] == nil || result["token"] == "" {
t.Fatal("expected token in response")
}
// Verify token works.
token, ok := result["token"].(string)
if !ok {
t.Fatal("token not a string")
}
status, state := tserver.getState(token)
if status != http.StatusOK {
t.Fatalf("expected 200, got %d", status)
}
if state["nick"] != "loginuser" {
t.Fatalf(
"expected loginuser, got %v",
state["nick"],
)
}
}
func TestLoginWrongPassword(t *testing.T) {
tserver := newTestServer(t)
regBody, err := json.Marshal(map[string]string{
"nick": "wrongpwuser", "password": "correctpass1",
})
if err != nil {
t.Fatal(err)
}
resp, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/register"),
bytes.NewReader(regBody),
)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()
loginBody, err := json.Marshal(map[string]string{
"nick": "wrongpwuser", "password": "wrongpass12",
})
if err != nil {
t.Fatal(err)
}
resp2, err := doRequest(
t,
http.MethodPost,
tserver.url("/api/v1/login"),
bytes.NewReader(loginBody),
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = resp2.Body.Close() }()
if resp2.StatusCode != http.StatusUnauthorized {
t.Fatalf(
"expected 401, got %d", resp2.StatusCode,
)
}
}
func TestLoginNonexistentUser(t *testing.T) {
tserver := newTestServer(t)
postJSONExpectStatus(
t, tserver, "/api/v1/login",
map[string]string{
"nick": "ghostuser",
"password": "password123",
},
http.StatusUnauthorized,
)
}
func TestSessionStillWorks(t *testing.T) {
tserver := newTestServer(t)
// Verify anonymous session creation still works.
token := tserver.createSession("anon_user")
if token == "" {
t.Fatal("expected token for anonymous session")
}
status, state := tserver.getState(token)
if status != http.StatusOK {
t.Fatalf("expected 200, got %d", status)
}
if state["nick"] != "anon_user" {
t.Fatalf(
"expected anon_user, got %v",
state["nick"],
)
}
}
func TestNickBroadcastToChannels(t *testing.T) { func TestNickBroadcastToChannels(t *testing.T) {
tserver := newTestServer(t) tserver := newTestServer(t)
aliceToken := tserver.createSession("nick_a") aliceToken := tserver.createSession("nick_a")

186
internal/handlers/auth.go Normal file
View File

@@ -0,0 +1,186 @@
package handlers
import (
"encoding/json"
"net/http"
"strings"
)
const minPasswordLength = 8
// HandleRegister creates a new user with a password.
func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
request.Body = http.MaxBytesReader(
writer, request.Body, hdlr.maxBodySize(),
)
hdlr.handleRegister(writer, request)
}
}
func (hdlr *Handlers) handleRegister(
writer http.ResponseWriter,
request *http.Request,
) {
type registerRequest struct {
Nick string `json:"nick"`
Password string `json:"password"`
}
var payload registerRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
hdlr.respondError(
writer, request,
"invalid nick format",
http.StatusBadRequest,
)
return
}
if len(payload.Password) < minPasswordLength {
hdlr.respondError(
writer, request,
"password must be at least 8 characters",
http.StatusBadRequest,
)
return
}
sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser(
request.Context(),
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.handleRegisterError(
writer, request, err,
)
return
}
hdlr.deliverMOTD(request, clientID, sessionID)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusCreated)
}
func (hdlr *Handlers) handleRegisterError(
writer http.ResponseWriter,
request *http.Request,
err error,
) {
if strings.Contains(err.Error(), "UNIQUE") {
hdlr.respondError(
writer, request,
"nick already taken",
http.StatusConflict,
)
return
}
hdlr.log.Error(
"register user failed", "error", err,
)
hdlr.respondError(
writer, request,
"internal error",
http.StatusInternalServerError,
)
}
// HandleLogin authenticates a user with nick and password.
func (hdlr *Handlers) HandleLogin() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
request.Body = http.MaxBytesReader(
writer, request.Body, hdlr.maxBodySize(),
)
hdlr.handleLogin(writer, request)
}
}
func (hdlr *Handlers) handleLogin(
writer http.ResponseWriter,
request *http.Request,
) {
type loginRequest struct {
Nick string `json:"nick"`
Password string `json:"password"`
}
var payload loginRequest
err := json.NewDecoder(request.Body).Decode(&payload)
if err != nil {
hdlr.respondError(
writer, request,
"invalid request body",
http.StatusBadRequest,
)
return
}
payload.Nick = strings.TrimSpace(payload.Nick)
if payload.Nick == "" || payload.Password == "" {
hdlr.respondError(
writer, request,
"nick and password required",
http.StatusBadRequest,
)
return
}
sessionID, _, token, err :=
hdlr.params.Database.LoginUser(
request.Context(),
payload.Nick,
payload.Password,
)
if err != nil {
hdlr.respondError(
writer, request,
"invalid credentials",
http.StatusUnauthorized,
)
return
}
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"token": token,
}, http.StatusOK)
}

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"net/http" "net/http"
"time"
"git.eeqj.de/sneak/chat/internal/broker" "git.eeqj.de/sneak/chat/internal/broker"
"git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/config"
@@ -30,12 +31,15 @@ type Params struct {
Healthcheck *healthcheck.Healthcheck Healthcheck *healthcheck.Healthcheck
} }
const defaultIdleTimeout = 24 * time.Hour
// Handlers manages HTTP request handling. // Handlers manages HTTP request handling.
type Handlers struct { type Handlers struct {
params *Params params *Params
log *slog.Logger log *slog.Logger
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
broker *broker.Broker broker *broker.Broker
cancelCleanup context.CancelFunc
} }
// New creates a new Handlers instance. // New creates a new Handlers instance.
@@ -43,7 +47,7 @@ func New(
lifecycle fx.Lifecycle, lifecycle fx.Lifecycle,
params Params, params Params,
) (*Handlers, error) { ) (*Handlers, error) {
hdlr := &Handlers{ hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params, params: &params,
log: params.Logger.Get(), log: params.Logger.Get(),
hc: params.Healthcheck, hc: params.Healthcheck,
@@ -51,10 +55,14 @@ func New(
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(ctx context.Context) error {
hdlr.startCleanup(ctx)
return nil return nil
}, },
OnStop: func(_ context.Context) error { OnStop: func(_ context.Context) error {
hdlr.stopCleanup()
return nil return nil
}, },
}) })
@@ -96,3 +104,100 @@ func (hdlr *Handlers) respondError(
status, status,
) )
} }
func (hdlr *Handlers) idleTimeout() time.Duration {
raw := hdlr.params.Config.SessionIdleTimeout
if raw == "" {
return defaultIdleTimeout
}
dur, err := time.ParseDuration(raw)
if err != nil {
hdlr.log.Error(
"invalid SESSION_IDLE_TIMEOUT, using default",
"value", raw, "error", err,
)
return defaultIdleTimeout
}
return dur
}
// startCleanup launches the idle-user cleanup goroutine.
// We use context.Background rather than the OnStart ctx
// because the OnStart context is startup-scoped and would
// cancel the goroutine once all start hooks complete.
//
//nolint:contextcheck // intentional Background ctx
func (hdlr *Handlers) startCleanup(_ context.Context) {
cleanupCtx, cancel := context.WithCancel(
context.Background(),
)
hdlr.cancelCleanup = cancel
go hdlr.cleanupLoop(cleanupCtx)
}
func (hdlr *Handlers) stopCleanup() {
if hdlr.cancelCleanup != nil {
hdlr.cancelCleanup()
}
}
func (hdlr *Handlers) cleanupLoop(ctx context.Context) {
timeout := hdlr.idleTimeout()
interval := max(timeout/2, time.Minute) //nolint:mnd // half the timeout
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
hdlr.runCleanup(ctx, timeout)
case <-ctx.Done():
return
}
}
}
func (hdlr *Handlers) runCleanup(
ctx context.Context,
timeout time.Duration,
) {
cutoff := time.Now().Add(-timeout)
// Find sessions that will be orphaned so we can send
// QUIT notifications before deleting anything.
stale, err := hdlr.params.Database.
GetStaleOrphanSessions(ctx, cutoff)
if err != nil {
hdlr.log.Error(
"stale session lookup failed", "error", err,
)
}
for _, ss := range stale {
hdlr.cleanupUser(ctx, ss.ID, ss.Nick)
}
deleted, err := hdlr.params.Database.DeleteStaleUsers(
ctx, cutoff,
)
if err != nil {
hdlr.log.Error(
"user cleanup failed", "error", err,
)
return
}
if deleted > 0 {
hdlr.log.Info(
"cleaned up stale users",
"deleted", deleted,
)
}
}

View File

@@ -59,9 +59,13 @@ func (srv *Server) SetupRoutes() {
} }
// API v1. // API v1.
srv.router.Route( srv.router.Route("/api/v1", srv.setupAPIv1)
"/api/v1",
func(router chi.Router) { // Serve embedded SPA.
srv.setupSPA()
}
func (srv *Server) setupAPIv1(router chi.Router) {
router.Get( router.Get(
"/server", "/server",
srv.handlers.HandleServerInfo(), srv.handlers.HandleServerInfo(),
@@ -70,10 +74,26 @@ func (srv *Server) SetupRoutes() {
"/session", "/session",
srv.handlers.HandleCreateSession(), srv.handlers.HandleCreateSession(),
) )
router.Post(
"/register",
srv.handlers.HandleRegister(),
)
router.Post(
"/login",
srv.handlers.HandleLogin(),
)
router.Get( router.Get(
"/state", "/state",
srv.handlers.HandleState(), srv.handlers.HandleState(),
) )
router.Post(
"/logout",
srv.handlers.HandleLogout(),
)
router.Get(
"/users/me",
srv.handlers.HandleUsersMe(),
)
router.Get( router.Get(
"/messages", "/messages",
srv.handlers.HandleGetMessages(), srv.handlers.HandleGetMessages(),
@@ -94,11 +114,6 @@ func (srv *Server) SetupRoutes() {
"/channels/{channel}/members", "/channels/{channel}/members",
srv.handlers.HandleChannelMembers(), srv.handlers.HandleChannelMembers(),
) )
},
)
// Serve embedded SPA.
srv.setupSPA()
} }
func (srv *Server) setupSPA() { func (srv *Server) setupSPA() {

466
web/dist/app.js vendored

File diff suppressed because one or more lines are too long

43
web/dist/style.css vendored
View File

@@ -14,6 +14,9 @@
--tab-active: #e94560; --tab-active: #e94560;
--tab-bg: #16213e; --tab-bg: #16213e;
--tab-hover: #1a1a3e; --tab-hover: #1a1a3e;
--topic-bg: #121a30;
--unread-bg: #e94560;
--warn: #f0ad4e;
} }
html, body, #root { html, body, #root {
@@ -86,6 +89,7 @@ html, body, #root {
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
overflow-x: auto; overflow-x: auto;
flex-shrink: 0; flex-shrink: 0;
align-items: center;
} }
.tab { .tab {
@@ -95,6 +99,7 @@ html, body, #root {
white-space: nowrap; white-space: nowrap;
color: var(--text-muted); color: var(--text-muted);
user-select: none; user-select: none;
position: relative;
} }
.tab:hover { .tab:hover {
@@ -116,6 +121,43 @@ html, body, #root {
color: var(--accent); color: var(--accent);
} }
.tab .unread-badge {
display: inline-block;
background: var(--unread-bg);
color: white;
font-size: 10px;
font-weight: bold;
padding: 1px 5px;
border-radius: 8px;
margin-left: 6px;
min-width: 16px;
text-align: center;
}
/* Connection status */
.connection-status {
padding: 4px 12px;
background: var(--warn);
color: #1a1a2e;
font-size: 12px;
font-weight: bold;
white-space: nowrap;
flex-shrink: 0;
}
/* Topic bar */
.topic-bar {
padding: 6px 12px;
background: var(--topic-bg);
border-bottom: 1px solid var(--border);
color: var(--text-muted);
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
flex-shrink: 0;
}
/* Content area */ /* Content area */
.content { .content {
display: flex; display: flex;
@@ -243,6 +285,7 @@ html, body, #root {
gap: 8px; gap: 8px;
background: var(--bg-secondary); background: var(--bg-secondary);
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
margin-left: auto;
} }
.join-dialog input { .join-dialog input {

View File

@@ -1,13 +1,17 @@
import { h, render, Component } from 'preact'; import { h, render } from 'preact';
import { useState, useEffect, useRef, useCallback } from 'preact/hooks'; import { useState, useEffect, useRef, useCallback } from 'preact/hooks';
const API = '/api/v1'; const API = '/api/v1';
const POLL_TIMEOUT = 15;
const RECONNECT_DELAY = 3000;
const MEMBER_REFRESH_INTERVAL = 10000;
function api(path, opts = {}) { function api(path, opts = {}) {
const token = localStorage.getItem('chat_token'); const token = localStorage.getItem('chat_token');
const headers = { 'Content-Type': 'application/json', ...(opts.headers || {}) }; const headers = { 'Content-Type': 'application/json', ...(opts.headers || {}) };
if (token) headers['Authorization'] = `Bearer ${token}`; if (token) headers['Authorization'] = `Bearer ${token}`;
return fetch(API + path, { ...opts, headers }).then(async r => { const { signal, ...rest } = opts;
return fetch(API + path, { ...rest, headers, signal }).then(async r => {
const data = await r.json().catch(() => null); const data = await r.json().catch(() => null);
if (!r.ok) throw { status: r.status, data }; if (!r.ok) throw { status: r.status, data };
return data; return data;
@@ -19,7 +23,6 @@ function formatTime(ts) {
return d.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' }); return d.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' });
} }
// Nick color hashing
function nickColor(nick) { function nickColor(nick) {
let h = 0; let h = 0;
for (let i = 0; i < nick.length; i++) h = nick.charCodeAt(i) + ((h << 5) - h); for (let i = 0; i < nick.length; i++) h = nick.charCodeAt(i) + ((h << 5) - h);
@@ -39,10 +42,9 @@ function LoginScreen({ onLogin }) {
if (s.name) setServerName(s.name); if (s.name) setServerName(s.name);
if (s.motd) setMotd(s.motd); if (s.motd) setMotd(s.motd);
}).catch(() => {}); }).catch(() => {});
// Check for saved token
const saved = localStorage.getItem('chat_token'); const saved = localStorage.getItem('chat_token');
if (saved) { if (saved) {
api('/state').then(u => onLogin(u.nick, saved)).catch(() => localStorage.removeItem('chat_token')); api('/state').then(u => onLogin(u.nick)).catch(() => localStorage.removeItem('chat_token'));
} }
inputRef.current?.focus(); inputRef.current?.focus();
}, []); }, []);
@@ -56,7 +58,7 @@ function LoginScreen({ onLogin }) {
body: JSON.stringify({ nick: nick.trim() }) body: JSON.stringify({ nick: nick.trim() })
}); });
localStorage.setItem('chat_token', res.token); localStorage.setItem('chat_token', res.token);
onLogin(res.nick, res.token); onLogin(res.nick);
} catch (err) { } catch (err) {
setError(err.data?.error || 'Connection failed'); setError(err.data?.error || 'Connection failed');
} }
@@ -84,11 +86,19 @@ function LoginScreen({ onLogin }) {
} }
function Message({ msg }) { function Message({ msg }) {
if (msg.system) {
return ( return (
<div class={`message ${msg.system ? 'system' : ''}`}> <div class="message system">
<span class="timestamp">{formatTime(msg.createdAt)}</span> <span class="timestamp">{formatTime(msg.ts)}</span>
<span class="nick" style={{ color: msg.system ? undefined : nickColor(msg.nick) }}>{msg.nick}</span> <span class="content">{msg.text}</span>
<span class="content">{msg.content}</span> </div>
);
}
return (
<div class="message">
<span class="timestamp">{formatTime(msg.ts)}</span>
<span class="nick" style={{ color: nickColor(msg.from) }}>{msg.from}</span>
<span class="content">{msg.text}</span>
</div> </div>
); );
} }
@@ -98,93 +108,194 @@ function App() {
const [nick, setNick] = useState(''); const [nick, setNick] = useState('');
const [tabs, setTabs] = useState([{ type: 'server', name: 'Server' }]); const [tabs, setTabs] = useState([{ type: 'server', name: 'Server' }]);
const [activeTab, setActiveTab] = useState(0); const [activeTab, setActiveTab] = useState(0);
const [messages, setMessages] = useState({ server: [] }); // keyed by tab name const [messages, setMessages] = useState({ Server: [] });
const [members, setMembers] = useState({}); // keyed by channel name const [members, setMembers] = useState({});
const [topics, setTopics] = useState({});
const [unread, setUnread] = useState({});
const [input, setInput] = useState(''); const [input, setInput] = useState('');
const [joinInput, setJoinInput] = useState(''); const [joinInput, setJoinInput] = useState('');
const [lastMsgId, setLastMsgId] = useState(0); const [connected, setConnected] = useState(true);
const lastIdRef = useRef(0);
const seenIdsRef = useRef(new Set());
const pollAbortRef = useRef(null);
const tabsRef = useRef(tabs);
const activeTabRef = useRef(activeTab);
const nickRef = useRef(nick);
const messagesEndRef = useRef(); const messagesEndRef = useRef();
const inputRef = useRef(); const inputRef = useRef();
const pollRef = useRef();
useEffect(() => { tabsRef.current = tabs; }, [tabs]);
useEffect(() => { activeTabRef.current = activeTab; }, [activeTab]);
useEffect(() => { nickRef.current = nick; }, [nick]);
// Persist joined channels
useEffect(() => {
const channels = tabs.filter(t => t.type === 'channel').map(t => t.name);
localStorage.setItem('chat_channels', JSON.stringify(channels));
}, [tabs]);
// Clear unread on tab switch
useEffect(() => {
const tab = tabs[activeTab];
if (tab) setUnread(prev => ({ ...prev, [tab.name]: 0 }));
}, [activeTab, tabs]);
const addMessage = useCallback((tabName, msg) => { const addMessage = useCallback((tabName, msg) => {
if (msg.id && seenIdsRef.current.has(msg.id)) return;
if (msg.id) seenIdsRef.current.add(msg.id);
setMessages(prev => ({ setMessages(prev => ({
...prev, ...prev,
[tabName]: [...(prev[tabName] || []), msg] [tabName]: [...(prev[tabName] || []), msg]
})); }));
const currentTab = tabsRef.current[activeTabRef.current];
if (!currentTab || currentTab.name !== tabName) {
setUnread(prev => ({ ...prev, [tabName]: (prev[tabName] || 0) + 1 }));
}
}, []); }, []);
const addSystemMessage = useCallback((tabName, text) => { const addSystemMessage = useCallback((tabName, text) => {
addMessage(tabName, { setMessages(prev => ({
id: Date.now(), ...prev,
nick: '*', [tabName]: [...(prev[tabName] || []), {
content: text, id: 'sys-' + Date.now() + '-' + Math.random(),
createdAt: new Date().toISOString(), ts: new Date().toISOString(),
text,
system: true system: true
}); }]
}, [addMessage]); }));
}, []);
const onLogin = useCallback((userNick, token) => { const refreshMembers = useCallback((channel) => {
setNick(userNick); const chName = channel.replace('#', '');
setLoggedIn(true); api(`/channels/${chName}/members`).then(m => {
addSystemMessage('server', `Connected as ${userNick}`); setMembers(prev => ({ ...prev, [channel]: m }));
// Fetch server info
api('/server').then(s => {
if (s.motd) addSystemMessage('server', `MOTD: ${s.motd}`);
}).catch(() => {}); }).catch(() => {});
}, [addSystemMessage]); }, []);
// Poll for new messages const processMessage = useCallback((msg) => {
useEffect(() => { const body = Array.isArray(msg.body) ? msg.body.join('\n') : '';
if (!loggedIn) return; const base = { id: msg.id, ts: msg.ts, from: msg.from, to: msg.to, command: msg.command };
let alive = true;
const poll = async () => { switch (msg.command) {
try { case 'PRIVMSG':
const msgs = await api(`/messages?after=${lastMsgId}`); case 'NOTICE': {
if (!alive) return; const parsed = { ...base, text: body, system: false };
let maxId = lastMsgId; const target = msg.to;
for (const msg of msgs) { if (target && target.startsWith('#')) {
if (msg.id > maxId) maxId = msg.id; addMessage(target, parsed);
if (msg.isDm) { } else {
const dmTab = msg.nick === nick ? msg.dmTarget : msg.nick; const dmPeer = msg.from === nickRef.current ? msg.to : msg.from;
// Ensure DM tab exists
setTabs(prev => { setTabs(prev => {
if (!prev.find(t => t.type === 'dm' && t.name === dmTab)) { if (!prev.find(t => t.type === 'dm' && t.name === dmPeer)) {
return [...prev, { type: 'dm', name: dmTab }]; return [...prev, { type: 'dm', name: dmPeer }];
} }
return prev; return prev;
}); });
addMessage(dmTab, msg); addMessage(dmPeer, parsed);
} else if (msg.channel) {
addMessage(msg.channel, msg);
} }
break;
}
case 'JOIN': {
const text = `${msg.from} has joined ${msg.to}`;
if (msg.to) addMessage(msg.to, { ...base, text, system: true });
if (msg.to && msg.to.startsWith('#')) refreshMembers(msg.to);
break;
}
case 'PART': {
const reason = body ? ': ' + body : '';
const text = `${msg.from} has left ${msg.to}${reason}`;
if (msg.to) addMessage(msg.to, { ...base, text, system: true });
if (msg.to && msg.to.startsWith('#')) refreshMembers(msg.to);
break;
}
case 'QUIT': {
const reason = body ? ': ' + body : '';
const text = `${msg.from} has quit${reason}`;
tabsRef.current.forEach(tab => {
if (tab.type === 'channel') {
addMessage(tab.name, { ...base, text, system: true });
}
});
break;
}
case 'NICK': {
const newNick = Array.isArray(msg.body) ? msg.body[0] : body;
const text = `${msg.from} is now known as ${newNick}`;
tabsRef.current.forEach(tab => {
if (tab.type === 'channel') {
addMessage(tab.name, { ...base, text, system: true });
}
});
if (msg.from === nickRef.current && newNick) setNick(newNick);
// Refresh members in all channels
tabsRef.current.forEach(tab => {
if (tab.type === 'channel') refreshMembers(tab.name);
});
break;
}
case 'TOPIC': {
const text = `${msg.from} set the topic: ${body}`;
if (msg.to) {
addMessage(msg.to, { ...base, text, system: true });
setTopics(prev => ({ ...prev, [msg.to]: body }));
}
break;
}
case '375':
case '372':
case '376':
addMessage('Server', { ...base, text: body, system: true });
break;
default:
addMessage('Server', { ...base, text: body || msg.command, system: true });
}
}, [addMessage, refreshMembers]);
// Long-poll loop
useEffect(() => {
if (!loggedIn) return;
let alive = true;
const poll = async () => {
while (alive) {
try {
const controller = new AbortController();
pollAbortRef.current = controller;
const result = await api(
`/messages?after=${lastIdRef.current}&timeout=${POLL_TIMEOUT}`,
{ signal: controller.signal }
);
if (!alive) break;
setConnected(true);
if (result.messages) {
for (const m of result.messages) processMessage(m);
}
if (result.last_id > lastIdRef.current) {
lastIdRef.current = result.last_id;
} }
if (maxId > lastMsgId) setLastMsgId(maxId);
} catch (err) { } catch (err) {
// silent if (!alive) break;
if (err.name === 'AbortError') continue;
setConnected(false);
await new Promise(r => setTimeout(r, RECONNECT_DELAY));
}
} }
}; };
pollRef.current = setInterval(poll, 1500);
poll();
return () => { alive = false; clearInterval(pollRef.current); };
}, [loggedIn, lastMsgId, nick, addMessage]);
// Fetch members for active channel tab poll();
return () => { alive = false; pollAbortRef.current?.abort(); };
}, [loggedIn, processMessage]);
// Refresh members for active channel
useEffect(() => { useEffect(() => {
if (!loggedIn) return; if (!loggedIn) return;
const tab = tabs[activeTab]; const tab = tabs[activeTab];
if (!tab || tab.type !== 'channel') return; if (!tab || tab.type !== 'channel') return;
const chName = tab.name.replace('#', ''); refreshMembers(tab.name);
api(`/channels/${chName}/members`).then(m => { const iv = setInterval(() => refreshMembers(tab.name), MEMBER_REFRESH_INTERVAL);
setMembers(prev => ({ ...prev, [tab.name]: m }));
}).catch(() => {});
const iv = setInterval(() => {
api(`/channels/${chName}/members`).then(m => {
setMembers(prev => ({ ...prev, [tab.name]: m }));
}).catch(() => {});
}, 5000);
return () => clearInterval(iv); return () => clearInterval(iv);
}, [loggedIn, activeTab, tabs]); }, [loggedIn, activeTab, tabs, refreshMembers]);
// Auto-scroll // Auto-scroll
useEffect(() => { useEffect(() => {
@@ -192,9 +303,37 @@ function App() {
}, [messages, activeTab]); }, [messages, activeTab]);
// Focus input on tab change // Focus input on tab change
useEffect(() => { inputRef.current?.focus(); }, [activeTab]);
// Fetch topic for active channel
useEffect(() => { useEffect(() => {
inputRef.current?.focus(); if (!loggedIn) return;
}, [activeTab]); const tab = tabs[activeTab];
if (!tab || tab.type !== 'channel') return;
api('/channels').then(channels => {
const ch = channels.find(c => c.name === tab.name);
if (ch && ch.topic) setTopics(prev => ({ ...prev, [tab.name]: ch.topic }));
}).catch(() => {});
}, [loggedIn, activeTab, tabs]);
const onLogin = useCallback(async (userNick) => {
setNick(userNick);
setLoggedIn(true);
addSystemMessage('Server', `Connected as ${userNick}`);
// Auto-rejoin saved channels
const saved = JSON.parse(localStorage.getItem('chat_channels') || '[]');
for (const ch of saved) {
try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'JOIN', to: ch }) });
setTabs(prev => {
if (prev.find(t => t.type === 'channel' && t.name === ch)) return prev;
return [...prev, { type: 'channel', name: ch }];
});
} catch (e) {
// Channel may not exist anymore
}
}
}, [addSystemMessage]);
const joinChannel = async (name) => { const joinChannel = async (name) => {
if (!name) return; if (!name) return;
@@ -206,22 +345,29 @@ function App() {
if (prev.find(t => t.type === 'channel' && t.name === name)) return prev; if (prev.find(t => t.type === 'channel' && t.name === name)) return prev;
return [...prev, { type: 'channel', name }]; return [...prev, { type: 'channel', name }];
}); });
setActiveTab(tabs.length); // switch to new tab setActiveTab(tabs.length);
addSystemMessage(name, `Joined ${name}`); // Load history
try {
const hist = await api(`/history?target=${encodeURIComponent(name)}&limit=50`);
if (Array.isArray(hist)) {
for (const m of hist) processMessage(m);
}
} catch (e) {
// History may be empty
}
setJoinInput(''); setJoinInput('');
} catch (err) { } catch (err) {
addSystemMessage('server', `Failed to join ${name}: ${err.data?.error || 'error'}`); addSystemMessage('Server', `Failed to join ${name}: ${err.data?.error || 'error'}`);
} }
}; };
const partChannel = async (name) => { const partChannel = async (name) => {
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PART', to: name }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PART', to: name }) });
} catch (err) { /* ignore */ } } catch (e) {
setTabs(prev => { // Ignore
const next = prev.filter(t => !(t.type === 'channel' && t.name === name)); }
return next; setTabs(prev => prev.filter(t => !(t.type === 'channel' && t.name === name)));
});
setActiveTab(0); setActiveTab(0);
}; };
@@ -240,7 +386,8 @@ function App() {
if (prev.find(t => t.type === 'dm' && t.name === targetNick)) return prev; if (prev.find(t => t.type === 'dm' && t.name === targetNick)) return prev;
return [...prev, { type: 'dm', name: targetNick }]; return [...prev, { type: 'dm', name: targetNick }];
}); });
setActiveTab(tabs.findIndex(t => t.type === 'dm' && t.name === targetNick) || tabs.length); const idx = tabs.findIndex(t => t.type === 'dm' && t.name === targetNick);
setActiveTab(idx >= 0 ? idx : tabs.length);
}; };
const sendMessage = async () => { const sendMessage = async () => {
@@ -250,46 +397,45 @@ function App() {
const tab = tabs[activeTab]; const tab = tabs[activeTab];
if (!tab || tab.type === 'server') return; if (!tab || tab.type === 'server') return;
// Handle /commands
if (text.startsWith('/')) { if (text.startsWith('/')) {
const parts = text.split(' '); const parts = text.split(' ');
const cmd = parts[0].toLowerCase(); const cmd = parts[0].toLowerCase();
if (cmd === '/join' && parts[1]) { if (cmd === '/join' && parts[1]) { joinChannel(parts[1]); return; }
joinChannel(parts[1]); if (cmd === '/part') { if (tab.type === 'channel') partChannel(tab.name); return; }
return;
}
if (cmd === '/part') {
if (tab.type === 'channel') partChannel(tab.name);
return;
}
if (cmd === '/msg' && parts[1] && parts.slice(2).join(' ')) { if (cmd === '/msg' && parts[1] && parts.slice(2).join(' ')) {
const target = parts[1]; const target = parts[1];
const msg = parts.slice(2).join(' '); const body = parts.slice(2).join(' ');
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: target, body: [msg] }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: target, body: [body] }) });
openDM(target); openDM(target);
} catch (err) { } catch (err) {
addSystemMessage('server', `Failed to send DM: ${err.data?.error || 'error'}`); addSystemMessage('Server', `DM failed: ${err.data?.error || 'error'}`);
} }
return; return;
} }
if (cmd === '/nick' && parts[1]) { if (cmd === '/nick' && parts[1]) {
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'NICK', body: [parts[1]] }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'NICK', body: [parts[1]] }) });
setNick(parts[1]);
addSystemMessage('server', `Nick changed to ${parts[1]}`);
} catch (err) { } catch (err) {
addSystemMessage('server', `Nick change failed: ${err.data?.error || 'error'}`); addSystemMessage('Server', `Nick change failed: ${err.data?.error || 'error'}`);
} }
return; return;
} }
addSystemMessage('server', `Unknown command: ${cmd}`); if (cmd === '/topic' && tab.type === 'channel') {
const topicText = parts.slice(1).join(' ');
try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'TOPIC', to: tab.name, body: [topicText] }) });
} catch (err) {
addSystemMessage('Server', `Topic failed: ${err.data?.error || 'error'}`);
}
return;
}
addSystemMessage('Server', `Unknown command: ${cmd}`);
return; return;
} }
const to = tab.type === 'channel' ? tab.name : tab.name;
try { try {
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to, body: [text] }) }); await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: tab.name, body: [text] }) });
} catch (err) { } catch (err) {
addSystemMessage(tab.name, `Send failed: ${err.data?.error || 'error'}`); addSystemMessage(tab.name, `Send failed: ${err.data?.error || 'error'}`);
} }
@@ -300,16 +446,21 @@ function App() {
const currentTab = tabs[activeTab] || tabs[0]; const currentTab = tabs[activeTab] || tabs[0];
const currentMessages = messages[currentTab.name] || []; const currentMessages = messages[currentTab.name] || [];
const currentMembers = members[currentTab.name] || []; const currentMembers = members[currentTab.name] || [];
const currentTopic = topics[currentTab.name] || '';
return ( return (
<div class="app"> <div class="app">
<div class="tab-bar"> <div class="tab-bar">
{!connected && <div class="connection-status"> Reconnecting...</div>}
{tabs.map((tab, i) => ( {tabs.map((tab, i) => (
<div <div
class={`tab ${i === activeTab ? 'active' : ''}`} class={`tab ${i === activeTab ? 'active' : ''}`}
onClick={() => setActiveTab(i)} onClick={() => setActiveTab(i)}
> >
{tab.type === 'dm' ? `${tab.name}` : tab.name} {tab.type === 'dm' ? `${tab.name}` : tab.name}
{unread[tab.name] > 0 && i !== activeTab && (
<span class="unread-badge">{unread[tab.name]}</span>
)}
{tab.type !== 'server' && ( {tab.type !== 'server' && (
<span class="close-btn" onClick={(e) => { e.stopPropagation(); closeTab(i); }}>×</span> <span class="close-btn" onClick={(e) => { e.stopPropagation(); closeTab(i); }}>×</span>
)} )}
@@ -326,19 +477,17 @@ function App() {
</div> </div>
</div> </div>
{currentTab.type === 'channel' && currentTopic && (
<div class="topic-bar" title={currentTopic}>{currentTopic}</div>
)}
<div class="content"> <div class="content">
<div class="messages-pane"> <div class="messages-pane">
{currentTab.type === 'server' ? ( <div class={currentTab.type === 'server' ? 'server-messages' : 'messages'}>
<div class="server-messages">
{currentMessages.map(m => <Message msg={m} />)}
<div ref={messagesEndRef} />
</div>
) : (
<>
<div class="messages">
{currentMessages.map(m => <Message msg={m} />)} {currentMessages.map(m => <Message msg={m} />)}
<div ref={messagesEndRef} /> <div ref={messagesEndRef} />
</div> </div>
{currentTab.type !== 'server' && (
<div class="input-bar"> <div class="input-bar">
<input <input
ref={inputRef} ref={inputRef}
@@ -349,7 +498,6 @@ function App() {
/> />
<button onClick={sendMessage}>Send</button> <button onClick={sendMessage}>Send</button>
</div> </div>
</>
)} )}
</div> </div>

View File

@@ -14,6 +14,9 @@
--tab-active: #e94560; --tab-active: #e94560;
--tab-bg: #16213e; --tab-bg: #16213e;
--tab-hover: #1a1a3e; --tab-hover: #1a1a3e;
--topic-bg: #121a30;
--unread-bg: #e94560;
--warn: #f0ad4e;
} }
html, body, #root { html, body, #root {
@@ -86,6 +89,7 @@ html, body, #root {
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
overflow-x: auto; overflow-x: auto;
flex-shrink: 0; flex-shrink: 0;
align-items: center;
} }
.tab { .tab {
@@ -95,6 +99,7 @@ html, body, #root {
white-space: nowrap; white-space: nowrap;
color: var(--text-muted); color: var(--text-muted);
user-select: none; user-select: none;
position: relative;
} }
.tab:hover { .tab:hover {
@@ -116,6 +121,43 @@ html, body, #root {
color: var(--accent); color: var(--accent);
} }
.tab .unread-badge {
display: inline-block;
background: var(--unread-bg);
color: white;
font-size: 10px;
font-weight: bold;
padding: 1px 5px;
border-radius: 8px;
margin-left: 6px;
min-width: 16px;
text-align: center;
}
/* Connection status */
.connection-status {
padding: 4px 12px;
background: var(--warn);
color: #1a1a2e;
font-size: 12px;
font-weight: bold;
white-space: nowrap;
flex-shrink: 0;
}
/* Topic bar */
.topic-bar {
padding: 6px 12px;
background: var(--topic-bg);
border-bottom: 1px solid var(--border);
color: var(--text-muted);
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
flex-shrink: 0;
}
/* Content area */ /* Content area */
.content { .content {
display: flex; display: flex;
@@ -243,6 +285,7 @@ html, body, #root {
gap: 8px; gap: 8px;
background: var(--bg-secondary); background: var(--bg-secondary);
border-bottom: 1px solid var(--border); border-bottom: 1px solid var(--border);
margin-left: auto;
} }
.join-dialog input { .join-dialog input {