Compare commits
18 Commits
feature/mv
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 2761ee156a | |||
| cd909d59c4 | |||
|
|
f5cc098b7b | ||
|
|
4d7b7618b2 | ||
|
|
910a5c2606 | ||
| bdc243224b | |||
| 5981c750a4 | |||
| 6cfab21eaa | |||
| 4a0ed57fc0 | |||
|
|
52c85724a7 | ||
|
|
69c9550bb2 | ||
| 7047167dc8 | |||
| 3cd942ffa5 | |||
| b8794c2587 | |||
| 70aa15e758 | |||
| 5e26e53187 | |||
| 02b906badb | |||
|
|
32419fb1f7 |
19
Dockerfile
19
Dockerfile
@@ -1,24 +1,35 @@
|
||||
# Lint stage — fast feedback on formatting and lint issues
|
||||
# golangci/golangci-lint:v2.1.6, 2026-03-02
|
||||
FROM golangci/golangci-lint@sha256:568ee1c1c53493575fa9494e280e579ac9ca865787bafe4df3023ae59ecf299b AS lint
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN make fmt-check
|
||||
RUN make lint
|
||||
|
||||
# Build stage
|
||||
# golang:1.24-alpine, 2026-02-26
|
||||
FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
|
||||
WORKDIR /src
|
||||
RUN apk add --no-cache git build-base make
|
||||
|
||||
# golangci-lint v2.1.6 (eabc2638a66d), 2026-02-26
|
||||
RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d
|
||||
# Force BuildKit to run the lint stage before proceeding
|
||||
COPY --from=lint /src/go.sum /dev/null
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
# Run all checks — build fails if branch is not green
|
||||
RUN make check
|
||||
RUN make test
|
||||
|
||||
# Build static binaries (no cgo needed at runtime — modernc.org/sqlite is pure Go)
|
||||
ARG VERSION=dev
|
||||
RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w -X main.Version=${VERSION}" -o /chatd ./cmd/chatd/
|
||||
RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o /chat-cli ./cmd/chat-cli/
|
||||
|
||||
# Runtime stage
|
||||
# alpine:3.21, 2026-02-26
|
||||
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
|
||||
73
README.md
73
README.md
@@ -1158,6 +1158,55 @@ curl -s http://localhost:8080/api/v1/channels/general/members \
|
||||
-H "Authorization: Bearer $TOKEN" | jq .
|
||||
```
|
||||
|
||||
### POST /api/v1/logout — Logout
|
||||
|
||||
Destroy the current client's auth token. If no other clients remain on the
|
||||
session, the user is fully cleaned up: parted from all channels (with QUIT
|
||||
broadcast to members), session deleted, nick released.
|
||||
|
||||
**Request:** No body. Requires auth.
|
||||
|
||||
**Response:** `200 OK`
|
||||
```json
|
||||
{"status": "ok"}
|
||||
```
|
||||
|
||||
**Errors:**
|
||||
|
||||
| Status | Error | When |
|
||||
|--------|-------|------|
|
||||
| 401 | `unauthorized` | Missing or invalid auth token |
|
||||
|
||||
**curl example:**
|
||||
```bash
|
||||
curl -s -X POST http://localhost:8080/api/v1/logout \
|
||||
-H "Authorization: Bearer $TOKEN" | jq .
|
||||
```
|
||||
|
||||
### GET /api/v1/users/me — Current User Info
|
||||
|
||||
Return the current user's session state. This is an alias for
|
||||
`GET /api/v1/state`.
|
||||
|
||||
**Request:** No body. Requires auth.
|
||||
|
||||
**Response:** `200 OK`
|
||||
```json
|
||||
{
|
||||
"id": 1,
|
||||
"nick": "alice",
|
||||
"channels": [
|
||||
{"id": 1, "name": "#general", "topic": "Welcome!"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**curl example:**
|
||||
```bash
|
||||
curl -s http://localhost:8080/api/v1/users/me \
|
||||
-H "Authorization: Bearer $TOKEN" | jq .
|
||||
```
|
||||
|
||||
### GET /api/v1/server — Server Info
|
||||
|
||||
Return server metadata. No authentication required.
|
||||
@@ -1166,10 +1215,17 @@ Return server metadata. No authentication required.
|
||||
```json
|
||||
{
|
||||
"name": "My Chat Server",
|
||||
"motd": "Welcome! Be nice."
|
||||
"motd": "Welcome! Be nice.",
|
||||
"users": 42
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Description |
|
||||
|---------|---------|-------------|
|
||||
| `name` | string | Server display name |
|
||||
| `motd` | string | Message of the day |
|
||||
| `users` | integer | Number of currently active user sessions |
|
||||
|
||||
### GET /.well-known/healthcheck.json — Health Check
|
||||
|
||||
Standard health check endpoint. No authentication required.
|
||||
@@ -1572,8 +1628,10 @@ skew issues) and simpler than UUIDs (integer comparison vs. string comparison).
|
||||
- **Queue entries**: Stored until pruned. Pruning by `QUEUE_MAX_AGE` is
|
||||
planned.
|
||||
- **Channels**: Deleted when the last member leaves (ephemeral).
|
||||
- **Users/sessions**: Deleted on `QUIT`. Session expiry by `SESSION_TIMEOUT`
|
||||
is planned.
|
||||
- **Users/sessions**: Deleted on `QUIT` or `POST /api/v1/logout`. Idle
|
||||
sessions are automatically expired after `SESSION_IDLE_TIMEOUT` (default
|
||||
24h) — the server runs a background cleanup loop that parts idle users
|
||||
from all channels, broadcasts QUIT, and releases their nicks.
|
||||
|
||||
---
|
||||
|
||||
@@ -1590,7 +1648,7 @@ directory is also loaded automatically via
|
||||
| `DBURL` | string | `file:./data.db?_journal_mode=WAL` | SQLite connection string. For file-based: `file:./path.db?_journal_mode=WAL`. For in-memory (testing): `file::memory:?cache=shared`. |
|
||||
| `DEBUG` | bool | `false` | Enable debug logging (verbose request/response logging) |
|
||||
| `MAX_HISTORY` | int | `10000` | Maximum messages retained per channel before rotation (planned) |
|
||||
| `SESSION_TIMEOUT` | int | `86400` | Session idle timeout in seconds (planned). Sessions with no activity for this long are expired and the nick is released. |
|
||||
| `SESSION_IDLE_TIMEOUT` | string | `24h` | Session idle timeout as a Go duration string (e.g. `24h`, `30m`). Sessions with no activity for this long are expired and the nick is released. |
|
||||
| `QUEUE_MAX_AGE` | int | `172800` | Maximum age of client queue entries in seconds (48h). Entries older than this are pruned (planned). |
|
||||
| `MAX_MESSAGE_SIZE` | int | `4096` | Maximum message body size in bytes (planned enforcement) |
|
||||
| `LONG_POLL_TIMEOUT`| int | `15` | Default long-poll timeout in seconds (client can override via query param, server caps at 30) |
|
||||
@@ -1610,7 +1668,7 @@ SERVER_NAME=My Chat Server
|
||||
MOTD=Welcome! Be excellent to each other.
|
||||
DEBUG=false
|
||||
DBURL=file:./data.db?_journal_mode=WAL
|
||||
SESSION_TIMEOUT=86400
|
||||
SESSION_IDLE_TIMEOUT=24h
|
||||
```
|
||||
|
||||
---
|
||||
@@ -2008,11 +2066,14 @@ GET /api/v1/challenge
|
||||
- [x] Docker deployment
|
||||
- [x] Prometheus metrics endpoint
|
||||
- [x] Health check endpoint
|
||||
- [x] Session expiry — auto-expire idle sessions, release nicks
|
||||
- [x] Logout endpoint (`POST /api/v1/logout`)
|
||||
- [x] Current user endpoint (`GET /api/v1/users/me`)
|
||||
- [x] User count in server info (`GET /api/v1/server`)
|
||||
|
||||
### Post-MVP (Planned)
|
||||
|
||||
- [ ] **Hashcash proof-of-work** for session creation (abuse prevention)
|
||||
- [ ] **Session expiry** — auto-expire idle sessions, release nicks
|
||||
- [ ] **Queue pruning** — delete old queue entries per `QUEUE_MAX_AGE`
|
||||
- [ ] **Message rotation** — enforce `MAX_HISTORY` per channel
|
||||
- [ ] **Channel modes** — enforce `+i`, `+m`, `+s`, `+t`, `+n`
|
||||
|
||||
13
go.mod
13
go.mod
@@ -4,14 +4,18 @@ go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
github.com/getsentry/sentry-go v0.42.0
|
||||
github.com/go-chi/chi v1.5.5
|
||||
github.com/go-chi/cors v1.2.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/rivo/tview v0.42.0
|
||||
github.com/slok/go-http-metrics v0.13.0
|
||||
github.com/spf13/viper v1.21.0
|
||||
go.uber.org/fx v1.24.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.45.0
|
||||
)
|
||||
|
||||
@@ -21,9 +25,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/gdamore/tcell/v2 v2.13.8 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
@@ -33,7 +35,6 @@ require (
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/tview v0.42.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
@@ -47,9 +48,9 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/term v0.40.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
30
go.sum
30
go.sum
@@ -113,12 +113,14 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
@@ -126,8 +128,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -135,30 +137,26 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
|
||||
@@ -23,21 +23,21 @@ type Params struct {
|
||||
|
||||
// Config holds all application configuration values.
|
||||
type Config struct {
|
||||
DBURL string
|
||||
Debug bool
|
||||
MaintenanceMode bool
|
||||
MetricsPassword string
|
||||
MetricsUsername string
|
||||
Port int
|
||||
SentryDSN string
|
||||
MaxHistory int
|
||||
SessionTimeout int
|
||||
MaxMessageSize int
|
||||
MOTD string
|
||||
ServerName string
|
||||
FederationKey string
|
||||
params *Params
|
||||
log *slog.Logger
|
||||
DBURL string
|
||||
Debug bool
|
||||
MaintenanceMode bool
|
||||
MetricsPassword string
|
||||
MetricsUsername string
|
||||
Port int
|
||||
SentryDSN string
|
||||
MaxHistory int
|
||||
MaxMessageSize int
|
||||
MOTD string
|
||||
ServerName string
|
||||
FederationKey string
|
||||
SessionIdleTimeout string
|
||||
params *Params
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a new Config by reading from files and environment variables.
|
||||
@@ -61,11 +61,11 @@ func New(
|
||||
viper.SetDefault("METRICS_USERNAME", "")
|
||||
viper.SetDefault("METRICS_PASSWORD", "")
|
||||
viper.SetDefault("MAX_HISTORY", "10000")
|
||||
viper.SetDefault("SESSION_TIMEOUT", "86400")
|
||||
viper.SetDefault("MAX_MESSAGE_SIZE", "4096")
|
||||
viper.SetDefault("MOTD", "")
|
||||
viper.SetDefault("SERVER_NAME", "")
|
||||
viper.SetDefault("FEDERATION_KEY", "")
|
||||
viper.SetDefault("SESSION_IDLE_TIMEOUT", "24h")
|
||||
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
@@ -77,21 +77,21 @@ func New(
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
DBURL: viper.GetString("DBURL"),
|
||||
Debug: viper.GetBool("DEBUG"),
|
||||
Port: viper.GetInt("PORT"),
|
||||
SentryDSN: viper.GetString("SENTRY_DSN"),
|
||||
MaintenanceMode: viper.GetBool("MAINTENANCE_MODE"),
|
||||
MetricsUsername: viper.GetString("METRICS_USERNAME"),
|
||||
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
|
||||
MaxHistory: viper.GetInt("MAX_HISTORY"),
|
||||
SessionTimeout: viper.GetInt("SESSION_TIMEOUT"),
|
||||
MaxMessageSize: viper.GetInt("MAX_MESSAGE_SIZE"),
|
||||
MOTD: viper.GetString("MOTD"),
|
||||
ServerName: viper.GetString("SERVER_NAME"),
|
||||
FederationKey: viper.GetString("FEDERATION_KEY"),
|
||||
log: log,
|
||||
params: ¶ms,
|
||||
DBURL: viper.GetString("DBURL"),
|
||||
Debug: viper.GetBool("DEBUG"),
|
||||
Port: viper.GetInt("PORT"),
|
||||
SentryDSN: viper.GetString("SENTRY_DSN"),
|
||||
MaintenanceMode: viper.GetBool("MAINTENANCE_MODE"),
|
||||
MetricsUsername: viper.GetString("METRICS_USERNAME"),
|
||||
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
|
||||
MaxHistory: viper.GetInt("MAX_HISTORY"),
|
||||
MaxMessageSize: viper.GetInt("MAX_MESSAGE_SIZE"),
|
||||
MOTD: viper.GetString("MOTD"),
|
||||
ServerName: viper.GetString("SERVER_NAME"),
|
||||
FederationKey: viper.GetString("FEDERATION_KEY"),
|
||||
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
|
||||
log: log,
|
||||
params: ¶ms,
|
||||
}
|
||||
|
||||
if cfg.Debug {
|
||||
|
||||
161
internal/db/auth.go
Normal file
161
internal/db/auth.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const bcryptCost = bcrypt.DefaultCost
|
||||
|
||||
var errNoPassword = errors.New(
|
||||
"account has no password set",
|
||||
)
|
||||
|
||||
// RegisterUser creates a session with a hashed password
|
||||
// and returns session ID, client ID, and token.
|
||||
func (database *Database) RegisterUser(
|
||||
ctx context.Context,
|
||||
nick, password string,
|
||||
) (int64, int64, string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword(
|
||||
[]byte(password), bcryptCost,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"hash password: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
sessionUUID := uuid.New().String()
|
||||
clientUUID := uuid.New().String()
|
||||
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return 0, 0, "", err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
transaction, err := database.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"begin tx: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
res, err := transaction.ExecContext(ctx,
|
||||
`INSERT INTO sessions
|
||||
(uuid, nick, password_hash,
|
||||
created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
sessionUUID, nick, string(hash), now, now)
|
||||
if err != nil {
|
||||
_ = transaction.Rollback()
|
||||
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
sessionID, _ := res.LastInsertId()
|
||||
|
||||
clientRes, err := transaction.ExecContext(ctx,
|
||||
`INSERT INTO clients
|
||||
(uuid, session_id, token,
|
||||
created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
clientUUID, sessionID, token, now, now)
|
||||
if err != nil {
|
||||
_ = transaction.Rollback()
|
||||
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create client: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
clientID, _ := clientRes.LastInsertId()
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"commit registration: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return sessionID, clientID, token, nil
|
||||
}
|
||||
|
||||
// LoginUser verifies a nick/password and creates a new
|
||||
// client token.
|
||||
func (database *Database) LoginUser(
|
||||
ctx context.Context,
|
||||
nick, password string,
|
||||
) (int64, int64, string, error) {
|
||||
var (
|
||||
sessionID int64
|
||||
passwordHash string
|
||||
)
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT id, password_hash
|
||||
FROM sessions WHERE nick = ?`,
|
||||
nick,
|
||||
).Scan(&sessionID, &passwordHash)
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"get session for login: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
if passwordHash == "" {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"login: %w", errNoPassword,
|
||||
)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword(
|
||||
[]byte(passwordHash), []byte(password),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"verify password: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
clientUUID := uuid.New().String()
|
||||
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return 0, 0, "", err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
res, err := database.conn.ExecContext(ctx,
|
||||
`INSERT INTO clients
|
||||
(uuid, session_id, token,
|
||||
created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
clientUUID, sessionID, token, now, now)
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create login client: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
clientID, _ := res.LastInsertId()
|
||||
|
||||
_, _ = database.conn.ExecContext(
|
||||
ctx,
|
||||
"UPDATE sessions SET last_seen = ? WHERE id = ?",
|
||||
now, sessionID,
|
||||
)
|
||||
|
||||
return sessionID, clientID, token, nil
|
||||
}
|
||||
178
internal/db/auth_test.go
Normal file
178
internal/db/auth_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestRegisterUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
sessionID, clientID, token, err :=
|
||||
database.RegisterUser(ctx, "reguser", "password123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if sessionID == 0 || clientID == 0 || token == "" {
|
||||
t.Fatal("expected valid ids and token")
|
||||
}
|
||||
|
||||
// Verify session works via token lookup.
|
||||
sid, cid, nick, err :=
|
||||
database.GetSessionByToken(ctx, token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if sid != sessionID || cid != clientID {
|
||||
t.Fatal("session/client id mismatch")
|
||||
}
|
||||
|
||||
if nick != "reguser" {
|
||||
t.Fatalf("expected reguser, got %s", nick)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterUserDuplicateNick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
regSID, regCID, regToken, err :=
|
||||
database.RegisterUser(ctx, "dupnick", "password123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = regSID
|
||||
_ = regCID
|
||||
_ = regToken
|
||||
|
||||
dupSID, dupCID, dupToken, dupErr :=
|
||||
database.RegisterUser(ctx, "dupnick", "other12345")
|
||||
if dupErr == nil {
|
||||
t.Fatal("expected error for duplicate nick")
|
||||
}
|
||||
|
||||
_ = dupSID
|
||||
_ = dupCID
|
||||
_ = dupToken
|
||||
}
|
||||
|
||||
func TestLoginUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
regSID, regCID, regToken, err :=
|
||||
database.RegisterUser(ctx, "loginuser", "mypassword")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = regSID
|
||||
_ = regCID
|
||||
_ = regToken
|
||||
|
||||
sessionID, clientID, token, err :=
|
||||
database.LoginUser(ctx, "loginuser", "mypassword")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if sessionID == 0 || clientID == 0 || token == "" {
|
||||
t.Fatal("expected valid ids and token")
|
||||
}
|
||||
|
||||
// Verify the new token works.
|
||||
_, _, nick, err :=
|
||||
database.GetSessionByToken(ctx, token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if nick != "loginuser" {
|
||||
t.Fatalf("expected loginuser, got %s", nick)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginUserWrongPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
regSID, regCID, regToken, err :=
|
||||
database.RegisterUser(ctx, "wrongpw", "correctpass")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = regSID
|
||||
_ = regCID
|
||||
_ = regToken
|
||||
|
||||
loginSID, loginCID, loginToken, loginErr :=
|
||||
database.LoginUser(ctx, "wrongpw", "wrongpass12")
|
||||
if loginErr == nil {
|
||||
t.Fatal("expected error for wrong password")
|
||||
}
|
||||
|
||||
_ = loginSID
|
||||
_ = loginCID
|
||||
_ = loginToken
|
||||
}
|
||||
|
||||
func TestLoginUserNoPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
// Create anonymous session (no password).
|
||||
anonSID, anonCID, anonToken, err :=
|
||||
database.CreateSession(ctx, "anon")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = anonSID
|
||||
_ = anonCID
|
||||
_ = anonToken
|
||||
|
||||
loginSID, loginCID, loginToken, loginErr :=
|
||||
database.LoginUser(ctx, "anon", "anything1")
|
||||
if loginErr == nil {
|
||||
t.Fatal(
|
||||
"expected error for login on passwordless account",
|
||||
)
|
||||
}
|
||||
|
||||
_ = loginSID
|
||||
_ = loginCID
|
||||
_ = loginToken
|
||||
}
|
||||
|
||||
func TestLoginUserNonexistent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
loginSID, loginCID, loginToken, err :=
|
||||
database.LoginUser(ctx, "ghost", "password123")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent user")
|
||||
}
|
||||
|
||||
_ = loginSID
|
||||
_ = loginCID
|
||||
_ = loginToken
|
||||
}
|
||||
@@ -55,76 +55,132 @@ type MemberInfo struct {
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
}
|
||||
|
||||
// CreateUser registers a new user with the given nick.
|
||||
func (database *Database) CreateUser(
|
||||
// CreateSession registers a new session and its first client.
|
||||
func (database *Database) CreateSession(
|
||||
ctx context.Context,
|
||||
nick string,
|
||||
) (int64, string, error) {
|
||||
) (int64, int64, string, error) {
|
||||
sessionUUID := uuid.New().String()
|
||||
clientUUID := uuid.New().String()
|
||||
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
return 0, 0, "", err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
res, err := database.conn.ExecContext(ctx,
|
||||
`INSERT INTO users
|
||||
(nick, token, created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
nick, token, now, now)
|
||||
transaction, err := database.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("create user: %w", err)
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"begin tx: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
userID, _ := res.LastInsertId()
|
||||
res, err := transaction.ExecContext(ctx,
|
||||
`INSERT INTO sessions
|
||||
(uuid, nick, created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
sessionUUID, nick, now, now)
|
||||
if err != nil {
|
||||
_ = transaction.Rollback()
|
||||
|
||||
return userID, token, nil
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
sessionID, _ := res.LastInsertId()
|
||||
|
||||
clientRes, err := transaction.ExecContext(ctx,
|
||||
`INSERT INTO clients
|
||||
(uuid, session_id, token,
|
||||
created_at, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
clientUUID, sessionID, token, now, now)
|
||||
if err != nil {
|
||||
_ = transaction.Rollback()
|
||||
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"create client: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
clientID, _ := clientRes.LastInsertId()
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"commit session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return sessionID, clientID, token, nil
|
||||
}
|
||||
|
||||
// GetUserByToken returns user id and nick for a token.
|
||||
func (database *Database) GetUserByToken(
|
||||
// GetSessionByToken returns session id, client id, and
|
||||
// nick for a client token.
|
||||
func (database *Database) GetSessionByToken(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (int64, string, error) {
|
||||
var userID int64
|
||||
|
||||
var nick string
|
||||
) (int64, int64, string, error) {
|
||||
var (
|
||||
sessionID int64
|
||||
clientID int64
|
||||
nick string
|
||||
)
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
"SELECT id, nick FROM users WHERE token = ?",
|
||||
`SELECT s.id, c.id, s.nick
|
||||
FROM clients c
|
||||
INNER JOIN sessions s
|
||||
ON s.id = c.session_id
|
||||
WHERE c.token = ?`,
|
||||
token,
|
||||
).Scan(&userID, &nick)
|
||||
).Scan(&sessionID, &clientID, &nick)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("get user by token: %w", err)
|
||||
return 0, 0, "", fmt.Errorf(
|
||||
"get session by token: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
_, _ = database.conn.ExecContext(
|
||||
ctx,
|
||||
"UPDATE users SET last_seen = ? WHERE id = ?",
|
||||
time.Now(), userID,
|
||||
"UPDATE sessions SET last_seen = ? WHERE id = ?",
|
||||
now, sessionID,
|
||||
)
|
||||
|
||||
return userID, nick, nil
|
||||
_, _ = database.conn.ExecContext(
|
||||
ctx,
|
||||
"UPDATE clients SET last_seen = ? WHERE id = ?",
|
||||
now, clientID,
|
||||
)
|
||||
|
||||
return sessionID, clientID, nick, nil
|
||||
}
|
||||
|
||||
// GetUserByNick returns user id for a given nick.
|
||||
func (database *Database) GetUserByNick(
|
||||
// GetSessionByNick returns session id for a given nick.
|
||||
func (database *Database) GetSessionByNick(
|
||||
ctx context.Context,
|
||||
nick string,
|
||||
) (int64, error) {
|
||||
var userID int64
|
||||
var sessionID int64
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
"SELECT id FROM users WHERE nick = ?",
|
||||
"SELECT id FROM sessions WHERE nick = ?",
|
||||
nick,
|
||||
).Scan(&userID)
|
||||
).Scan(&sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get user by nick: %w", err)
|
||||
return 0, fmt.Errorf(
|
||||
"get session by nick: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
// GetChannelByName returns the channel ID for a name.
|
||||
@@ -179,16 +235,16 @@ func (database *Database) GetOrCreateChannel(
|
||||
return channelID, nil
|
||||
}
|
||||
|
||||
// JoinChannel adds a user to a channel.
|
||||
// JoinChannel adds a session to a channel.
|
||||
func (database *Database) JoinChannel(
|
||||
ctx context.Context,
|
||||
channelID, userID int64,
|
||||
channelID, sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO channel_members
|
||||
(channel_id, user_id, joined_at)
|
||||
(channel_id, session_id, joined_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
channelID, userID, time.Now())
|
||||
channelID, sessionID, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("join channel: %w", err)
|
||||
}
|
||||
@@ -196,15 +252,15 @@ func (database *Database) JoinChannel(
|
||||
return nil
|
||||
}
|
||||
|
||||
// PartChannel removes a user from a channel.
|
||||
// PartChannel removes a session from a channel.
|
||||
func (database *Database) PartChannel(
|
||||
ctx context.Context,
|
||||
channelID, userID int64,
|
||||
channelID, sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`DELETE FROM channel_members
|
||||
WHERE channel_id = ? AND user_id = ?`,
|
||||
channelID, userID)
|
||||
WHERE channel_id = ? AND session_id = ?`,
|
||||
channelID, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("part channel: %w", err)
|
||||
}
|
||||
@@ -265,18 +321,18 @@ func scanChannels(
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ListChannels returns channels the user has joined.
|
||||
// ListChannels returns channels the session has joined.
|
||||
func (database *Database) ListChannels(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) ([]ChannelInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT c.id, c.name, c.topic
|
||||
FROM channels c
|
||||
INNER JOIN channel_members cm
|
||||
ON cm.channel_id = c.id
|
||||
WHERE cm.user_id = ?
|
||||
ORDER BY c.name`, userID)
|
||||
WHERE cm.session_id = ?
|
||||
ORDER BY c.name`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list channels: %w", err)
|
||||
}
|
||||
@@ -306,12 +362,12 @@ func (database *Database) ChannelMembers(
|
||||
channelID int64,
|
||||
) ([]MemberInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT u.id, u.nick, u.last_seen
|
||||
FROM users u
|
||||
`SELECT s.id, s.nick, s.last_seen
|
||||
FROM sessions s
|
||||
INNER JOIN channel_members cm
|
||||
ON cm.user_id = u.id
|
||||
ON cm.session_id = s.id
|
||||
WHERE cm.channel_id = ?
|
||||
ORDER BY u.nick`, channelID)
|
||||
ORDER BY s.nick`, channelID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"query channel members: %w", err,
|
||||
@@ -349,17 +405,17 @@ func (database *Database) ChannelMembers(
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// IsChannelMember checks if a user belongs to a channel.
|
||||
// IsChannelMember checks if a session belongs to a channel.
|
||||
func (database *Database) IsChannelMember(
|
||||
ctx context.Context,
|
||||
channelID, userID int64,
|
||||
channelID, sessionID int64,
|
||||
) (bool, error) {
|
||||
var count int
|
||||
|
||||
err := database.conn.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM channel_members
|
||||
WHERE channel_id = ? AND user_id = ?`,
|
||||
channelID, userID,
|
||||
WHERE channel_id = ? AND session_id = ?`,
|
||||
channelID, sessionID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf(
|
||||
@@ -397,13 +453,13 @@ func scanInt64s(rows *sql.Rows) ([]int64, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// GetChannelMemberIDs returns user IDs in a channel.
|
||||
// GetChannelMemberIDs returns session IDs in a channel.
|
||||
func (database *Database) GetChannelMemberIDs(
|
||||
ctx context.Context,
|
||||
channelID int64,
|
||||
) ([]int64, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT user_id FROM channel_members
|
||||
`SELECT session_id FROM channel_members
|
||||
WHERE channel_id = ?`, channelID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
@@ -414,17 +470,17 @@ func (database *Database) GetChannelMemberIDs(
|
||||
return scanInt64s(rows)
|
||||
}
|
||||
|
||||
// GetUserChannelIDs returns channel IDs the user is in.
|
||||
func (database *Database) GetUserChannelIDs(
|
||||
// GetSessionChannelIDs returns channel IDs for a session.
|
||||
func (database *Database) GetSessionChannelIDs(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) ([]int64, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT channel_id FROM channel_members
|
||||
WHERE user_id = ?`, userID)
|
||||
WHERE session_id = ?`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"get user channel ids: %w", err,
|
||||
"get session channel ids: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -467,27 +523,52 @@ func (database *Database) InsertMessage(
|
||||
return dbID, msgUUID, nil
|
||||
}
|
||||
|
||||
// EnqueueMessage adds a message to a user's queue.
|
||||
func (database *Database) EnqueueMessage(
|
||||
// EnqueueToSession adds a message to all clients of a
|
||||
// session's queues.
|
||||
func (database *Database) EnqueueToSession(
|
||||
ctx context.Context,
|
||||
userID, messageID int64,
|
||||
sessionID, messageID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO client_queues
|
||||
(user_id, message_id, created_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
userID, messageID, time.Now())
|
||||
(client_id, message_id, created_at)
|
||||
SELECT c.id, ?, ?
|
||||
FROM clients c
|
||||
WHERE c.session_id = ?`,
|
||||
messageID, time.Now(), sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enqueue message: %w", err)
|
||||
return fmt.Errorf(
|
||||
"enqueue to session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PollMessages returns queued messages for a user.
|
||||
// EnqueueToClient adds a message to a specific client's
|
||||
// queue.
|
||||
func (database *Database) EnqueueToClient(
|
||||
ctx context.Context,
|
||||
clientID, messageID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO client_queues
|
||||
(client_id, message_id, created_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
clientID, messageID, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"enqueue to client: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PollMessages returns queued messages for a client.
|
||||
func (database *Database) PollMessages(
|
||||
ctx context.Context,
|
||||
userID, afterQueueID int64,
|
||||
clientID, afterQueueID int64,
|
||||
limit int,
|
||||
) ([]IRCMessage, int64, error) {
|
||||
if limit <= 0 {
|
||||
@@ -501,9 +582,9 @@ func (database *Database) PollMessages(
|
||||
FROM client_queues cq
|
||||
INNER JOIN messages m
|
||||
ON m.id = cq.message_id
|
||||
WHERE cq.user_id = ? AND cq.id > ?
|
||||
WHERE cq.client_id = ? AND cq.id > ?
|
||||
ORDER BY cq.id ASC LIMIT ?`,
|
||||
userID, afterQueueID, limit)
|
||||
clientID, afterQueueID, limit)
|
||||
if err != nil {
|
||||
return nil, afterQueueID, fmt.Errorf(
|
||||
"poll messages: %w", err,
|
||||
@@ -649,15 +730,15 @@ func reverseMessages(msgs []IRCMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
// ChangeNick updates a user's nickname.
|
||||
// ChangeNick updates a session's nickname.
|
||||
func (database *Database) ChangeNick(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
newNick string,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(ctx,
|
||||
"UPDATE users SET nick = ? WHERE id = ?",
|
||||
newNick, userID)
|
||||
"UPDATE sessions SET nick = ? WHERE id = ?",
|
||||
newNick, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("change nick: %w", err)
|
||||
}
|
||||
@@ -681,38 +762,182 @@ func (database *Database) SetTopic(
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user and all their data.
|
||||
func (database *Database) DeleteUser(
|
||||
// DeleteSession removes a session and all its data.
|
||||
func (database *Database) DeleteSession(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
sessionID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(
|
||||
ctx,
|
||||
"DELETE FROM users WHERE id = ?",
|
||||
userID,
|
||||
"DELETE FROM sessions WHERE id = ?",
|
||||
sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
return fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllChannelMembershipsForUser returns channels
|
||||
// a user belongs to.
|
||||
func (database *Database) GetAllChannelMembershipsForUser(
|
||||
// DeleteClient removes a single client record by ID.
|
||||
func (database *Database) DeleteClient(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
clientID int64,
|
||||
) error {
|
||||
_, err := database.conn.ExecContext(
|
||||
ctx,
|
||||
"DELETE FROM clients WHERE id = ?",
|
||||
clientID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete client: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserCount returns the number of active users.
|
||||
func (database *Database) GetUserCount(
|
||||
ctx context.Context,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
"SELECT COUNT(*) FROM sessions",
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(
|
||||
"get user count: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ClientCountForSession returns the number of clients
|
||||
// belonging to a session.
|
||||
func (database *Database) ClientCountForSession(
|
||||
ctx context.Context,
|
||||
sessionID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := database.conn.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT COUNT(*) FROM clients
|
||||
WHERE session_id = ?`,
|
||||
sessionID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(
|
||||
"client count for session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// DeleteStaleUsers removes clients not seen since the
|
||||
// cutoff and cleans up orphaned users (sessions).
|
||||
func (database *Database) DeleteStaleUsers(
|
||||
ctx context.Context,
|
||||
cutoff time.Time,
|
||||
) (int64, error) {
|
||||
res, err := database.conn.ExecContext(ctx,
|
||||
"DELETE FROM clients WHERE last_seen < ?",
|
||||
cutoff,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(
|
||||
"delete stale clients: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
deleted, _ := res.RowsAffected()
|
||||
|
||||
_, err = database.conn.ExecContext(ctx,
|
||||
`DELETE FROM sessions WHERE id NOT IN
|
||||
(SELECT DISTINCT session_id FROM clients)`,
|
||||
)
|
||||
if err != nil {
|
||||
return deleted, fmt.Errorf(
|
||||
"delete orphan sessions: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// StaleSession holds the id and nick of a session
|
||||
// whose clients are all stale.
|
||||
type StaleSession struct {
|
||||
ID int64
|
||||
Nick string
|
||||
}
|
||||
|
||||
// GetStaleOrphanSessions returns sessions where every
|
||||
// client has a last_seen before cutoff.
|
||||
func (database *Database) GetStaleOrphanSessions(
|
||||
ctx context.Context,
|
||||
cutoff time.Time,
|
||||
) ([]StaleSession, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT s.id, s.nick
|
||||
FROM sessions s
|
||||
WHERE s.id IN (
|
||||
SELECT DISTINCT session_id FROM clients
|
||||
WHERE last_seen < ?
|
||||
)
|
||||
AND s.id NOT IN (
|
||||
SELECT DISTINCT session_id FROM clients
|
||||
WHERE last_seen >= ?
|
||||
)`, cutoff, cutoff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"get stale orphan sessions: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var result []StaleSession
|
||||
|
||||
for rows.Next() {
|
||||
var stale StaleSession
|
||||
if err := rows.Scan(&stale.ID, &stale.Nick); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"scan stale session: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
result = append(result, stale)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"iterate stale sessions: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetSessionChannels returns channels a session
|
||||
// belongs to.
|
||||
func (database *Database) GetSessionChannels(
|
||||
ctx context.Context,
|
||||
sessionID int64,
|
||||
) ([]ChannelInfo, error) {
|
||||
rows, err := database.conn.QueryContext(ctx,
|
||||
`SELECT c.id, c.name, c.topic
|
||||
FROM channels c
|
||||
INNER JOIN channel_members cm
|
||||
ON cm.channel_id = c.id
|
||||
WHERE cm.user_id = ?`, userID)
|
||||
WHERE cm.session_id = ?`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"get memberships: %w", err,
|
||||
"get session channels: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,70 +27,91 @@ func setupTestDB(t *testing.T) *db.Database {
|
||||
return database
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
func TestCreateSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
id, token, err := database.CreateUser(ctx, "alice")
|
||||
sessionID, _, token, err := database.CreateSession(
|
||||
ctx, "alice",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id == 0 || token == "" {
|
||||
if sessionID == 0 || token == "" {
|
||||
t.Fatal("expected valid id and token")
|
||||
}
|
||||
|
||||
_, _, err = database.CreateUser(ctx, "alice")
|
||||
if err == nil {
|
||||
_, _, dupToken, dupErr := database.CreateSession(
|
||||
ctx, "alice",
|
||||
)
|
||||
if dupErr == nil {
|
||||
t.Fatal("expected error for duplicate nick")
|
||||
}
|
||||
|
||||
_ = dupToken
|
||||
}
|
||||
|
||||
func TestGetUserByToken(t *testing.T) {
|
||||
func TestGetSessionByToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
_, token, err := database.CreateUser(ctx, "bob")
|
||||
_, _, token, err := database.CreateSession(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
id, nick, err := database.GetUserByToken(ctx, token)
|
||||
sessionID, clientID, nick, err :=
|
||||
database.GetSessionByToken(ctx, token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if nick != "bob" || id == 0 {
|
||||
if nick != "bob" || sessionID == 0 || clientID == 0 {
|
||||
t.Fatalf("expected bob, got %s", nick)
|
||||
}
|
||||
|
||||
_, _, err = database.GetUserByToken(ctx, "badtoken")
|
||||
if err == nil {
|
||||
badSID, badCID, badNick, badErr :=
|
||||
database.GetSessionByToken(ctx, "badtoken")
|
||||
if badErr == nil {
|
||||
t.Fatal("expected error for bad token")
|
||||
}
|
||||
|
||||
if badSID != 0 || badCID != 0 || badNick != "" {
|
||||
t.Fatal("expected zero values on error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByNick(t *testing.T) {
|
||||
func TestGetSessionByNick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
_, _, err := database.CreateUser(ctx, "charlie")
|
||||
charlieID, charlieClientID, charlieToken, err :=
|
||||
database.CreateSession(ctx, "charlie")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
id, err := database.GetUserByNick(ctx, "charlie")
|
||||
if charlieID == 0 || charlieClientID == 0 {
|
||||
t.Fatal("expected valid session/client IDs")
|
||||
}
|
||||
|
||||
if charlieToken == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
|
||||
id, err := database.GetSessionByNick(ctx, "charlie")
|
||||
if err != nil || id == 0 {
|
||||
t.Fatal("expected to find charlie")
|
||||
}
|
||||
|
||||
_, err = database.GetUserByNick(ctx, "nobody")
|
||||
_, err = database.GetSessionByNick(ctx, "nobody")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown nick")
|
||||
}
|
||||
@@ -129,7 +150,7 @@ func TestJoinAndPart(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "user1")
|
||||
sid, _, _, err := database.CreateSession(ctx, "user1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -139,22 +160,22 @@ func TestJoinAndPart(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ids, err := database.GetChannelMemberIDs(ctx, chID)
|
||||
if err != nil || len(ids) != 1 || ids[0] != uid {
|
||||
t.Fatal("expected user in channel")
|
||||
if err != nil || len(ids) != 1 || ids[0] != sid {
|
||||
t.Fatal("expected session in channel")
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.PartChannel(ctx, chID, uid)
|
||||
err = database.PartChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -178,17 +199,17 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "temp")
|
||||
sid, _, _, err := database.CreateSession(ctx, "temp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.PartChannel(ctx, chID, uid)
|
||||
err = database.PartChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -204,7 +225,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserWithChannels(
|
||||
func createSessionWithChannels(
|
||||
t *testing.T,
|
||||
database *db.Database,
|
||||
nick, ch1Name, ch2Name string,
|
||||
@@ -213,7 +234,7 @@ func createUserWithChannels(
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, nick)
|
||||
sid, _, _, err := database.CreateSession(ctx, nick)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -232,29 +253,29 @@ func createUserWithChannels(
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, ch1, uid)
|
||||
err = database.JoinChannel(ctx, ch1, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, ch2, uid)
|
||||
err = database.JoinChannel(ctx, ch2, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return uid, ch1, ch2
|
||||
return sid, ch1, ch2
|
||||
}
|
||||
|
||||
func TestListChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
uid, _, _ := createUserWithChannels(
|
||||
sid, _, _ := createSessionWithChannels(
|
||||
t, database, "lister", "#a", "#b",
|
||||
)
|
||||
|
||||
channels, err := database.ListChannels(
|
||||
t.Context(), uid,
|
||||
t.Context(), sid,
|
||||
)
|
||||
if err != nil || len(channels) != 2 {
|
||||
t.Fatalf(
|
||||
@@ -295,17 +316,21 @@ func TestChangeNick(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, token, err := database.CreateUser(ctx, "old")
|
||||
sid, _, token, err := database.CreateSession(
|
||||
ctx, "old",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.ChangeNick(ctx, uid, "new")
|
||||
err = database.ChangeNick(ctx, sid, "new")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, nick, err := database.GetUserByToken(ctx, token)
|
||||
_, _, nick, err := database.GetSessionByToken(
|
||||
ctx, token,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -375,7 +400,16 @@ func TestPollMessages(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "poller")
|
||||
sid, _, token, err := database.CreateSession(
|
||||
ctx, "poller",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, clientID, _, err := database.GetSessionByToken(
|
||||
ctx, token,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -389,7 +423,7 @@ func TestPollMessages(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.EnqueueMessage(ctx, uid, dbID)
|
||||
err = database.EnqueueToSession(ctx, sid, dbID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -397,7 +431,7 @@ func TestPollMessages(t *testing.T) {
|
||||
const batchSize = 10
|
||||
|
||||
msgs, lastQID, err := database.PollMessages(
|
||||
ctx, uid, 0, batchSize,
|
||||
ctx, clientID, 0, batchSize,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -420,7 +454,7 @@ func TestPollMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
msgs, _, _ = database.PollMessages(
|
||||
ctx, uid, lastQID, batchSize,
|
||||
ctx, clientID, lastQID, batchSize,
|
||||
)
|
||||
|
||||
if len(msgs) != 0 {
|
||||
@@ -467,13 +501,15 @@ func TestGetHistory(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid, _, err := database.CreateUser(ctx, "deleteme")
|
||||
sid, _, _, err := database.CreateSession(
|
||||
ctx, "deleteme",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -485,19 +521,19 @@ func TestDeleteUser(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid)
|
||||
err = database.JoinChannel(ctx, chID, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.DeleteUser(ctx, uid)
|
||||
err = database.DeleteSession(ctx, sid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = database.GetUserByNick(ctx, "deleteme")
|
||||
_, err = database.GetSessionByNick(ctx, "deleteme")
|
||||
if err == nil {
|
||||
t.Fatal("user should be deleted")
|
||||
t.Fatal("session should be deleted")
|
||||
}
|
||||
|
||||
ids, _ := database.GetChannelMemberIDs(ctx, chID)
|
||||
@@ -512,12 +548,12 @@ func TestChannelMembers(t *testing.T) {
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
uid1, _, err := database.CreateUser(ctx, "m1")
|
||||
sid1, _, _, err := database.CreateSession(ctx, "m1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
uid2, _, err := database.CreateUser(ctx, "m2")
|
||||
sid2, _, _, err := database.CreateSession(ctx, "m2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -529,12 +565,12 @@ func TestChannelMembers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid1)
|
||||
err = database.JoinChannel(ctx, chID, sid1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.JoinChannel(ctx, chID, uid2)
|
||||
err = database.JoinChannel(ctx, chID, sid2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -548,17 +584,17 @@ func TestChannelMembers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllChannelMembershipsForUser(t *testing.T) {
|
||||
func TestGetSessionChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
uid, _, _ := createUserWithChannels(
|
||||
sid, _, _ := createSessionWithChannels(
|
||||
t, database, "multi", "#m1", "#m2",
|
||||
)
|
||||
|
||||
channels, err :=
|
||||
database.GetAllChannelMembershipsForUser(
|
||||
t.Context(), uid,
|
||||
database.GetSessionChannels(
|
||||
t.Context(), sid,
|
||||
)
|
||||
if err != nil || len(channels) != 2 {
|
||||
t.Fatalf(
|
||||
@@ -567,3 +603,51 @@ func TestGetAllChannelMembershipsForUser(t *testing.T) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueToClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
database := setupTestDB(t)
|
||||
ctx := t.Context()
|
||||
|
||||
_, _, token, err := database.CreateSession(
|
||||
ctx, "enqclient",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, clientID, _, err := database.GetSessionByToken(
|
||||
ctx, token,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := json.RawMessage(`["test"]`)
|
||||
|
||||
dbID, _, err := database.InsertMessage(
|
||||
ctx, "PRIVMSG", "sender", "#ch", body, nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = database.EnqueueToClient(ctx, clientID, dbID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const batchSize = 10
|
||||
|
||||
msgs, _, err := database.PollMessages(
|
||||
ctx, clientID, 0, batchSize,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(msgs))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
-- Chat server schema (pre-1.0 consolidated)
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- Users: IRC-style sessions (no passwords, just nick + token)
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
-- Sessions: each session is a user identity (nick + optional password + signing key)
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uuid TEXT NOT NULL UNIQUE,
|
||||
nick TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL DEFAULT '',
|
||||
signing_key TEXT NOT NULL DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid);
|
||||
|
||||
-- Clients: each session can have multiple connected clients
|
||||
CREATE TABLE IF NOT EXISTS clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uuid TEXT NOT NULL UNIQUE,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_token ON users(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_clients_token ON clients(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_clients_session ON clients(session_id);
|
||||
|
||||
-- Channels
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
@@ -24,9 +38,9 @@ CREATE TABLE IF NOT EXISTS channels (
|
||||
CREATE TABLE IF NOT EXISTS channel_members (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(channel_id, user_id)
|
||||
UNIQUE(channel_id, session_id)
|
||||
);
|
||||
|
||||
-- Messages: IRC envelope format
|
||||
@@ -46,9 +60,9 @@ CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
|
||||
-- Per-client message queues for fan-out delivery
|
||||
CREATE TABLE IF NOT EXISTS client_queues (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
client_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, message_id)
|
||||
UNIQUE(client_id, message_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_client_queues_user ON client_queues(user_id, id);
|
||||
CREATE INDEX IF NOT EXISTS idx_client_queues_client ON client_queues(client_id, id);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -37,35 +38,37 @@ func (hdlr *Handlers) maxBodySize() int64 {
|
||||
return defaultMaxBodySize
|
||||
}
|
||||
|
||||
// authUser extracts the user from the Authorization header.
|
||||
func (hdlr *Handlers) authUser(
|
||||
// authSession extracts the session from the client token.
|
||||
func (hdlr *Handlers) authSession(
|
||||
request *http.Request,
|
||||
) (int64, string, error) {
|
||||
) (int64, int64, string, error) {
|
||||
auth := request.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
return 0, "", errUnauthorized
|
||||
return 0, 0, "", errUnauthorized
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
if token == "" {
|
||||
return 0, "", errUnauthorized
|
||||
return 0, 0, "", errUnauthorized
|
||||
}
|
||||
|
||||
uid, nick, err := hdlr.params.Database.GetUserByToken(
|
||||
request.Context(), token,
|
||||
)
|
||||
sessionID, clientID, nick, err :=
|
||||
hdlr.params.Database.GetSessionByToken(
|
||||
request.Context(), token,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("auth: %w", err)
|
||||
return 0, 0, "", fmt.Errorf("auth: %w", err)
|
||||
}
|
||||
|
||||
return uid, nick, nil
|
||||
return sessionID, clientID, nick, nil
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) requireAuth(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) (int64, string, bool) {
|
||||
uid, nick, err := hdlr.authUser(request)
|
||||
) (int64, int64, string, bool) {
|
||||
sessionID, clientID, nick, err :=
|
||||
hdlr.authSession(request)
|
||||
if err != nil {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
@@ -73,19 +76,19 @@ func (hdlr *Handlers) requireAuth(
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
return 0, "", false
|
||||
return 0, 0, "", false
|
||||
}
|
||||
|
||||
return uid, nick, true
|
||||
return sessionID, clientID, nick, true
|
||||
}
|
||||
|
||||
// fanOut stores a message and enqueues it to all specified
|
||||
// user IDs, then notifies them.
|
||||
// session IDs, then notifies them.
|
||||
func (hdlr *Handlers) fanOut(
|
||||
request *http.Request,
|
||||
command, from, target string,
|
||||
body json.RawMessage,
|
||||
userIDs []int64,
|
||||
sessionIDs []int64,
|
||||
) (string, error) {
|
||||
dbID, msgUUID, err := hdlr.params.Database.InsertMessage(
|
||||
request.Context(), command, from, target, body, nil,
|
||||
@@ -94,16 +97,16 @@ func (hdlr *Handlers) fanOut(
|
||||
return "", fmt.Errorf("insert message: %w", err)
|
||||
}
|
||||
|
||||
for _, uid := range userIDs {
|
||||
enqErr := hdlr.params.Database.EnqueueMessage(
|
||||
request.Context(), uid, dbID,
|
||||
for _, sid := range sessionIDs {
|
||||
enqErr := hdlr.params.Database.EnqueueToSession(
|
||||
request.Context(), sid, dbID,
|
||||
)
|
||||
if enqErr != nil {
|
||||
hdlr.log.Error("enqueue failed",
|
||||
"error", enqErr, "user_id", uid)
|
||||
"error", enqErr, "session_id", sid)
|
||||
}
|
||||
|
||||
hdlr.broker.Notify(uid)
|
||||
hdlr.broker.Notify(sid)
|
||||
}
|
||||
|
||||
return msgUUID, nil
|
||||
@@ -114,10 +117,10 @@ func (hdlr *Handlers) fanOutSilent(
|
||||
request *http.Request,
|
||||
command, from, target string,
|
||||
body json.RawMessage,
|
||||
userIDs []int64,
|
||||
sessionIDs []int64,
|
||||
) error {
|
||||
_, err := hdlr.fanOut(
|
||||
request, command, from, target, body, userIDs,
|
||||
request, command, from, target, body, sessionIDs,
|
||||
)
|
||||
|
||||
return err
|
||||
@@ -125,16 +128,6 @@ func (hdlr *Handlers) fanOutSilent(
|
||||
|
||||
// HandleCreateSession creates a new user session.
|
||||
func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
|
||||
type createRequest struct {
|
||||
Nick string `json:"nick"`
|
||||
}
|
||||
|
||||
type createResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Nick string `json:"nick"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
return func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
@@ -143,82 +136,174 @@ func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc {
|
||||
writer, request.Body, hdlr.maxBodySize(),
|
||||
)
|
||||
|
||||
var payload createRequest
|
||||
|
||||
err := json.NewDecoder(request.Body).Decode(&payload)
|
||||
if err != nil {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid request body",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
payload.Nick = strings.TrimSpace(payload.Nick)
|
||||
|
||||
if !validNickRe.MatchString(payload.Nick) {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid nick format",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
userID, token, err := hdlr.params.Database.CreateUser(
|
||||
request.Context(), payload.Nick,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE") {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"nick already taken",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.log.Error(
|
||||
"create user failed", "error", err,
|
||||
)
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"internal error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.respondJSON(
|
||||
writer, request,
|
||||
&createResponse{
|
||||
ID: userID,
|
||||
Nick: payload.Nick,
|
||||
Token: token,
|
||||
},
|
||||
http.StatusCreated,
|
||||
)
|
||||
hdlr.handleCreateSession(writer, request)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleState returns the current user's info and channels.
|
||||
func (hdlr *Handlers) handleCreateSession(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
type createRequest struct {
|
||||
Nick string `json:"nick"`
|
||||
}
|
||||
|
||||
var payload createRequest
|
||||
|
||||
err := json.NewDecoder(request.Body).Decode(&payload)
|
||||
if err != nil {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid request body",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
payload.Nick = strings.TrimSpace(payload.Nick)
|
||||
|
||||
if !validNickRe.MatchString(payload.Nick) {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid nick format",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, clientID, token, err :=
|
||||
hdlr.params.Database.CreateSession(
|
||||
request.Context(), payload.Nick,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.handleCreateSessionError(
|
||||
writer, request, err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.deliverMOTD(request, clientID, sessionID)
|
||||
|
||||
hdlr.respondJSON(writer, request, map[string]any{
|
||||
"id": sessionID,
|
||||
"nick": payload.Nick,
|
||||
"token": token,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) handleCreateSessionError(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
err error,
|
||||
) {
|
||||
if strings.Contains(err.Error(), "UNIQUE") {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"nick already taken",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.log.Error(
|
||||
"create session failed", "error", err,
|
||||
)
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"internal error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// deliverMOTD sends the MOTD as IRC numeric messages to a
|
||||
// new client.
|
||||
func (hdlr *Handlers) deliverMOTD(
|
||||
request *http.Request,
|
||||
clientID, sessionID int64,
|
||||
) {
|
||||
motd := hdlr.params.Config.MOTD
|
||||
serverName := hdlr.params.Config.ServerName
|
||||
|
||||
if serverName == "" {
|
||||
serverName = "chat"
|
||||
}
|
||||
|
||||
if motd == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := request.Context()
|
||||
|
||||
hdlr.enqueueNumeric(
|
||||
ctx, clientID, "375", serverName,
|
||||
"- "+serverName+" Message of the Day -",
|
||||
)
|
||||
|
||||
for line := range strings.SplitSeq(motd, "\n") {
|
||||
hdlr.enqueueNumeric(
|
||||
ctx, clientID, "372", serverName,
|
||||
"- "+line,
|
||||
)
|
||||
}
|
||||
|
||||
hdlr.enqueueNumeric(
|
||||
ctx, clientID, "376", serverName,
|
||||
"End of /MOTD command.",
|
||||
)
|
||||
|
||||
hdlr.broker.Notify(sessionID)
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) enqueueNumeric(
|
||||
ctx context.Context,
|
||||
clientID int64,
|
||||
command, serverName, text string,
|
||||
) {
|
||||
body, err := json.Marshal([]string{text})
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"marshal numeric body", "error", err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
dbID, _, insertErr := hdlr.params.Database.InsertMessage(
|
||||
ctx, command, serverName, "",
|
||||
json.RawMessage(body), nil,
|
||||
)
|
||||
if insertErr != nil {
|
||||
hdlr.log.Error(
|
||||
"insert numeric message", "error", insertErr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_ = hdlr.params.Database.EnqueueToClient(
|
||||
ctx, clientID, dbID,
|
||||
)
|
||||
}
|
||||
|
||||
// HandleState returns the current session's info and
|
||||
// channels.
|
||||
func (hdlr *Handlers) HandleState() http.HandlerFunc {
|
||||
return func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
uid, nick, ok := hdlr.requireAuth(writer, request)
|
||||
sessionID, _, nick, ok :=
|
||||
hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := hdlr.params.Database.ListChannels(
|
||||
request.Context(), uid,
|
||||
request.Context(), sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
@@ -234,7 +319,7 @@ func (hdlr *Handlers) HandleState() http.HandlerFunc {
|
||||
}
|
||||
|
||||
hdlr.respondJSON(writer, request, map[string]any{
|
||||
"id": uid,
|
||||
"id": sessionID,
|
||||
"nick": nick,
|
||||
"channels": channels,
|
||||
}, http.StatusOK)
|
||||
@@ -247,7 +332,7 @@ func (hdlr *Handlers) HandleListAllChannels() http.HandlerFunc {
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
_, _, ok := hdlr.requireAuth(writer, request)
|
||||
_, _, _, ok := hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -280,7 +365,7 @@ func (hdlr *Handlers) HandleChannelMembers() http.HandlerFunc {
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
_, _, ok := hdlr.requireAuth(writer, request)
|
||||
_, _, _, ok := hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -328,7 +413,8 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
uid, _, ok := hdlr.requireAuth(writer, request)
|
||||
sessionID, clientID, _, ok :=
|
||||
hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -349,7 +435,7 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
|
||||
}
|
||||
|
||||
msgs, lastQID, err := hdlr.params.Database.PollMessages(
|
||||
request.Context(), uid,
|
||||
request.Context(), clientID,
|
||||
afterID, pollMessageLimit,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -374,17 +460,20 @@ func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.longPoll(writer, request, uid, afterID, timeout)
|
||||
hdlr.longPoll(
|
||||
writer, request,
|
||||
sessionID, clientID, afterID, timeout,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) longPoll(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid, afterID int64,
|
||||
sessionID, clientID, afterID int64,
|
||||
timeout int,
|
||||
) {
|
||||
waitCh := hdlr.broker.Wait(uid)
|
||||
waitCh := hdlr.broker.Wait(sessionID)
|
||||
|
||||
timer := time.NewTimer(
|
||||
time.Duration(timeout) * time.Second,
|
||||
@@ -396,15 +485,15 @@ func (hdlr *Handlers) longPoll(
|
||||
case <-waitCh:
|
||||
case <-timer.C:
|
||||
case <-request.Context().Done():
|
||||
hdlr.broker.Remove(uid, waitCh)
|
||||
hdlr.broker.Remove(sessionID, waitCh)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.broker.Remove(uid, waitCh)
|
||||
hdlr.broker.Remove(sessionID, waitCh)
|
||||
|
||||
msgs, lastQID, err := hdlr.params.Database.PollMessages(
|
||||
request.Context(), uid,
|
||||
request.Context(), clientID,
|
||||
afterID, pollMessageLimit,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -443,7 +532,8 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
|
||||
writer, request.Body, hdlr.maxBodySize(),
|
||||
)
|
||||
|
||||
uid, nick, ok := hdlr.requireAuth(writer, request)
|
||||
sessionID, _, nick, ok :=
|
||||
hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -492,7 +582,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
|
||||
}
|
||||
|
||||
hdlr.dispatchCommand(
|
||||
writer, request, uid, nick,
|
||||
writer, request, sessionID, nick,
|
||||
payload.Command, payload.To,
|
||||
payload.Body, bodyLines,
|
||||
)
|
||||
@@ -502,7 +592,7 @@ func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc {
|
||||
func (hdlr *Handlers) dispatchCommand(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, command, target string,
|
||||
body json.RawMessage,
|
||||
bodyLines func() []string,
|
||||
@@ -510,20 +600,20 @@ func (hdlr *Handlers) dispatchCommand(
|
||||
switch command {
|
||||
case cmdPrivmsg, "NOTICE":
|
||||
hdlr.handlePrivmsg(
|
||||
writer, request, uid, nick,
|
||||
writer, request, sessionID, nick,
|
||||
command, target, body, bodyLines,
|
||||
)
|
||||
case "JOIN":
|
||||
hdlr.handleJoin(
|
||||
writer, request, uid, nick, target,
|
||||
writer, request, sessionID, nick, target,
|
||||
)
|
||||
case "PART":
|
||||
hdlr.handlePart(
|
||||
writer, request, uid, nick, target, body,
|
||||
writer, request, sessionID, nick, target, body,
|
||||
)
|
||||
case "NICK":
|
||||
hdlr.handleNick(
|
||||
writer, request, uid, nick, bodyLines,
|
||||
writer, request, sessionID, nick, bodyLines,
|
||||
)
|
||||
case "TOPIC":
|
||||
hdlr.handleTopic(
|
||||
@@ -531,7 +621,7 @@ func (hdlr *Handlers) dispatchCommand(
|
||||
)
|
||||
case "QUIT":
|
||||
hdlr.handleQuit(
|
||||
writer, request, uid, nick, body,
|
||||
writer, request, sessionID, nick, body,
|
||||
)
|
||||
case "PING":
|
||||
hdlr.respondJSON(writer, request,
|
||||
@@ -552,7 +642,7 @@ func (hdlr *Handlers) dispatchCommand(
|
||||
func (hdlr *Handlers) handlePrivmsg(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, command, target string,
|
||||
body json.RawMessage,
|
||||
bodyLines func() []string,
|
||||
@@ -580,7 +670,7 @@ func (hdlr *Handlers) handlePrivmsg(
|
||||
|
||||
if strings.HasPrefix(target, "#") {
|
||||
hdlr.handleChannelMsg(
|
||||
writer, request, uid, nick,
|
||||
writer, request, sessionID, nick,
|
||||
command, target, body,
|
||||
)
|
||||
|
||||
@@ -588,7 +678,7 @@ func (hdlr *Handlers) handlePrivmsg(
|
||||
}
|
||||
|
||||
hdlr.handleDirectMsg(
|
||||
writer, request, uid, nick,
|
||||
writer, request, sessionID, nick,
|
||||
command, target, body,
|
||||
)
|
||||
}
|
||||
@@ -596,7 +686,7 @@ func (hdlr *Handlers) handlePrivmsg(
|
||||
func (hdlr *Handlers) handleChannelMsg(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, command, target string,
|
||||
body json.RawMessage,
|
||||
) {
|
||||
@@ -614,7 +704,7 @@ func (hdlr *Handlers) handleChannelMsg(
|
||||
}
|
||||
|
||||
isMember, err := hdlr.params.Database.IsChannelMember(
|
||||
request.Context(), chID, uid,
|
||||
request.Context(), chID, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
@@ -677,11 +767,11 @@ func (hdlr *Handlers) handleChannelMsg(
|
||||
func (hdlr *Handlers) handleDirectMsg(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, command, target string,
|
||||
body json.RawMessage,
|
||||
) {
|
||||
targetUID, err := hdlr.params.Database.GetUserByNick(
|
||||
targetSID, err := hdlr.params.Database.GetSessionByNick(
|
||||
request.Context(), target,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -694,9 +784,9 @@ func (hdlr *Handlers) handleDirectMsg(
|
||||
return
|
||||
}
|
||||
|
||||
recipients := []int64{targetUID}
|
||||
if targetUID != uid {
|
||||
recipients = append(recipients, uid)
|
||||
recipients := []int64{targetSID}
|
||||
if targetSID != sessionID {
|
||||
recipients = append(recipients, sessionID)
|
||||
}
|
||||
|
||||
msgUUID, err := hdlr.fanOut(
|
||||
@@ -721,7 +811,7 @@ func (hdlr *Handlers) handleDirectMsg(
|
||||
func (hdlr *Handlers) handleJoin(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, target string,
|
||||
) {
|
||||
if target == "" {
|
||||
@@ -766,7 +856,7 @@ func (hdlr *Handlers) handleJoin(
|
||||
}
|
||||
|
||||
err = hdlr.params.Database.JoinChannel(
|
||||
request.Context(), chID, uid,
|
||||
request.Context(), chID, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
@@ -800,7 +890,7 @@ func (hdlr *Handlers) handleJoin(
|
||||
func (hdlr *Handlers) handlePart(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, target string,
|
||||
body json.RawMessage,
|
||||
) {
|
||||
@@ -841,7 +931,7 @@ func (hdlr *Handlers) handlePart(
|
||||
)
|
||||
|
||||
err = hdlr.params.Database.PartChannel(
|
||||
request.Context(), chID, uid,
|
||||
request.Context(), chID, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
@@ -871,7 +961,7 @@ func (hdlr *Handlers) handlePart(
|
||||
func (hdlr *Handlers) handleNick(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick string,
|
||||
bodyLines func() []string,
|
||||
) {
|
||||
@@ -909,7 +999,7 @@ func (hdlr *Handlers) handleNick(
|
||||
}
|
||||
|
||||
err := hdlr.params.Database.ChangeNick(
|
||||
request.Context(), uid, newNick,
|
||||
request.Context(), sessionID, newNick,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE") {
|
||||
@@ -934,7 +1024,7 @@ func (hdlr *Handlers) handleNick(
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.broadcastNick(request, uid, nick, newNick)
|
||||
hdlr.broadcastNick(request, sessionID, nick, newNick)
|
||||
|
||||
hdlr.respondJSON(writer, request,
|
||||
map[string]string{
|
||||
@@ -945,15 +1035,15 @@ func (hdlr *Handlers) handleNick(
|
||||
|
||||
func (hdlr *Handlers) broadcastNick(
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
oldNick, newNick string,
|
||||
) {
|
||||
channels, _ := hdlr.params.Database.
|
||||
GetAllChannelMembershipsForUser(
|
||||
request.Context(), uid,
|
||||
GetSessionChannels(
|
||||
request.Context(), sessionID,
|
||||
)
|
||||
|
||||
notified := map[int64]bool{uid: true}
|
||||
notified := map[int64]bool{sessionID: true}
|
||||
|
||||
nickBody, err := json.Marshal([]string{newNick})
|
||||
if err != nil {
|
||||
@@ -969,11 +1059,11 @@ func (hdlr *Handlers) broadcastNick(
|
||||
json.RawMessage(nickBody), nil,
|
||||
)
|
||||
|
||||
_ = hdlr.params.Database.EnqueueMessage(
|
||||
request.Context(), uid, dbID,
|
||||
_ = hdlr.params.Database.EnqueueToSession(
|
||||
request.Context(), sessionID, dbID,
|
||||
)
|
||||
|
||||
hdlr.broker.Notify(uid)
|
||||
hdlr.broker.Notify(sessionID)
|
||||
|
||||
for _, chanInfo := range channels {
|
||||
memberIDs, _ := hdlr.params.Database.
|
||||
@@ -985,7 +1075,7 @@ func (hdlr *Handlers) broadcastNick(
|
||||
if !notified[mid] {
|
||||
notified[mid] = true
|
||||
|
||||
_ = hdlr.params.Database.EnqueueMessage(
|
||||
_ = hdlr.params.Database.EnqueueToSession(
|
||||
request.Context(), mid, dbID,
|
||||
)
|
||||
|
||||
@@ -1077,13 +1167,13 @@ func (hdlr *Handlers) handleTopic(
|
||||
func (hdlr *Handlers) handleQuit(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick string,
|
||||
body json.RawMessage,
|
||||
) {
|
||||
channels, _ := hdlr.params.Database.
|
||||
GetAllChannelMembershipsForUser(
|
||||
request.Context(), uid,
|
||||
GetSessionChannels(
|
||||
request.Context(), sessionID,
|
||||
)
|
||||
|
||||
notified := map[int64]bool{}
|
||||
@@ -1103,10 +1193,10 @@ func (hdlr *Handlers) handleQuit(
|
||||
)
|
||||
|
||||
for _, mid := range memberIDs {
|
||||
if mid != uid && !notified[mid] {
|
||||
if mid != sessionID && !notified[mid] {
|
||||
notified[mid] = true
|
||||
|
||||
_ = hdlr.params.Database.EnqueueMessage(
|
||||
_ = hdlr.params.Database.EnqueueToSession(
|
||||
request.Context(), mid, dbID,
|
||||
)
|
||||
|
||||
@@ -1115,7 +1205,7 @@ func (hdlr *Handlers) handleQuit(
|
||||
}
|
||||
|
||||
_ = hdlr.params.Database.PartChannel(
|
||||
request.Context(), chanInfo.ID, uid,
|
||||
request.Context(), chanInfo.ID, sessionID,
|
||||
)
|
||||
|
||||
_ = hdlr.params.Database.DeleteChannelIfEmpty(
|
||||
@@ -1123,8 +1213,8 @@ func (hdlr *Handlers) handleQuit(
|
||||
)
|
||||
}
|
||||
|
||||
_ = hdlr.params.Database.DeleteUser(
|
||||
request.Context(), uid,
|
||||
_ = hdlr.params.Database.DeleteSession(
|
||||
request.Context(), sessionID,
|
||||
)
|
||||
|
||||
hdlr.respondJSON(writer, request,
|
||||
@@ -1138,7 +1228,8 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
uid, nick, ok := hdlr.requireAuth(writer, request)
|
||||
sessionID, _, nick, ok :=
|
||||
hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -1155,7 +1246,7 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
|
||||
}
|
||||
|
||||
if !hdlr.canAccessHistory(
|
||||
writer, request, uid, nick, target,
|
||||
writer, request, sessionID, nick, target,
|
||||
) {
|
||||
return
|
||||
}
|
||||
@@ -1198,12 +1289,12 @@ func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc {
|
||||
func (hdlr *Handlers) canAccessHistory(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
nick, target string,
|
||||
) bool {
|
||||
if strings.HasPrefix(target, "#") {
|
||||
return hdlr.canAccessChannelHistory(
|
||||
writer, request, uid, target,
|
||||
writer, request, sessionID, target,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1225,7 +1316,7 @@ func (hdlr *Handlers) canAccessHistory(
|
||||
func (hdlr *Handlers) canAccessChannelHistory(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
uid int64,
|
||||
sessionID int64,
|
||||
target string,
|
||||
) bool {
|
||||
chID, err := hdlr.params.Database.GetChannelByName(
|
||||
@@ -1242,7 +1333,7 @@ func (hdlr *Handlers) canAccessChannelHistory(
|
||||
}
|
||||
|
||||
isMember, err := hdlr.params.Database.IsChannelMember(
|
||||
request.Context(), chID, uid,
|
||||
request.Context(), chID, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
@@ -1270,20 +1361,139 @@ func (hdlr *Handlers) canAccessChannelHistory(
|
||||
return true
|
||||
}
|
||||
|
||||
// HandleServerInfo returns server metadata.
|
||||
func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc {
|
||||
type infoResponse struct {
|
||||
Name string `json:"name"`
|
||||
MOTD string `json:"motd"`
|
||||
}
|
||||
|
||||
// HandleLogout deletes the authenticated client's token
|
||||
// and cleans up the user (session) if no clients remain.
|
||||
func (hdlr *Handlers) HandleLogout() http.HandlerFunc {
|
||||
return func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
hdlr.respondJSON(writer, request, &infoResponse{
|
||||
Name: hdlr.params.Config.ServerName,
|
||||
MOTD: hdlr.params.Config.MOTD,
|
||||
sessionID, clientID, nick, ok :=
|
||||
hdlr.requireAuth(writer, request)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := request.Context()
|
||||
|
||||
err := hdlr.params.Database.DeleteClient(
|
||||
ctx, clientID,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"delete client failed", "error", err,
|
||||
)
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"internal error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// If no clients remain, clean up the user fully:
|
||||
// part all channels (notifying members) and
|
||||
// delete the session.
|
||||
remaining, err := hdlr.params.Database.
|
||||
ClientCountForSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"client count check failed", "error", err,
|
||||
)
|
||||
}
|
||||
|
||||
if remaining == 0 {
|
||||
hdlr.cleanupUser(
|
||||
ctx, sessionID, nick,
|
||||
)
|
||||
}
|
||||
|
||||
hdlr.respondJSON(writer, request,
|
||||
map[string]string{"status": "ok"},
|
||||
http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupUser parts the user from all channels (notifying
|
||||
// members) and deletes the session.
|
||||
func (hdlr *Handlers) cleanupUser(
|
||||
ctx context.Context,
|
||||
sessionID int64,
|
||||
nick string,
|
||||
) {
|
||||
channels, _ := hdlr.params.Database.
|
||||
GetSessionChannels(ctx, sessionID)
|
||||
|
||||
notified := map[int64]bool{}
|
||||
|
||||
var quitDBID int64
|
||||
|
||||
if len(channels) > 0 {
|
||||
quitDBID, _, _ = hdlr.params.Database.InsertMessage(
|
||||
ctx, "QUIT", nick, "", nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
for _, chanInfo := range channels {
|
||||
memberIDs, _ := hdlr.params.Database.
|
||||
GetChannelMemberIDs(ctx, chanInfo.ID)
|
||||
|
||||
for _, mid := range memberIDs {
|
||||
if mid != sessionID && !notified[mid] {
|
||||
notified[mid] = true
|
||||
|
||||
_ = hdlr.params.Database.EnqueueToSession(
|
||||
ctx, mid, quitDBID,
|
||||
)
|
||||
|
||||
hdlr.broker.Notify(mid)
|
||||
}
|
||||
}
|
||||
|
||||
_ = hdlr.params.Database.PartChannel(
|
||||
ctx, chanInfo.ID, sessionID,
|
||||
)
|
||||
|
||||
_ = hdlr.params.Database.DeleteChannelIfEmpty(
|
||||
ctx, chanInfo.ID,
|
||||
)
|
||||
}
|
||||
|
||||
_ = hdlr.params.Database.DeleteSession(ctx, sessionID)
|
||||
}
|
||||
|
||||
// HandleUsersMe returns the current user's session info.
|
||||
func (hdlr *Handlers) HandleUsersMe() http.HandlerFunc {
|
||||
return hdlr.HandleState()
|
||||
}
|
||||
|
||||
// HandleServerInfo returns server metadata.
|
||||
func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc {
|
||||
return func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
users, err := hdlr.params.Database.GetUserCount(
|
||||
request.Context(),
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"get user count failed", "error", err,
|
||||
)
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"internal error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.respondJSON(writer, request, map[string]any{
|
||||
"name": hdlr.params.Config.ServerName,
|
||||
"motd": hdlr.params.Config.MOTD,
|
||||
"users": users,
|
||||
}, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1469,6 +1469,310 @@ func TestHealthcheck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterValid(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
body, err := json.Marshal(map[string]string{
|
||||
"nick": "reguser", "password": "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/register"),
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf(
|
||||
"expected 201, got %d: %s",
|
||||
resp.StatusCode, respBody,
|
||||
)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
|
||||
_ = json.NewDecoder(resp.Body).Decode(&result)
|
||||
|
||||
if result["token"] == nil || result["token"] == "" {
|
||||
t.Fatal("expected token in response")
|
||||
}
|
||||
|
||||
if result["nick"] != "reguser" {
|
||||
t.Fatalf(
|
||||
"expected reguser, got %v", result["nick"],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterDuplicate(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
body, err := json.Marshal(map[string]string{
|
||||
"nick": "dupuser", "password": "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/register"),
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
resp2, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/register"),
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() { _ = resp2.Body.Close() }()
|
||||
|
||||
if resp2.StatusCode != http.StatusConflict {
|
||||
t.Fatalf("expected 409, got %d", resp2.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func postJSONExpectStatus(
|
||||
t *testing.T,
|
||||
tserver *testServer,
|
||||
path string,
|
||||
payload map[string]string,
|
||||
expectedStatus int,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url(path),
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != expectedStatus {
|
||||
t.Fatalf(
|
||||
"expected %d, got %d",
|
||||
expectedStatus, resp.StatusCode,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterShortPassword(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
postJSONExpectStatus(
|
||||
t, tserver, "/api/v1/register",
|
||||
map[string]string{
|
||||
"nick": "shortpw", "password": "short",
|
||||
},
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
}
|
||||
|
||||
func TestRegisterInvalidNick(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
postJSONExpectStatus(
|
||||
t, tserver, "/api/v1/register",
|
||||
map[string]string{
|
||||
"nick": "bad nick!",
|
||||
"password": "password123",
|
||||
},
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
}
|
||||
|
||||
func TestLoginValid(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
// Register first.
|
||||
regBody, err := json.Marshal(map[string]string{
|
||||
"nick": "loginuser", "password": "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/register"),
|
||||
bytes.NewReader(regBody),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// Login.
|
||||
loginBody, err := json.Marshal(map[string]string{
|
||||
"nick": "loginuser", "password": "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp2, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/login"),
|
||||
bytes.NewReader(loginBody),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() { _ = resp2.Body.Close() }()
|
||||
|
||||
if resp2.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp2.Body)
|
||||
t.Fatalf(
|
||||
"expected 200, got %d: %s",
|
||||
resp2.StatusCode, respBody,
|
||||
)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
|
||||
_ = json.NewDecoder(resp2.Body).Decode(&result)
|
||||
|
||||
if result["token"] == nil || result["token"] == "" {
|
||||
t.Fatal("expected token in response")
|
||||
}
|
||||
|
||||
// Verify token works.
|
||||
token, ok := result["token"].(string)
|
||||
if !ok {
|
||||
t.Fatal("token not a string")
|
||||
}
|
||||
|
||||
status, state := tserver.getState(token)
|
||||
if status != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", status)
|
||||
}
|
||||
|
||||
if state["nick"] != "loginuser" {
|
||||
t.Fatalf(
|
||||
"expected loginuser, got %v",
|
||||
state["nick"],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginWrongPassword(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
regBody, err := json.Marshal(map[string]string{
|
||||
"nick": "wrongpwuser", "password": "correctpass1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/register"),
|
||||
bytes.NewReader(regBody),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
loginBody, err := json.Marshal(map[string]string{
|
||||
"nick": "wrongpwuser", "password": "wrongpass12",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp2, err := doRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
tserver.url("/api/v1/login"),
|
||||
bytes.NewReader(loginBody),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() { _ = resp2.Body.Close() }()
|
||||
|
||||
if resp2.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf(
|
||||
"expected 401, got %d", resp2.StatusCode,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginNonexistentUser(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
postJSONExpectStatus(
|
||||
t, tserver, "/api/v1/login",
|
||||
map[string]string{
|
||||
"nick": "ghostuser",
|
||||
"password": "password123",
|
||||
},
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
}
|
||||
|
||||
func TestSessionStillWorks(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
|
||||
// Verify anonymous session creation still works.
|
||||
token := tserver.createSession("anon_user")
|
||||
if token == "" {
|
||||
t.Fatal("expected token for anonymous session")
|
||||
}
|
||||
|
||||
status, state := tserver.getState(token)
|
||||
if status != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", status)
|
||||
}
|
||||
|
||||
if state["nick"] != "anon_user" {
|
||||
t.Fatalf(
|
||||
"expected anon_user, got %v",
|
||||
state["nick"],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNickBroadcastToChannels(t *testing.T) {
|
||||
tserver := newTestServer(t)
|
||||
aliceToken := tserver.createSession("nick_a")
|
||||
|
||||
186
internal/handlers/auth.go
Normal file
186
internal/handlers/auth.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// HandleRegister creates a new user with a password.
|
||||
func (hdlr *Handlers) HandleRegister() http.HandlerFunc {
|
||||
return func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
request.Body = http.MaxBytesReader(
|
||||
writer, request.Body, hdlr.maxBodySize(),
|
||||
)
|
||||
|
||||
hdlr.handleRegister(writer, request)
|
||||
}
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) handleRegister(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
type registerRequest struct {
|
||||
Nick string `json:"nick"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var payload registerRequest
|
||||
|
||||
err := json.NewDecoder(request.Body).Decode(&payload)
|
||||
if err != nil {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid request body",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
payload.Nick = strings.TrimSpace(payload.Nick)
|
||||
|
||||
if !validNickRe.MatchString(payload.Nick) {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid nick format",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(payload.Password) < minPasswordLength {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"password must be at least 8 characters",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, clientID, token, err :=
|
||||
hdlr.params.Database.RegisterUser(
|
||||
request.Context(),
|
||||
payload.Nick,
|
||||
payload.Password,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.handleRegisterError(
|
||||
writer, request, err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.deliverMOTD(request, clientID, sessionID)
|
||||
|
||||
hdlr.respondJSON(writer, request, map[string]any{
|
||||
"id": sessionID,
|
||||
"nick": payload.Nick,
|
||||
"token": token,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) handleRegisterError(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
err error,
|
||||
) {
|
||||
if strings.Contains(err.Error(), "UNIQUE") {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"nick already taken",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.log.Error(
|
||||
"register user failed", "error", err,
|
||||
)
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"internal error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// HandleLogin authenticates a user with nick and password.
|
||||
func (hdlr *Handlers) HandleLogin() http.HandlerFunc {
|
||||
return func(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
request.Body = http.MaxBytesReader(
|
||||
writer, request.Body, hdlr.maxBodySize(),
|
||||
)
|
||||
|
||||
hdlr.handleLogin(writer, request)
|
||||
}
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) handleLogin(
|
||||
writer http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) {
|
||||
type loginRequest struct {
|
||||
Nick string `json:"nick"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var payload loginRequest
|
||||
|
||||
err := json.NewDecoder(request.Body).Decode(&payload)
|
||||
if err != nil {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid request body",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
payload.Nick = strings.TrimSpace(payload.Nick)
|
||||
|
||||
if payload.Nick == "" || payload.Password == "" {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"nick and password required",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, _, token, err :=
|
||||
hdlr.params.Database.LoginUser(
|
||||
request.Context(),
|
||||
payload.Nick,
|
||||
payload.Password,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.respondError(
|
||||
writer, request,
|
||||
"invalid credentials",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hdlr.respondJSON(writer, request, map[string]any{
|
||||
"id": sessionID,
|
||||
"nick": payload.Nick,
|
||||
"token": token,
|
||||
}, http.StatusOK)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/chat/internal/broker"
|
||||
"git.eeqj.de/sneak/chat/internal/config"
|
||||
@@ -30,12 +31,15 @@ type Params struct {
|
||||
Healthcheck *healthcheck.Healthcheck
|
||||
}
|
||||
|
||||
const defaultIdleTimeout = 24 * time.Hour
|
||||
|
||||
// Handlers manages HTTP request handling.
|
||||
type Handlers struct {
|
||||
params *Params
|
||||
log *slog.Logger
|
||||
hc *healthcheck.Healthcheck
|
||||
broker *broker.Broker
|
||||
params *Params
|
||||
log *slog.Logger
|
||||
hc *healthcheck.Healthcheck
|
||||
broker *broker.Broker
|
||||
cancelCleanup context.CancelFunc
|
||||
}
|
||||
|
||||
// New creates a new Handlers instance.
|
||||
@@ -43,7 +47,7 @@ func New(
|
||||
lifecycle fx.Lifecycle,
|
||||
params Params,
|
||||
) (*Handlers, error) {
|
||||
hdlr := &Handlers{
|
||||
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
|
||||
params: ¶ms,
|
||||
log: params.Logger.Get(),
|
||||
hc: params.Healthcheck,
|
||||
@@ -51,10 +55,14 @@ func New(
|
||||
}
|
||||
|
||||
lifecycle.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error {
|
||||
OnStart: func(ctx context.Context) error {
|
||||
hdlr.startCleanup(ctx)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnStop: func(_ context.Context) error {
|
||||
hdlr.stopCleanup()
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -96,3 +104,100 @@ func (hdlr *Handlers) respondError(
|
||||
status,
|
||||
)
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) idleTimeout() time.Duration {
|
||||
raw := hdlr.params.Config.SessionIdleTimeout
|
||||
if raw == "" {
|
||||
return defaultIdleTimeout
|
||||
}
|
||||
|
||||
dur, err := time.ParseDuration(raw)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"invalid SESSION_IDLE_TIMEOUT, using default",
|
||||
"value", raw, "error", err,
|
||||
)
|
||||
|
||||
return defaultIdleTimeout
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
|
||||
// startCleanup launches the idle-user cleanup goroutine.
|
||||
// We use context.Background rather than the OnStart ctx
|
||||
// because the OnStart context is startup-scoped and would
|
||||
// cancel the goroutine once all start hooks complete.
|
||||
//
|
||||
//nolint:contextcheck // intentional Background ctx
|
||||
func (hdlr *Handlers) startCleanup(_ context.Context) {
|
||||
cleanupCtx, cancel := context.WithCancel(
|
||||
context.Background(),
|
||||
)
|
||||
hdlr.cancelCleanup = cancel
|
||||
|
||||
go hdlr.cleanupLoop(cleanupCtx)
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) stopCleanup() {
|
||||
if hdlr.cancelCleanup != nil {
|
||||
hdlr.cancelCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) cleanupLoop(ctx context.Context) {
|
||||
timeout := hdlr.idleTimeout()
|
||||
|
||||
interval := max(timeout/2, time.Minute) //nolint:mnd // half the timeout
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hdlr.runCleanup(ctx, timeout)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hdlr *Handlers) runCleanup(
|
||||
ctx context.Context,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
cutoff := time.Now().Add(-timeout)
|
||||
|
||||
// Find sessions that will be orphaned so we can send
|
||||
// QUIT notifications before deleting anything.
|
||||
stale, err := hdlr.params.Database.
|
||||
GetStaleOrphanSessions(ctx, cutoff)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"stale session lookup failed", "error", err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ss := range stale {
|
||||
hdlr.cleanupUser(ctx, ss.ID, ss.Nick)
|
||||
}
|
||||
|
||||
deleted, err := hdlr.params.Database.DeleteStaleUsers(
|
||||
ctx, cutoff,
|
||||
)
|
||||
if err != nil {
|
||||
hdlr.log.Error(
|
||||
"user cleanup failed", "error", err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if deleted > 0 {
|
||||
hdlr.log.Info(
|
||||
"cleaned up stale users",
|
||||
"deleted", deleted,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,48 +59,63 @@ func (srv *Server) SetupRoutes() {
|
||||
}
|
||||
|
||||
// API v1.
|
||||
srv.router.Route(
|
||||
"/api/v1",
|
||||
func(router chi.Router) {
|
||||
router.Get(
|
||||
"/server",
|
||||
srv.handlers.HandleServerInfo(),
|
||||
)
|
||||
router.Post(
|
||||
"/session",
|
||||
srv.handlers.HandleCreateSession(),
|
||||
)
|
||||
router.Get(
|
||||
"/state",
|
||||
srv.handlers.HandleState(),
|
||||
)
|
||||
router.Get(
|
||||
"/messages",
|
||||
srv.handlers.HandleGetMessages(),
|
||||
)
|
||||
router.Post(
|
||||
"/messages",
|
||||
srv.handlers.HandleSendCommand(),
|
||||
)
|
||||
router.Get(
|
||||
"/history",
|
||||
srv.handlers.HandleGetHistory(),
|
||||
)
|
||||
router.Get(
|
||||
"/channels",
|
||||
srv.handlers.HandleListAllChannels(),
|
||||
)
|
||||
router.Get(
|
||||
"/channels/{channel}/members",
|
||||
srv.handlers.HandleChannelMembers(),
|
||||
)
|
||||
},
|
||||
)
|
||||
srv.router.Route("/api/v1", srv.setupAPIv1)
|
||||
|
||||
// Serve embedded SPA.
|
||||
srv.setupSPA()
|
||||
}
|
||||
|
||||
func (srv *Server) setupAPIv1(router chi.Router) {
|
||||
router.Get(
|
||||
"/server",
|
||||
srv.handlers.HandleServerInfo(),
|
||||
)
|
||||
router.Post(
|
||||
"/session",
|
||||
srv.handlers.HandleCreateSession(),
|
||||
)
|
||||
router.Post(
|
||||
"/register",
|
||||
srv.handlers.HandleRegister(),
|
||||
)
|
||||
router.Post(
|
||||
"/login",
|
||||
srv.handlers.HandleLogin(),
|
||||
)
|
||||
router.Get(
|
||||
"/state",
|
||||
srv.handlers.HandleState(),
|
||||
)
|
||||
router.Post(
|
||||
"/logout",
|
||||
srv.handlers.HandleLogout(),
|
||||
)
|
||||
router.Get(
|
||||
"/users/me",
|
||||
srv.handlers.HandleUsersMe(),
|
||||
)
|
||||
router.Get(
|
||||
"/messages",
|
||||
srv.handlers.HandleGetMessages(),
|
||||
)
|
||||
router.Post(
|
||||
"/messages",
|
||||
srv.handlers.HandleSendCommand(),
|
||||
)
|
||||
router.Get(
|
||||
"/history",
|
||||
srv.handlers.HandleGetHistory(),
|
||||
)
|
||||
router.Get(
|
||||
"/channels",
|
||||
srv.handlers.HandleListAllChannels(),
|
||||
)
|
||||
router.Get(
|
||||
"/channels/{channel}/members",
|
||||
srv.handlers.HandleChannelMembers(),
|
||||
)
|
||||
}
|
||||
|
||||
func (srv *Server) setupSPA() {
|
||||
distFS, err := fs.Sub(web.Dist, "dist")
|
||||
if err != nil {
|
||||
|
||||
466
web/dist/app.js
vendored
466
web/dist/app.js
vendored
File diff suppressed because one or more lines are too long
43
web/dist/style.css
vendored
43
web/dist/style.css
vendored
@@ -14,6 +14,9 @@
|
||||
--tab-active: #e94560;
|
||||
--tab-bg: #16213e;
|
||||
--tab-hover: #1a1a3e;
|
||||
--topic-bg: #121a30;
|
||||
--unread-bg: #e94560;
|
||||
--warn: #f0ad4e;
|
||||
}
|
||||
|
||||
html, body, #root {
|
||||
@@ -86,6 +89,7 @@ html, body, #root {
|
||||
border-bottom: 1px solid var(--border);
|
||||
overflow-x: auto;
|
||||
flex-shrink: 0;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.tab {
|
||||
@@ -95,6 +99,7 @@ html, body, #root {
|
||||
white-space: nowrap;
|
||||
color: var(--text-muted);
|
||||
user-select: none;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.tab:hover {
|
||||
@@ -116,6 +121,43 @@ html, body, #root {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.tab .unread-badge {
|
||||
display: inline-block;
|
||||
background: var(--unread-bg);
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-weight: bold;
|
||||
padding: 1px 5px;
|
||||
border-radius: 8px;
|
||||
margin-left: 6px;
|
||||
min-width: 16px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Connection status */
|
||||
.connection-status {
|
||||
padding: 4px 12px;
|
||||
background: var(--warn);
|
||||
color: #1a1a2e;
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
white-space: nowrap;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Topic bar */
|
||||
.topic-bar {
|
||||
padding: 6px 12px;
|
||||
background: var(--topic-bg);
|
||||
border-bottom: 1px solid var(--border);
|
||||
color: var(--text-muted);
|
||||
font-size: 12px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Content area */
|
||||
.content {
|
||||
display: flex;
|
||||
@@ -243,6 +285,7 @@ html, body, #root {
|
||||
gap: 8px;
|
||||
background: var(--bg-secondary);
|
||||
border-bottom: 1px solid var(--border);
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.join-dialog input {
|
||||
|
||||
382
web/src/app.jsx
382
web/src/app.jsx
@@ -1,13 +1,17 @@
|
||||
import { h, render, Component } from 'preact';
|
||||
import { h, render } from 'preact';
|
||||
import { useState, useEffect, useRef, useCallback } from 'preact/hooks';
|
||||
|
||||
const API = '/api/v1';
|
||||
const POLL_TIMEOUT = 15;
|
||||
const RECONNECT_DELAY = 3000;
|
||||
const MEMBER_REFRESH_INTERVAL = 10000;
|
||||
|
||||
function api(path, opts = {}) {
|
||||
const token = localStorage.getItem('chat_token');
|
||||
const headers = { 'Content-Type': 'application/json', ...(opts.headers || {}) };
|
||||
if (token) headers['Authorization'] = `Bearer ${token}`;
|
||||
return fetch(API + path, { ...opts, headers }).then(async r => {
|
||||
const { signal, ...rest } = opts;
|
||||
return fetch(API + path, { ...rest, headers, signal }).then(async r => {
|
||||
const data = await r.json().catch(() => null);
|
||||
if (!r.ok) throw { status: r.status, data };
|
||||
return data;
|
||||
@@ -19,7 +23,6 @@ function formatTime(ts) {
|
||||
return d.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' });
|
||||
}
|
||||
|
||||
// Nick color hashing
|
||||
function nickColor(nick) {
|
||||
let h = 0;
|
||||
for (let i = 0; i < nick.length; i++) h = nick.charCodeAt(i) + ((h << 5) - h);
|
||||
@@ -39,10 +42,9 @@ function LoginScreen({ onLogin }) {
|
||||
if (s.name) setServerName(s.name);
|
||||
if (s.motd) setMotd(s.motd);
|
||||
}).catch(() => {});
|
||||
// Check for saved token
|
||||
const saved = localStorage.getItem('chat_token');
|
||||
if (saved) {
|
||||
api('/state').then(u => onLogin(u.nick, saved)).catch(() => localStorage.removeItem('chat_token'));
|
||||
api('/state').then(u => onLogin(u.nick)).catch(() => localStorage.removeItem('chat_token'));
|
||||
}
|
||||
inputRef.current?.focus();
|
||||
}, []);
|
||||
@@ -56,7 +58,7 @@ function LoginScreen({ onLogin }) {
|
||||
body: JSON.stringify({ nick: nick.trim() })
|
||||
});
|
||||
localStorage.setItem('chat_token', res.token);
|
||||
onLogin(res.nick, res.token);
|
||||
onLogin(res.nick);
|
||||
} catch (err) {
|
||||
setError(err.data?.error || 'Connection failed');
|
||||
}
|
||||
@@ -84,11 +86,19 @@ function LoginScreen({ onLogin }) {
|
||||
}
|
||||
|
||||
function Message({ msg }) {
|
||||
if (msg.system) {
|
||||
return (
|
||||
<div class="message system">
|
||||
<span class="timestamp">{formatTime(msg.ts)}</span>
|
||||
<span class="content">{msg.text}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div class={`message ${msg.system ? 'system' : ''}`}>
|
||||
<span class="timestamp">{formatTime(msg.createdAt)}</span>
|
||||
<span class="nick" style={{ color: msg.system ? undefined : nickColor(msg.nick) }}>{msg.nick}</span>
|
||||
<span class="content">{msg.content}</span>
|
||||
<div class="message">
|
||||
<span class="timestamp">{formatTime(msg.ts)}</span>
|
||||
<span class="nick" style={{ color: nickColor(msg.from) }}>{msg.from}</span>
|
||||
<span class="content">{msg.text}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -98,93 +108,194 @@ function App() {
|
||||
const [nick, setNick] = useState('');
|
||||
const [tabs, setTabs] = useState([{ type: 'server', name: 'Server' }]);
|
||||
const [activeTab, setActiveTab] = useState(0);
|
||||
const [messages, setMessages] = useState({ server: [] }); // keyed by tab name
|
||||
const [members, setMembers] = useState({}); // keyed by channel name
|
||||
const [messages, setMessages] = useState({ Server: [] });
|
||||
const [members, setMembers] = useState({});
|
||||
const [topics, setTopics] = useState({});
|
||||
const [unread, setUnread] = useState({});
|
||||
const [input, setInput] = useState('');
|
||||
const [joinInput, setJoinInput] = useState('');
|
||||
const [lastMsgId, setLastMsgId] = useState(0);
|
||||
const [connected, setConnected] = useState(true);
|
||||
|
||||
const lastIdRef = useRef(0);
|
||||
const seenIdsRef = useRef(new Set());
|
||||
const pollAbortRef = useRef(null);
|
||||
const tabsRef = useRef(tabs);
|
||||
const activeTabRef = useRef(activeTab);
|
||||
const nickRef = useRef(nick);
|
||||
const messagesEndRef = useRef();
|
||||
const inputRef = useRef();
|
||||
const pollRef = useRef();
|
||||
|
||||
useEffect(() => { tabsRef.current = tabs; }, [tabs]);
|
||||
useEffect(() => { activeTabRef.current = activeTab; }, [activeTab]);
|
||||
useEffect(() => { nickRef.current = nick; }, [nick]);
|
||||
|
||||
// Persist joined channels
|
||||
useEffect(() => {
|
||||
const channels = tabs.filter(t => t.type === 'channel').map(t => t.name);
|
||||
localStorage.setItem('chat_channels', JSON.stringify(channels));
|
||||
}, [tabs]);
|
||||
|
||||
// Clear unread on tab switch
|
||||
useEffect(() => {
|
||||
const tab = tabs[activeTab];
|
||||
if (tab) setUnread(prev => ({ ...prev, [tab.name]: 0 }));
|
||||
}, [activeTab, tabs]);
|
||||
|
||||
const addMessage = useCallback((tabName, msg) => {
|
||||
if (msg.id && seenIdsRef.current.has(msg.id)) return;
|
||||
if (msg.id) seenIdsRef.current.add(msg.id);
|
||||
setMessages(prev => ({
|
||||
...prev,
|
||||
[tabName]: [...(prev[tabName] || []), msg]
|
||||
}));
|
||||
const currentTab = tabsRef.current[activeTabRef.current];
|
||||
if (!currentTab || currentTab.name !== tabName) {
|
||||
setUnread(prev => ({ ...prev, [tabName]: (prev[tabName] || 0) + 1 }));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const addSystemMessage = useCallback((tabName, text) => {
|
||||
addMessage(tabName, {
|
||||
id: Date.now(),
|
||||
nick: '*',
|
||||
content: text,
|
||||
createdAt: new Date().toISOString(),
|
||||
system: true
|
||||
});
|
||||
}, [addMessage]);
|
||||
setMessages(prev => ({
|
||||
...prev,
|
||||
[tabName]: [...(prev[tabName] || []), {
|
||||
id: 'sys-' + Date.now() + '-' + Math.random(),
|
||||
ts: new Date().toISOString(),
|
||||
text,
|
||||
system: true
|
||||
}]
|
||||
}));
|
||||
}, []);
|
||||
|
||||
const onLogin = useCallback((userNick, token) => {
|
||||
setNick(userNick);
|
||||
setLoggedIn(true);
|
||||
addSystemMessage('server', `Connected as ${userNick}`);
|
||||
// Fetch server info
|
||||
api('/server').then(s => {
|
||||
if (s.motd) addSystemMessage('server', `MOTD: ${s.motd}`);
|
||||
const refreshMembers = useCallback((channel) => {
|
||||
const chName = channel.replace('#', '');
|
||||
api(`/channels/${chName}/members`).then(m => {
|
||||
setMembers(prev => ({ ...prev, [channel]: m }));
|
||||
}).catch(() => {});
|
||||
}, [addSystemMessage]);
|
||||
}, []);
|
||||
|
||||
// Poll for new messages
|
||||
const processMessage = useCallback((msg) => {
|
||||
const body = Array.isArray(msg.body) ? msg.body.join('\n') : '';
|
||||
const base = { id: msg.id, ts: msg.ts, from: msg.from, to: msg.to, command: msg.command };
|
||||
|
||||
switch (msg.command) {
|
||||
case 'PRIVMSG':
|
||||
case 'NOTICE': {
|
||||
const parsed = { ...base, text: body, system: false };
|
||||
const target = msg.to;
|
||||
if (target && target.startsWith('#')) {
|
||||
addMessage(target, parsed);
|
||||
} else {
|
||||
const dmPeer = msg.from === nickRef.current ? msg.to : msg.from;
|
||||
setTabs(prev => {
|
||||
if (!prev.find(t => t.type === 'dm' && t.name === dmPeer)) {
|
||||
return [...prev, { type: 'dm', name: dmPeer }];
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
addMessage(dmPeer, parsed);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'JOIN': {
|
||||
const text = `${msg.from} has joined ${msg.to}`;
|
||||
if (msg.to) addMessage(msg.to, { ...base, text, system: true });
|
||||
if (msg.to && msg.to.startsWith('#')) refreshMembers(msg.to);
|
||||
break;
|
||||
}
|
||||
case 'PART': {
|
||||
const reason = body ? ': ' + body : '';
|
||||
const text = `${msg.from} has left ${msg.to}${reason}`;
|
||||
if (msg.to) addMessage(msg.to, { ...base, text, system: true });
|
||||
if (msg.to && msg.to.startsWith('#')) refreshMembers(msg.to);
|
||||
break;
|
||||
}
|
||||
case 'QUIT': {
|
||||
const reason = body ? ': ' + body : '';
|
||||
const text = `${msg.from} has quit${reason}`;
|
||||
tabsRef.current.forEach(tab => {
|
||||
if (tab.type === 'channel') {
|
||||
addMessage(tab.name, { ...base, text, system: true });
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'NICK': {
|
||||
const newNick = Array.isArray(msg.body) ? msg.body[0] : body;
|
||||
const text = `${msg.from} is now known as ${newNick}`;
|
||||
tabsRef.current.forEach(tab => {
|
||||
if (tab.type === 'channel') {
|
||||
addMessage(tab.name, { ...base, text, system: true });
|
||||
}
|
||||
});
|
||||
if (msg.from === nickRef.current && newNick) setNick(newNick);
|
||||
// Refresh members in all channels
|
||||
tabsRef.current.forEach(tab => {
|
||||
if (tab.type === 'channel') refreshMembers(tab.name);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'TOPIC': {
|
||||
const text = `${msg.from} set the topic: ${body}`;
|
||||
if (msg.to) {
|
||||
addMessage(msg.to, { ...base, text, system: true });
|
||||
setTopics(prev => ({ ...prev, [msg.to]: body }));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case '375':
|
||||
case '372':
|
||||
case '376':
|
||||
addMessage('Server', { ...base, text: body, system: true });
|
||||
break;
|
||||
default:
|
||||
addMessage('Server', { ...base, text: body || msg.command, system: true });
|
||||
}
|
||||
}, [addMessage, refreshMembers]);
|
||||
|
||||
// Long-poll loop
|
||||
useEffect(() => {
|
||||
if (!loggedIn) return;
|
||||
let alive = true;
|
||||
|
||||
const poll = async () => {
|
||||
try {
|
||||
const msgs = await api(`/messages?after=${lastMsgId}`);
|
||||
if (!alive) return;
|
||||
let maxId = lastMsgId;
|
||||
for (const msg of msgs) {
|
||||
if (msg.id > maxId) maxId = msg.id;
|
||||
if (msg.isDm) {
|
||||
const dmTab = msg.nick === nick ? msg.dmTarget : msg.nick;
|
||||
// Ensure DM tab exists
|
||||
setTabs(prev => {
|
||||
if (!prev.find(t => t.type === 'dm' && t.name === dmTab)) {
|
||||
return [...prev, { type: 'dm', name: dmTab }];
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
addMessage(dmTab, msg);
|
||||
} else if (msg.channel) {
|
||||
addMessage(msg.channel, msg);
|
||||
while (alive) {
|
||||
try {
|
||||
const controller = new AbortController();
|
||||
pollAbortRef.current = controller;
|
||||
const result = await api(
|
||||
`/messages?after=${lastIdRef.current}&timeout=${POLL_TIMEOUT}`,
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
if (!alive) break;
|
||||
setConnected(true);
|
||||
if (result.messages) {
|
||||
for (const m of result.messages) processMessage(m);
|
||||
}
|
||||
if (result.last_id > lastIdRef.current) {
|
||||
lastIdRef.current = result.last_id;
|
||||
}
|
||||
} catch (err) {
|
||||
if (!alive) break;
|
||||
if (err.name === 'AbortError') continue;
|
||||
setConnected(false);
|
||||
await new Promise(r => setTimeout(r, RECONNECT_DELAY));
|
||||
}
|
||||
if (maxId > lastMsgId) setLastMsgId(maxId);
|
||||
} catch (err) {
|
||||
// silent
|
||||
}
|
||||
};
|
||||
pollRef.current = setInterval(poll, 1500);
|
||||
poll();
|
||||
return () => { alive = false; clearInterval(pollRef.current); };
|
||||
}, [loggedIn, lastMsgId, nick, addMessage]);
|
||||
|
||||
// Fetch members for active channel tab
|
||||
poll();
|
||||
return () => { alive = false; pollAbortRef.current?.abort(); };
|
||||
}, [loggedIn, processMessage]);
|
||||
|
||||
// Refresh members for active channel
|
||||
useEffect(() => {
|
||||
if (!loggedIn) return;
|
||||
const tab = tabs[activeTab];
|
||||
if (!tab || tab.type !== 'channel') return;
|
||||
const chName = tab.name.replace('#', '');
|
||||
api(`/channels/${chName}/members`).then(m => {
|
||||
setMembers(prev => ({ ...prev, [tab.name]: m }));
|
||||
}).catch(() => {});
|
||||
const iv = setInterval(() => {
|
||||
api(`/channels/${chName}/members`).then(m => {
|
||||
setMembers(prev => ({ ...prev, [tab.name]: m }));
|
||||
}).catch(() => {});
|
||||
}, 5000);
|
||||
refreshMembers(tab.name);
|
||||
const iv = setInterval(() => refreshMembers(tab.name), MEMBER_REFRESH_INTERVAL);
|
||||
return () => clearInterval(iv);
|
||||
}, [loggedIn, activeTab, tabs]);
|
||||
}, [loggedIn, activeTab, tabs, refreshMembers]);
|
||||
|
||||
// Auto-scroll
|
||||
useEffect(() => {
|
||||
@@ -192,9 +303,37 @@ function App() {
|
||||
}, [messages, activeTab]);
|
||||
|
||||
// Focus input on tab change
|
||||
useEffect(() => { inputRef.current?.focus(); }, [activeTab]);
|
||||
|
||||
// Fetch topic for active channel
|
||||
useEffect(() => {
|
||||
inputRef.current?.focus();
|
||||
}, [activeTab]);
|
||||
if (!loggedIn) return;
|
||||
const tab = tabs[activeTab];
|
||||
if (!tab || tab.type !== 'channel') return;
|
||||
api('/channels').then(channels => {
|
||||
const ch = channels.find(c => c.name === tab.name);
|
||||
if (ch && ch.topic) setTopics(prev => ({ ...prev, [tab.name]: ch.topic }));
|
||||
}).catch(() => {});
|
||||
}, [loggedIn, activeTab, tabs]);
|
||||
|
||||
const onLogin = useCallback(async (userNick) => {
|
||||
setNick(userNick);
|
||||
setLoggedIn(true);
|
||||
addSystemMessage('Server', `Connected as ${userNick}`);
|
||||
// Auto-rejoin saved channels
|
||||
const saved = JSON.parse(localStorage.getItem('chat_channels') || '[]');
|
||||
for (const ch of saved) {
|
||||
try {
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'JOIN', to: ch }) });
|
||||
setTabs(prev => {
|
||||
if (prev.find(t => t.type === 'channel' && t.name === ch)) return prev;
|
||||
return [...prev, { type: 'channel', name: ch }];
|
||||
});
|
||||
} catch (e) {
|
||||
// Channel may not exist anymore
|
||||
}
|
||||
}
|
||||
}, [addSystemMessage]);
|
||||
|
||||
const joinChannel = async (name) => {
|
||||
if (!name) return;
|
||||
@@ -206,22 +345,29 @@ function App() {
|
||||
if (prev.find(t => t.type === 'channel' && t.name === name)) return prev;
|
||||
return [...prev, { type: 'channel', name }];
|
||||
});
|
||||
setActiveTab(tabs.length); // switch to new tab
|
||||
addSystemMessage(name, `Joined ${name}`);
|
||||
setActiveTab(tabs.length);
|
||||
// Load history
|
||||
try {
|
||||
const hist = await api(`/history?target=${encodeURIComponent(name)}&limit=50`);
|
||||
if (Array.isArray(hist)) {
|
||||
for (const m of hist) processMessage(m);
|
||||
}
|
||||
} catch (e) {
|
||||
// History may be empty
|
||||
}
|
||||
setJoinInput('');
|
||||
} catch (err) {
|
||||
addSystemMessage('server', `Failed to join ${name}: ${err.data?.error || 'error'}`);
|
||||
addSystemMessage('Server', `Failed to join ${name}: ${err.data?.error || 'error'}`);
|
||||
}
|
||||
};
|
||||
|
||||
const partChannel = async (name) => {
|
||||
try {
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PART', to: name }) });
|
||||
} catch (err) { /* ignore */ }
|
||||
setTabs(prev => {
|
||||
const next = prev.filter(t => !(t.type === 'channel' && t.name === name));
|
||||
return next;
|
||||
});
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
setTabs(prev => prev.filter(t => !(t.type === 'channel' && t.name === name)));
|
||||
setActiveTab(0);
|
||||
};
|
||||
|
||||
@@ -240,7 +386,8 @@ function App() {
|
||||
if (prev.find(t => t.type === 'dm' && t.name === targetNick)) return prev;
|
||||
return [...prev, { type: 'dm', name: targetNick }];
|
||||
});
|
||||
setActiveTab(tabs.findIndex(t => t.type === 'dm' && t.name === targetNick) || tabs.length);
|
||||
const idx = tabs.findIndex(t => t.type === 'dm' && t.name === targetNick);
|
||||
setActiveTab(idx >= 0 ? idx : tabs.length);
|
||||
};
|
||||
|
||||
const sendMessage = async () => {
|
||||
@@ -250,46 +397,45 @@ function App() {
|
||||
const tab = tabs[activeTab];
|
||||
if (!tab || tab.type === 'server') return;
|
||||
|
||||
// Handle /commands
|
||||
if (text.startsWith('/')) {
|
||||
const parts = text.split(' ');
|
||||
const cmd = parts[0].toLowerCase();
|
||||
if (cmd === '/join' && parts[1]) {
|
||||
joinChannel(parts[1]);
|
||||
return;
|
||||
}
|
||||
if (cmd === '/part') {
|
||||
if (tab.type === 'channel') partChannel(tab.name);
|
||||
return;
|
||||
}
|
||||
if (cmd === '/join' && parts[1]) { joinChannel(parts[1]); return; }
|
||||
if (cmd === '/part') { if (tab.type === 'channel') partChannel(tab.name); return; }
|
||||
if (cmd === '/msg' && parts[1] && parts.slice(2).join(' ')) {
|
||||
const target = parts[1];
|
||||
const msg = parts.slice(2).join(' ');
|
||||
const body = parts.slice(2).join(' ');
|
||||
try {
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: target, body: [msg] }) });
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: target, body: [body] }) });
|
||||
openDM(target);
|
||||
} catch (err) {
|
||||
addSystemMessage('server', `Failed to send DM: ${err.data?.error || 'error'}`);
|
||||
addSystemMessage('Server', `DM failed: ${err.data?.error || 'error'}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (cmd === '/nick' && parts[1]) {
|
||||
try {
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'NICK', body: [parts[1]] }) });
|
||||
setNick(parts[1]);
|
||||
addSystemMessage('server', `Nick changed to ${parts[1]}`);
|
||||
} catch (err) {
|
||||
addSystemMessage('server', `Nick change failed: ${err.data?.error || 'error'}`);
|
||||
addSystemMessage('Server', `Nick change failed: ${err.data?.error || 'error'}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
addSystemMessage('server', `Unknown command: ${cmd}`);
|
||||
if (cmd === '/topic' && tab.type === 'channel') {
|
||||
const topicText = parts.slice(1).join(' ');
|
||||
try {
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'TOPIC', to: tab.name, body: [topicText] }) });
|
||||
} catch (err) {
|
||||
addSystemMessage('Server', `Topic failed: ${err.data?.error || 'error'}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
addSystemMessage('Server', `Unknown command: ${cmd}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const to = tab.type === 'channel' ? tab.name : tab.name;
|
||||
try {
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to, body: [text] }) });
|
||||
await api('/messages', { method: 'POST', body: JSON.stringify({ command: 'PRIVMSG', to: tab.name, body: [text] }) });
|
||||
} catch (err) {
|
||||
addSystemMessage(tab.name, `Send failed: ${err.data?.error || 'error'}`);
|
||||
}
|
||||
@@ -300,16 +446,21 @@ function App() {
|
||||
const currentTab = tabs[activeTab] || tabs[0];
|
||||
const currentMessages = messages[currentTab.name] || [];
|
||||
const currentMembers = members[currentTab.name] || [];
|
||||
const currentTopic = topics[currentTab.name] || '';
|
||||
|
||||
return (
|
||||
<div class="app">
|
||||
<div class="tab-bar">
|
||||
{!connected && <div class="connection-status">⚠ Reconnecting...</div>}
|
||||
{tabs.map((tab, i) => (
|
||||
<div
|
||||
class={`tab ${i === activeTab ? 'active' : ''}`}
|
||||
onClick={() => setActiveTab(i)}
|
||||
>
|
||||
{tab.type === 'dm' ? `→${tab.name}` : tab.name}
|
||||
{unread[tab.name] > 0 && i !== activeTab && (
|
||||
<span class="unread-badge">{unread[tab.name]}</span>
|
||||
)}
|
||||
{tab.type !== 'server' && (
|
||||
<span class="close-btn" onClick={(e) => { e.stopPropagation(); closeTab(i); }}>×</span>
|
||||
)}
|
||||
@@ -326,30 +477,27 @@ function App() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{currentTab.type === 'channel' && currentTopic && (
|
||||
<div class="topic-bar" title={currentTopic}>{currentTopic}</div>
|
||||
)}
|
||||
|
||||
<div class="content">
|
||||
<div class="messages-pane">
|
||||
{currentTab.type === 'server' ? (
|
||||
<div class="server-messages">
|
||||
{currentMessages.map(m => <Message msg={m} />)}
|
||||
<div ref={messagesEndRef} />
|
||||
<div class={currentTab.type === 'server' ? 'server-messages' : 'messages'}>
|
||||
{currentMessages.map(m => <Message msg={m} />)}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
{currentTab.type !== 'server' && (
|
||||
<div class="input-bar">
|
||||
<input
|
||||
ref={inputRef}
|
||||
placeholder={`Message ${currentTab.name}...`}
|
||||
value={input}
|
||||
onInput={e => setInput(e.target.value)}
|
||||
onKeyDown={e => e.key === 'Enter' && sendMessage()}
|
||||
/>
|
||||
<button onClick={sendMessage}>Send</button>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div class="messages">
|
||||
{currentMessages.map(m => <Message msg={m} />)}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
<div class="input-bar">
|
||||
<input
|
||||
ref={inputRef}
|
||||
placeholder={`Message ${currentTab.name}...`}
|
||||
value={input}
|
||||
onInput={e => setInput(e.target.value)}
|
||||
onKeyDown={e => e.key === 'Enter' && sendMessage()}
|
||||
/>
|
||||
<button onClick={sendMessage}>Send</button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
--tab-active: #e94560;
|
||||
--tab-bg: #16213e;
|
||||
--tab-hover: #1a1a3e;
|
||||
--topic-bg: #121a30;
|
||||
--unread-bg: #e94560;
|
||||
--warn: #f0ad4e;
|
||||
}
|
||||
|
||||
html, body, #root {
|
||||
@@ -86,6 +89,7 @@ html, body, #root {
|
||||
border-bottom: 1px solid var(--border);
|
||||
overflow-x: auto;
|
||||
flex-shrink: 0;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.tab {
|
||||
@@ -95,6 +99,7 @@ html, body, #root {
|
||||
white-space: nowrap;
|
||||
color: var(--text-muted);
|
||||
user-select: none;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.tab:hover {
|
||||
@@ -116,6 +121,43 @@ html, body, #root {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.tab .unread-badge {
|
||||
display: inline-block;
|
||||
background: var(--unread-bg);
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-weight: bold;
|
||||
padding: 1px 5px;
|
||||
border-radius: 8px;
|
||||
margin-left: 6px;
|
||||
min-width: 16px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Connection status */
|
||||
.connection-status {
|
||||
padding: 4px 12px;
|
||||
background: var(--warn);
|
||||
color: #1a1a2e;
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
white-space: nowrap;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Topic bar */
|
||||
.topic-bar {
|
||||
padding: 6px 12px;
|
||||
background: var(--topic-bg);
|
||||
border-bottom: 1px solid var(--border);
|
||||
color: var(--text-muted);
|
||||
font-size: 12px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Content area */
|
||||
.content {
|
||||
display: flex;
|
||||
@@ -243,6 +285,7 @@ html, body, #root {
|
||||
gap: 8px;
|
||||
background: var(--bg-secondary);
|
||||
border-bottom: 1px solid var(--border);
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.join-dialog input {
|
||||
|
||||
Reference in New Issue
Block a user