22 Commits

Author SHA1 Message Date
clawbot
e60c83f191 docs: add project tagline, ignore built binaries 2026-02-10 17:55:15 -08:00
clawbot
f4a9ec13bd docs: bold tagline, simplify opening paragraph 2026-02-10 17:55:13 -08:00
clawbot
a2c47f5618 docs: add design philosophy section explaining protocol rationale
Explains why the protocol is IRC over HTTP (not a new protocol borrowing
IRC names), the four specific changes from IRC (HTTP transport, server-held
sessions, structured bodies, message metadata), and why the C2S command
dispatch resembles but is not JSON-RPC.
2026-02-10 17:54:56 -08:00
clawbot
1a6e929f56 docs: update README API spec for unified command endpoint
- Document all C2S commands with required/optional fields table
- Remove separate join/part/nick/topic endpoint docs
- Update /channels/all to /channels
- Update /register to /session
2026-02-10 17:53:48 -08:00
clawbot
af3a26fcdf feat(web): update SPA to use unified command endpoint
- Session creation uses /session instead of /register
- JOIN/PART via POST /messages with command field
- NICK changes via POST /messages with NICK command
- Messages sent as PRIVMSG commands with body array
- DMs use PRIVMSG command format
2026-02-10 17:53:17 -08:00
clawbot
d06bb5334a feat(cli): update client to use unified command endpoint
- JoinChannel/PartChannel now send via POST /messages with command field
- ListChannels uses /channels instead of /channels/all
- gofmt whitespace fixes
2026-02-10 17:53:13 -08:00
clawbot
0ee3fd78d2 refactor: unify all C2S commands through POST /messages
All client-to-server commands now go through POST /api/v1/messages
with a 'command' field. The server dispatches by command type:

- PRIVMSG/NOTICE: send message to channel or user
- JOIN: join channel (creates if needed)
- PART: leave channel
- NICK: change nickname
- TOPIC: set channel topic
- PING: keepalive (returns PONG)

Removed separate routes:
- POST /channels/join
- DELETE /channels/{channel}
- POST /register (renamed to POST /session)
- GET /channels/all (moved to GET /channels)

Added DB methods: ChangeNick, SetTopic
2026-02-10 17:53:08 -08:00
clawbot
f7776f8d3f feat: scaffold IRC-style CLI client (chat-cli)
Adds cmd/chat-cli/ with:
- tview-based irssi-like TUI (message buffer, status bar, input line)
- IRC-style commands: /connect, /nick, /join, /part, /msg, /query,
  /topic, /names, /list, /window, /quit, /help
- Multi-buffer window model with Alt+N switching and unread indicators
- Background long-poll goroutine for message delivery
- Clean API client wrapper in cmd/chat-cli/api/
- Structured types matching the JSON message protocol
2026-02-10 11:51:01 -08:00
clawbot
4b074aafd7 docs: add PUBKEY schema for signing key distribution
Dedicated JSON Schema for the PUBKEY command — announces/relays
user signing public keys. Body is a structured object with alg
and key fields. Already documented in README; this adds the
schema file and index entry.
2026-02-10 10:36:55 -08:00
clawbot
ab70f889a6 refactor: structured body (array|object, never string) for canonicalization
Message bodies are always arrays of strings (text lines) or objects
(structured data like PUBKEY). Never raw strings. This enables:
- Multiline messages without escape sequences
- Deterministic JSON canonicalization (RFC 8785 JCS) for signing
- Structured data where needed

Update all schemas: body fields use array type with string items.
Update message.json envelope: body is oneOf[array, object], id is UUID.
Update README: message envelope table, examples, and canonicalization docs.
Update schema/README.md: field types, examples with array bodies.
2026-02-10 10:36:02 -08:00
clawbot
dfb1636be5 refactor: model message schemas after IRC RFC 1459/2812
Replace c2s/s2c/s2s taxonomy with IRC-native structure:
- schema/commands/ — IRC command schemas (PRIVMSG, NOTICE, JOIN, PART,
  QUIT, NICK, TOPIC, MODE, KICK, PING, PONG)
- schema/numerics/ — IRC numeric reply codes (001-004, 322-323, 332,
  353, 366, 372-376, 401, 403, 433, 442, 482)
- schema/message.json — base envelope mapping IRC wire format to JSON

Messages use 'command' field with IRC command names or 3-digit numeric
codes. 'body' is a string (IRC trailing parameter), not object/array.
'from'/'to' map to IRC prefix and first parameter.

Federation uses the same IRC commands (no custom RELAY/LINK/SYNC).

Update README message format, command tables, and examples to match.
2026-02-10 10:31:26 -08:00
clawbot
c8d88de8c5 chore: remove superseded schema files
Remove schema/commands/ and schema/message.json, replaced by
the new schema/{c2s,s2c,s2s}/*.schema.json structure.
2026-02-10 10:26:44 -08:00
clawbot
02acf1c919 docs: document IRC message protocol, signing, and canonicalization
- Add IRC command/numeric mapping tables (C2S, S2C, S2S)
- Document structured message bodies (array/object, never raw strings)
- Document RFC 8785 JCS canonicalization for deterministic hashing
- Document Ed25519 signing/verification flow with TOFU key distribution
- Document PUBKEY message type for public key announcement
- Update message examples to use IRC command format
- Update curl examples to use command-based messages
- Note web client as convenience UI; primary interface is IRC-style clients
- Add schema/ to project structure
2026-02-10 10:26:32 -08:00
clawbot
909da3cc99 feat: add IRC-style message protocol JSON schemas (draft 2020-12)
Add JSON Schema definitions for all message types:
- Base message envelope (message.schema.json)
- C2S: PRIVMSG, NOTICE, JOIN, PART, QUIT, NICK, MODE, TOPIC, KICK, PING, PUBKEY
- S2C: named commands + numeric reply codes (001, 002, 322, 353, 366, 372, 375, 376, 401, 403, 433)
- S2S: RELAY, LINK, UNLINK, SYNC, PING, PONG
- Schema index (schema/README.md)

All messages use IRC command names and numeric codes from RFC 1459/2812.
Bodies are always objects or arrays (never raw strings) to support
deterministic canonicalization (RFC 8785 JCS) and message signing.
2026-02-10 10:26:32 -08:00
clawbot
4645be5f20 style: fix whitespace alignment in config.go 2026-02-10 10:26:32 -08:00
user
065b243def docs: add JSON Schema definitions for all message types (draft 2020-12)
C2S (7): send, join, part, nick, topic, mode, ping
S2C (12): message, dm, notice, join, part, quit, nick, topic, mode, system, error, pong
S2S (6): relay, link, unlink, sync, ping, pong

Each message type has its own schema file under schema/{c2s,s2c,s2s}/.
schema/README.md provides an index of all types with descriptions.
2026-02-10 10:25:42 -08:00
user
6483670dc7 docs: emphasize API-first design, add curl examples, note web client as reference impl 2026-02-10 10:21:10 -08:00
user
16e08c2839 docs: update README with IRC-inspired unified API design
- Document simplified endpoint structure
- Single message stream (GET /messages) for all message types
- Unified send (POST /messages) with 'to' field
- GET /state replaces separate /me and /channels
- GET /history for scrollback
- Update project status
2026-02-10 10:20:13 -08:00
user
aabf8e902c feat: update web client for unified API endpoints
- Use /state instead of /me for auth check
- Use /messages instead of /poll for message stream
- Use unified POST /messages with 'to' field for all sends
- Update part channel URL to DELETE /channels/{name}
2026-02-10 10:20:09 -08:00
user
74437b8372 refactor: update routes for unified API endpoints
- GET /api/v1/state replaces /me and /channels
- GET/POST /api/v1/messages replaces /poll, /channels/{ch}/messages, /dm/{nick}/messages
- GET /api/v1/history for scrollback
- DELETE /api/v1/channels/{name} replaces /channels/{channel}/part
2026-02-10 10:20:05 -08:00
user
7361e8bd9b refactor: merge /me + /channels into /state, unify message endpoints
- HandleState returns user info (id, nick) + joined channels in one response
- HandleGetMessages now serves unified message stream (was HandlePoll)
- HandleSendMessage accepts 'to' field for routing to #channel or nick
- HandleGetHistory supports scrollback for channels and DMs
- Remove separate HandleMe, HandleListChannels, HandleSendDM, HandleGetDMs, HandlePoll
2026-02-10 10:20:00 -08:00
user
ac933d07d2 Add embedded web chat client with C2S HTTP API
- New DB schema: users, channel_members, messages tables (migration 003)
- Full C2S HTTP API: register, channels, messages, DMs, polling
- Preact SPA embedded via embed.FS, served at GET /
- IRC-style UI: tab bar, channel messages, user list, DM tabs, /commands
- Dark theme, responsive, esbuild-bundled (~19KB)
- Polling-based message delivery (1.5s interval)
- Commands: /join, /part, /msg, /nick
2026-02-10 09:22:22 -08:00
30 changed files with 813 additions and 3609 deletions

View File

@@ -1,9 +1,8 @@
.git
node_modules
.DS_Store
bin/ bin/
chatd
data.db data.db
.env .env
.git
*.test *.test
*.out *.out
debug.log debug.log

View File

@@ -1,12 +0,0 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -1,9 +0,0 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-22
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

30
.gitignore vendored
View File

@@ -1,28 +1,7 @@
# OS
.DS_Store
Thumbs.db
# Editors
*.swp
*.swo
*~
*.bak
.idea/
.vscode/
*.sublime-*
# Node
node_modules/
# Environment / secrets
.env
.env.*
*.pem
*.key
# Build artifacts
/chatd /chatd
/bin/ /bin/
data.db
.env
*.exe *.exe
*.dll *.dll
*.so *.so
@@ -30,9 +9,6 @@ node_modules/
*.test *.test
*.out *.out
vendor/ vendor/
# Project
data.db
debug.log debug.log
/chat-cli
web/node_modules/ web/node_modules/
chat-cli

View File

@@ -1,7 +1,6 @@
# golang:1.24-alpine, 2026-02-26 FROM golang:1.24-alpine AS builder
FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
RUN apk add --no-cache git build-base RUN apk add --no-cache git
WORKDIR /src WORKDIR /src
COPY go.mod go.sum ./ COPY go.mod go.sum ./
@@ -9,15 +8,10 @@ RUN go mod download
COPY . . COPY . .
# Run all checks — build fails if branch is not green
RUN go install github.com/golangci/golangci-lint/cmd/golangci-lint@v2.1.6
RUN make check
ARG VERSION=dev ARG VERSION=dev
RUN go build -ldflags "-X main.Version=${VERSION}" -o /chatd ./cmd/chatd RUN go build -ldflags "-X main.Version=${VERSION}" -o /chatd ./cmd/chatd
# alpine:3.21, 2026-02-26 FROM alpine:3.21
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates RUN apk add --no-cache ca-certificates
COPY --from=builder /chatd /usr/local/bin/chatd COPY --from=builder /chatd /usr/local/bin/chatd

21
LICENSE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025 sneak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks .PHONY: all build lint fmt test check clean run debug
BINARY := chatd BINARY := chatd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -17,15 +17,18 @@ fmt:
gofmt -s -w . gofmt -s -w .
goimports -w . goimports -w .
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: test:
go test -timeout 30s -v -race -cover ./... go test -v -race -cover ./...
# check runs all validation without making changes # Check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong # Used by CI and Docker build — fails if anything is wrong
check: test lint fmt-check check:
@echo "==> Checking formatting..."
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
@echo "==> Running linter..."
golangci-lint run --config .golangci.yml ./...
@echo "==> Running tests..."
go test -v -race ./...
@echo "==> Building..." @echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd
@echo "==> All checks passed!" @echo "==> All checks passed!"
@@ -38,12 +41,3 @@ debug: build
clean: clean:
rm -rf bin/ chatd data.db rm -rf bin/ chatd data.db
docker:
docker build -t chat .
hooks:
@printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit
@printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit
@printf 'make check\n' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit

View File

@@ -656,8 +656,3 @@ embedded web client.
## License ## License
MIT MIT
## Author
[@sneak](https://sneak.berlin)

View File

@@ -1,182 +0,0 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary. Pre-1.0.0: modify existing migrations (no installed base assumed).
Post-1.0.0: add new migration files.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

View File

@@ -1,31 +1,15 @@
// Package chatapi provides a client for the chat server HTTP API. package api
package chatapi
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"time" "time"
) )
const (
httpTimeout = 30 * time.Second
pollExtraDelay = 5
httpErrThreshold = 400
)
// ErrHTTP is returned for non-2xx responses.
var ErrHTTP = errors.New("http error")
// ErrUnexpectedFormat is returned when the response format is
// not recognised.
var ErrUnexpectedFormat = errors.New("unexpected format")
// Client wraps HTTP calls to the chat server API. // Client wraps HTTP calls to the chat server API.
type Client struct { type Client struct {
BaseURL string BaseURL string
@@ -38,32 +22,59 @@ func NewClient(baseURL string) *Client {
return &Client{ return &Client{
BaseURL: baseURL, BaseURL: baseURL,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Timeout: httpTimeout, Timeout: 30 * time.Second,
}, },
} }
} }
func (c *Client) do(method, path string, body interface{}) ([]byte, error) {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequest(method, c.BaseURL+path, bodyReader)
if err != nil {
return nil, fmt.Errorf("request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if c.Token != "" {
req.Header.Set("Authorization", "Bearer "+c.Token)
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("http: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
if resp.StatusCode >= 400 {
return data, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(data))
}
return data, nil
}
// CreateSession creates a new session on the server. // CreateSession creates a new session on the server.
func (c *Client) CreateSession( func (c *Client) CreateSession(nick string) (*SessionResponse, error) {
nick string, data, err := c.do("POST", "/api/v1/session", &SessionRequest{Nick: nick})
) (*SessionResponse, error) {
data, err := c.do(
"POST", "/api/v1/session",
&SessionRequest{Nick: nick},
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var resp SessionResponse var resp SessionResponse
if err := json.Unmarshal(data, &resp); err != nil {
err = json.Unmarshal(data, &resp)
if err != nil {
return nil, fmt.Errorf("decode session: %w", err) return nil, fmt.Errorf("decode session: %w", err)
} }
c.Token = resp.Token c.Token = resp.Token
return &resp, nil return &resp, nil
} }
@@ -73,113 +84,78 @@ func (c *Client) GetState() (*StateResponse, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var resp StateResponse var resp StateResponse
if err := json.Unmarshal(data, &resp); err != nil {
err = json.Unmarshal(data, &resp)
if err != nil {
return nil, fmt.Errorf("decode state: %w", err) return nil, fmt.Errorf("decode state: %w", err)
} }
return &resp, nil return &resp, nil
} }
// SendMessage sends a message (any IRC command). // SendMessage sends a message (any IRC command).
func (c *Client) SendMessage(msg *Message) error { func (c *Client) SendMessage(msg *Message) error {
_, err := c.do("POST", "/api/v1/messages", msg) _, err := c.do("POST", "/api/v1/messages", msg)
return err return err
} }
// PollMessages long-polls for new messages. // PollMessages long-polls for new messages.
func (c *Client) PollMessages( func (c *Client) PollMessages(afterID string, timeout int) ([]Message, error) {
afterID string, // Use a longer HTTP timeout than the server long-poll timeout.
timeout int, client := &http.Client{Timeout: time.Duration(timeout+5) * time.Second}
) ([]Message, error) {
pollTimeout := time.Duration(
timeout+pollExtraDelay,
) * time.Second
client := &http.Client{Timeout: pollTimeout}
params := url.Values{} params := url.Values{}
if afterID != "" { if afterID != "" {
params.Set("after", afterID) params.Set("after", afterID)
} }
params.Set("timeout", fmt.Sprintf("%d", timeout))
params.Set("timeout", strconv.Itoa(timeout))
path := "/api/v1/messages" path := "/api/v1/messages"
if len(params) > 0 { if len(params) > 0 {
path += "?" + params.Encode() path += "?" + params.Encode()
} }
req, err := http.NewRequest( //nolint:noctx // CLI tool req, err := http.NewRequest("GET", c.BaseURL+path, nil)
http.MethodGet, c.BaseURL+path, nil,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Authorization", "Bearer "+c.Token) req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := client.Do(req) //nolint:gosec // URL from user config resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
data, err := io.ReadAll(resp.Body) data, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.StatusCode >= httpErrThreshold { if resp.StatusCode >= 400 {
return nil, fmt.Errorf( return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(data))
"%w: %d: %s",
ErrHTTP, resp.StatusCode, string(data),
)
} }
return decodeMessages(data) // The server may return an array directly or wrapped.
}
func decodeMessages(data []byte) ([]Message, error) {
var msgs []Message var msgs []Message
if err := json.Unmarshal(data, &msgs); err != nil {
err := json.Unmarshal(data, &msgs) // Try wrapped format.
if err == nil {
return msgs, nil
}
var wrapped MessagesResponse var wrapped MessagesResponse
if err2 := json.Unmarshal(data, &wrapped); err2 != nil {
err2 := json.Unmarshal(data, &wrapped) return nil, fmt.Errorf("decode messages: %w (raw: %s)", err, string(data))
if err2 != nil { }
return nil, fmt.Errorf( msgs = wrapped.Messages
"decode messages: %w (raw: %s)",
err, string(data),
)
} }
return wrapped.Messages, nil return msgs, nil
} }
// JoinChannel joins a channel via the unified command // JoinChannel joins a channel via the unified command endpoint.
// endpoint.
func (c *Client) JoinChannel(channel string) error { func (c *Client) JoinChannel(channel string) error {
return c.SendMessage( return c.SendMessage(&Message{Command: "JOIN", To: channel})
&Message{Command: "JOIN", To: channel},
)
} }
// PartChannel leaves a channel via the unified command // PartChannel leaves a channel via the unified command endpoint.
// endpoint.
func (c *Client) PartChannel(channel string) error { func (c *Client) PartChannel(channel string) error {
return c.SendMessage( return c.SendMessage(&Message{Command: "PART", To: channel})
&Message{Command: "PART", To: channel},
)
} }
// ListChannels returns all channels on the server. // ListChannels returns all channels on the server.
@@ -188,39 +164,29 @@ func (c *Client) ListChannels() ([]Channel, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var channels []Channel var channels []Channel
if err := json.Unmarshal(data, &channels); err != nil {
err = json.Unmarshal(data, &channels)
if err != nil {
return nil, err return nil, err
} }
return channels, nil return channels, nil
} }
// GetMembers returns members of a channel. // GetMembers returns members of a channel.
func (c *Client) GetMembers( func (c *Client) GetMembers(channel string) ([]string, error) {
channel string, data, err := c.do("GET", "/api/v1/channels/"+url.PathEscape(channel)+"/members", nil)
) ([]string, error) {
path := "/api/v1/channels/" +
url.PathEscape(channel) + "/members"
data, err := c.do("GET", path, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var members []string var members []string
if err := json.Unmarshal(data, &members); err != nil {
err = json.Unmarshal(data, &members) // Try object format.
if err != nil { var obj map[string]interface{}
return nil, fmt.Errorf( if err2 := json.Unmarshal(data, &obj); err2 != nil {
"%w: members: %s", return nil, err
ErrUnexpectedFormat, string(data), }
) // Extract member names from whatever format.
return nil, fmt.Errorf("unexpected members format: %s", string(data))
} }
return members, nil return members, nil
} }
@@ -230,63 +196,9 @@ func (c *Client) GetServerInfo() (*ServerInfo, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var info ServerInfo var info ServerInfo
if err := json.Unmarshal(data, &info); err != nil {
err = json.Unmarshal(data, &info)
if err != nil {
return nil, err return nil, err
} }
return &info, nil return &info, nil
} }
func (c *Client) do(
method, path string,
body any,
) ([]byte, error) {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequest( //nolint:noctx // CLI tool
method, c.BaseURL+path, bodyReader,
)
if err != nil {
return nil, fmt.Errorf("request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if c.Token != "" {
req.Header.Set("Authorization", "Bearer "+c.Token)
}
resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL from user config
if err != nil {
return nil, fmt.Errorf("http: %w", err)
}
defer func() { _ = resp.Body.Close() }()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
if resp.StatusCode >= httpErrThreshold {
return data, fmt.Errorf(
"%w: %d: %s",
ErrHTTP, resp.StatusCode, string(data),
)
}
return data, nil
}

View File

@@ -1,4 +1,4 @@
package chatapi package api
import "time" import "time"
@@ -9,16 +9,16 @@ type SessionRequest struct {
// SessionResponse is the response from POST /api/v1/session. // SessionResponse is the response from POST /api/v1/session.
type SessionResponse struct { type SessionResponse struct {
SessionID string `json:"sessionId"` SessionID string `json:"session_id"`
ClientID string `json:"clientId"` ClientID string `json:"client_id"`
Nick string `json:"nick"` Nick string `json:"nick"`
Token string `json:"token"` Token string `json:"token"`
} }
// StateResponse is the response from GET /api/v1/state. // StateResponse is the response from GET /api/v1/state.
type StateResponse struct { type StateResponse struct {
SessionID string `json:"sessionId"` SessionID string `json:"session_id"`
ClientID string `json:"clientId"` ClientID string `json:"client_id"`
Nick string `json:"nick"` Nick string `json:"nick"`
Channels []string `json:"channels"` Channels []string `json:"channels"`
} }
@@ -29,25 +29,22 @@ type Message struct {
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
To string `json:"to,omitempty"` To string `json:"to,omitempty"`
Params []string `json:"params,omitempty"` Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"` Body interface{} `json:"body,omitempty"`
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
TS string `json:"ts,omitempty"` TS string `json:"ts,omitempty"`
Meta any `json:"meta,omitempty"` Meta interface{} `json:"meta,omitempty"`
} }
// BodyLines returns the body as a slice of strings (for text // BodyLines returns the body as a slice of strings (for text messages).
// messages).
func (m *Message) BodyLines() []string { func (m *Message) BodyLines() []string {
switch v := m.Body.(type) { switch v := m.Body.(type) {
case []any: case []interface{}:
lines := make([]string, 0, len(v)) lines := make([]string, 0, len(v))
for _, item := range v { for _, item := range v {
if s, ok := item.(string); ok { if s, ok := item.(string); ok {
lines = append(lines, s) lines = append(lines, s)
} }
} }
return lines return lines
case []string: case []string:
return v return v
@@ -61,7 +58,7 @@ type Channel struct {
Name string `json:"name"` Name string `json:"name"`
Topic string `json:"topic"` Topic string `json:"topic"`
Members int `json:"members"` Members int `json:"members"`
CreatedAt string `json:"createdAt"` CreatedAt string `json:"created_at"`
} }
// ServerInfo is the response from GET /api/v1/server. // ServerInfo is the response from GET /api/v1/server.
@@ -82,6 +79,5 @@ func (m *Message) ParseTS() time.Time {
if err != nil { if err != nil {
return time.Now() return time.Now()
} }
return t return t
} }

View File

@@ -1,21 +1,13 @@
// Package main implements chat-cli, an IRC-style terminal client.
package main package main
import ( import (
"fmt" "fmt"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
api "git.eeqj.de/sneak/chat/cmd/chat-cli/api" "git.eeqj.de/sneak/chat/cmd/chat-cli/api"
)
const (
pollTimeoutSec = 15
retryDelay = 2 * time.Second
maxNickLength = 32
) )
// App holds the application state. // App holds the application state.
@@ -40,18 +32,11 @@ func main() {
app.ui.OnInput(app.handleInput) app.ui.OnInput(app.handleInput)
app.ui.SetStatus(app.nick, "", "disconnected") app.ui.SetStatus(app.nick, "", "disconnected")
app.ui.AddStatus( app.ui.AddStatus("Welcome to chat-cli — an IRC-style client")
"Welcome to chat-cli \u2014 an IRC-style client", app.ui.AddStatus("Type [yellow]/connect <server-url>[white] to begin, or [yellow]/help[white] for commands")
)
app.ui.AddStatus(
"Type [yellow]/connect <server-url>[white] " +
"to begin, or [yellow]/help[white] for commands",
)
err := app.ui.Run()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err)
if err := app.ui.Run(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
} }
@@ -59,34 +44,21 @@ func main() {
func (a *App) handleInput(text string) { func (a *App) handleInput(text string) {
if strings.HasPrefix(text, "/") { if strings.HasPrefix(text, "/") {
a.handleCommand(text) a.handleCommand(text)
return return
} }
a.sendPlainText(text) // Plain text → PRIVMSG to current target.
}
func (a *App) sendPlainText(text string) {
a.mu.Lock() a.mu.Lock()
target := a.target target := a.target
connected := a.connected connected := a.connected
nick := a.nick
a.mu.Unlock() a.mu.Unlock()
if !connected { if !connected {
a.ui.AddStatus( a.ui.AddStatus("[red]Not connected. Use /connect <url>")
"[red]Not connected. Use /connect <url>",
)
return return
} }
if target == "" { if target == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]No target. Use /join #channel or /query nick")
"[red]No target. " +
"Use /join #channel or /query nick",
)
return return
} }
@@ -96,28 +68,21 @@ func (a *App) sendPlainText(text string) {
Body: []string{text}, Body: []string{text},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Send error: %v", err))
fmt.Sprintf("[red]Send error: %v", err),
)
return return
} }
// Echo locally.
ts := time.Now().Format("15:04") ts := time.Now().Format("15:04")
a.mu.Lock()
a.ui.AddLine( nick := a.nick
target, a.mu.Unlock()
fmt.Sprintf( a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text))
"[gray]%s [green]<%s>[white] %s",
ts, nick, text,
),
)
} }
func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch func (a *App) handleCommand(text string) {
parts := strings.SplitN(text, " ", 2) //nolint:mnd // split into cmd+args parts := strings.SplitN(text, " ", 2)
cmd := strings.ToLower(parts[0]) cmd := strings.ToLower(parts[0])
args := "" args := ""
if len(parts) > 1 { if len(parts) > 1 {
args = parts[1] args = parts[1]
@@ -149,41 +114,27 @@ func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch
case "/help": case "/help":
a.cmdHelp() a.cmdHelp()
default: default:
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Unknown command: %s", cmd))
"[red]Unknown command: " + cmd,
)
} }
} }
func (a *App) cmdConnect(serverURL string) { func (a *App) cmdConnect(serverURL string) {
if serverURL == "" { if serverURL == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /connect <server-url>")
"[red]Usage: /connect <server-url>",
)
return return
} }
serverURL = strings.TrimRight(serverURL, "/") serverURL = strings.TrimRight(serverURL, "/")
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("Connecting to %s...", serverURL))
fmt.Sprintf("Connecting to %s...", serverURL),
)
a.mu.Lock() a.mu.Lock()
nick := a.nick nick := a.nick
a.mu.Unlock() a.mu.Unlock()
client := api.NewClient(serverURL) client := api.NewClient(serverURL)
resp, err := client.CreateSession(nick) resp, err := client.CreateSession(nick)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Connection failed: %v", err))
fmt.Sprintf(
"[red]Connection failed: %v", err,
),
)
return return
} }
@@ -194,26 +145,19 @@ func (a *App) cmdConnect(serverURL string) {
a.lastMsgID = "" a.lastMsgID = ""
a.mu.Unlock() a.mu.Unlock()
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[green]Connected! Nick: %s, Session: %s", resp.Nick, resp.SessionID))
fmt.Sprintf(
"[green]Connected! Nick: %s, Session: %s",
resp.Nick, resp.SessionID,
),
)
a.ui.SetStatus(resp.Nick, "", "connected") a.ui.SetStatus(resp.Nick, "", "connected")
// Start polling.
a.stopPoll = make(chan struct{}) a.stopPoll = make(chan struct{})
go a.pollLoop() go a.pollLoop()
} }
func (a *App) cmdNick(nick string) { func (a *App) cmdNick(nick string) {
if nick == "" { if nick == "" {
a.ui.AddStatus("[red]Usage: /nick <name>") a.ui.AddStatus("[red]Usage: /nick <name>")
return return
} }
a.mu.Lock() a.mu.Lock()
connected := a.connected connected := a.connected
a.mu.Unlock() a.mu.Unlock()
@@ -222,14 +166,7 @@ func (a *App) cmdNick(nick string) {
a.mu.Lock() a.mu.Lock()
a.nick = nick a.nick = nick
a.mu.Unlock() a.mu.Unlock()
a.ui.AddStatus(fmt.Sprintf("Nick set to %s (will be used on connect)", nick))
a.ui.AddStatus(
fmt.Sprintf(
"Nick set to %s (will be used on connect)",
nick,
),
)
return return
} }
@@ -238,12 +175,7 @@ func (a *App) cmdNick(nick string) {
Body: []string{nick}, Body: []string{nick},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Nick change failed: %v", err))
fmt.Sprintf(
"[red]Nick change failed: %v", err,
),
)
return return
} }
@@ -251,20 +183,15 @@ func (a *App) cmdNick(nick string) {
a.nick = nick a.nick = nick
target := a.target target := a.target
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus(nick, target, "connected") a.ui.SetStatus(nick, target, "connected")
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("Nick changed to %s", nick))
"Nick changed to " + nick,
)
} }
func (a *App) cmdJoin(channel string) { func (a *App) cmdJoin(channel string) {
if channel == "" { if channel == "" {
a.ui.AddStatus("[red]Usage: /join #channel") a.ui.AddStatus("[red]Usage: /join #channel")
return return
} }
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
@@ -272,19 +199,14 @@ func (a *App) cmdJoin(channel string) {
a.mu.Lock() a.mu.Lock()
connected := a.connected connected := a.connected
a.mu.Unlock() a.mu.Unlock()
if !connected { if !connected {
a.ui.AddStatus("[red]Not connected") a.ui.AddStatus("[red]Not connected")
return return
} }
err := a.client.JoinChannel(channel) err := a.client.JoinChannel(channel)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Join failed: %v", err))
fmt.Sprintf("[red]Join failed: %v", err),
)
return return
} }
@@ -294,55 +216,39 @@ func (a *App) cmdJoin(channel string) {
a.mu.Unlock() a.mu.Unlock()
a.ui.SwitchToBuffer(channel) a.ui.SwitchToBuffer(channel)
a.ui.AddLine( a.ui.AddLine(channel, fmt.Sprintf("[yellow]*** Joined %s", channel))
channel,
"[yellow]*** Joined "+channel,
)
a.ui.SetStatus(nick, channel, "connected") a.ui.SetStatus(nick, channel, "connected")
} }
func (a *App) cmdPart(channel string) { func (a *App) cmdPart(channel string) {
a.mu.Lock() a.mu.Lock()
if channel == "" { if channel == "" {
channel = a.target channel = a.target
} }
connected := a.connected connected := a.connected
a.mu.Unlock() a.mu.Unlock()
if channel == "" || !strings.HasPrefix(channel, "#") { if channel == "" || !strings.HasPrefix(channel, "#") {
a.ui.AddStatus("[red]No channel to part") a.ui.AddStatus("[red]No channel to part")
return return
} }
if !connected { if !connected {
a.ui.AddStatus("[red]Not connected") a.ui.AddStatus("[red]Not connected")
return return
} }
err := a.client.PartChannel(channel) err := a.client.PartChannel(channel)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Part failed: %v", err))
fmt.Sprintf("[red]Part failed: %v", err),
)
return return
} }
a.ui.AddLine( a.ui.AddLine(channel, fmt.Sprintf("[yellow]*** Left %s", channel))
channel,
"[yellow]*** Left "+channel,
)
a.mu.Lock() a.mu.Lock()
if a.target == channel { if a.target == channel {
a.target = "" a.target = ""
} }
nick := a.nick nick := a.nick
a.mu.Unlock() a.mu.Unlock()
@@ -351,23 +257,19 @@ func (a *App) cmdPart(channel string) {
} }
func (a *App) cmdMsg(args string) { func (a *App) cmdMsg(args string) {
parts := strings.SplitN(args, " ", 2) //nolint:mnd // split into target+text parts := strings.SplitN(args, " ", 2)
if len(parts) < 2 { //nolint:mnd // min args if len(parts) < 2 {
a.ui.AddStatus("[red]Usage: /msg <nick> <text>") a.ui.AddStatus("[red]Usage: /msg <nick> <text>")
return return
} }
target, text := parts[0], parts[1] target, text := parts[0], parts[1]
a.mu.Lock() a.mu.Lock()
connected := a.connected connected := a.connected
nick := a.nick nick := a.nick
a.mu.Unlock() a.mu.Unlock()
if !connected { if !connected {
a.ui.AddStatus("[red]Not connected") a.ui.AddStatus("[red]Not connected")
return return
} }
@@ -377,28 +279,17 @@ func (a *App) cmdMsg(args string) {
Body: []string{text}, Body: []string{text},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Send failed: %v", err))
fmt.Sprintf("[red]Send failed: %v", err),
)
return return
} }
ts := time.Now().Format("15:04") ts := time.Now().Format("15:04")
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, nick, text,
),
)
} }
func (a *App) cmdQuery(nick string) { func (a *App) cmdQuery(nick string) {
if nick == "" { if nick == "" {
a.ui.AddStatus("[red]Usage: /query <nick>") a.ui.AddStatus("[red]Usage: /query <nick>")
return return
} }
@@ -419,29 +310,22 @@ func (a *App) cmdTopic(args string) {
if !connected { if !connected {
a.ui.AddStatus("[red]Not connected") a.ui.AddStatus("[red]Not connected")
return return
} }
if !strings.HasPrefix(target, "#") { if !strings.HasPrefix(target, "#") {
a.ui.AddStatus("[red]Not in a channel") a.ui.AddStatus("[red]Not in a channel")
return return
} }
if args == "" { if args == "" {
// Query topic.
err := a.client.SendMessage(&api.Message{ err := a.client.SendMessage(&api.Message{
Command: "TOPIC", Command: "TOPIC",
To: target, To: target,
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Topic query failed: %v", err))
fmt.Sprintf(
"[red]Topic query failed: %v", err,
),
)
} }
return return
} }
@@ -451,11 +335,7 @@ func (a *App) cmdTopic(args string) {
Body: []string{args}, Body: []string{args},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Topic set failed: %v", err))
fmt.Sprintf(
"[red]Topic set failed: %v", err,
),
)
} }
} }
@@ -467,32 +347,20 @@ func (a *App) cmdNames() {
if !connected { if !connected {
a.ui.AddStatus("[red]Not connected") a.ui.AddStatus("[red]Not connected")
return return
} }
if !strings.HasPrefix(target, "#") { if !strings.HasPrefix(target, "#") {
a.ui.AddStatus("[red]Not in a channel") a.ui.AddStatus("[red]Not in a channel")
return return
} }
members, err := a.client.GetMembers(target) members, err := a.client.GetMembers(target)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Names failed: %v", err))
fmt.Sprintf("[red]Names failed: %v", err),
)
return return
} }
a.ui.AddLine( a.ui.AddLine(target, fmt.Sprintf("[cyan]*** Members of %s: %s", target, strings.Join(members, " ")))
target,
fmt.Sprintf(
"[cyan]*** Members of %s: %s",
target, strings.Join(members, " "),
),
)
} }
func (a *App) cmdList() { func (a *App) cmdList() {
@@ -502,55 +370,46 @@ func (a *App) cmdList() {
if !connected { if !connected {
a.ui.AddStatus("[red]Not connected") a.ui.AddStatus("[red]Not connected")
return return
} }
channels, err := a.client.ListChannels() channels, err := a.client.ListChannels()
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]List failed: %v", err))
fmt.Sprintf("[red]List failed: %v", err),
)
return return
} }
a.ui.AddStatus("[cyan]*** Channel list:") a.ui.AddStatus("[cyan]*** Channel list:")
for _, ch := range channels { for _, ch := range channels {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf(" %s (%d members) %s", ch.Name, ch.Members, ch.Topic))
fmt.Sprintf(
" %s (%d members) %s",
ch.Name, ch.Members, ch.Topic,
),
)
} }
a.ui.AddStatus("[cyan]*** End of channel list") a.ui.AddStatus("[cyan]*** End of channel list")
} }
func (a *App) cmdWindow(args string) { func (a *App) cmdWindow(args string) {
if args == "" { if args == "" {
a.ui.AddStatus("[red]Usage: /window <number>") a.ui.AddStatus("[red]Usage: /window <number>")
return return
} }
n := 0
n, _ := strconv.Atoi(args) fmt.Sscanf(args, "%d", &n)
a.ui.SwitchBuffer(n) a.ui.SwitchBuffer(n)
a.mu.Lock() a.mu.Lock()
if n < a.ui.BufferCount() && n >= 0 {
// Update target to the buffer name.
// Needs to be done carefully.
}
nick := a.nick nick := a.nick
a.mu.Unlock() a.mu.Unlock()
if n >= 0 && n < a.ui.BufferCount() { // Update target based on buffer.
if n < a.ui.BufferCount() {
buf := a.ui.buffers[n] buf := a.ui.buffers[n]
if buf.Name != "(status)" { if buf.Name != "(status)" {
a.mu.Lock() a.mu.Lock()
a.target = buf.Name a.target = buf.Name
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus(nick, buf.Name, "connected") a.ui.SetStatus(nick, buf.Name, "connected")
} else { } else {
a.ui.SetStatus(nick, "", "connected") a.ui.SetStatus(nick, "", "connected")
@@ -560,17 +419,12 @@ func (a *App) cmdWindow(args string) {
func (a *App) cmdQuit() { func (a *App) cmdQuit() {
a.mu.Lock() a.mu.Lock()
if a.connected && a.client != nil { if a.connected && a.client != nil {
_ = a.client.SendMessage( _ = a.client.SendMessage(&api.Message{Command: "QUIT"})
&api.Message{Command: "QUIT"},
)
} }
if a.stopPoll != nil { if a.stopPoll != nil {
close(a.stopPoll) close(a.stopPoll)
} }
a.mu.Unlock() a.mu.Unlock()
a.ui.Stop() a.ui.Stop()
} }
@@ -578,21 +432,20 @@ func (a *App) cmdQuit() {
func (a *App) cmdHelp() { func (a *App) cmdHelp() {
help := []string{ help := []string{
"[cyan]*** chat-cli commands:", "[cyan]*** chat-cli commands:",
" /connect <url> \u2014 Connect to server", " /connect <url> Connect to server",
" /nick <name> \u2014 Change nickname", " /nick <name> Change nickname",
" /join #channel \u2014 Join channel", " /join #channel Join channel",
" /part [#chan] \u2014 Leave channel", " /part [#chan] Leave channel",
" /msg <nick> <text> \u2014 Send DM", " /msg <nick> <text> Send DM",
" /query <nick> \u2014 Open DM window", " /query <nick> Open DM window",
" /topic [text] \u2014 View/set topic", " /topic [text] View/set topic",
" /names \u2014 List channel members", " /names List channel members",
" /list \u2014 List channels", " /list List channels",
" /window <n> \u2014 Switch buffer (Alt+0-9)", " /window <n> Switch buffer (Alt+0-9)",
" /quit \u2014 Disconnect and exit", " /quit Disconnect and exit",
" /help \u2014 This help", " /help This help",
" Plain text sends to current target.", " Plain text sends to current target.",
} }
for _, line := range help { for _, line := range help {
a.ui.AddStatus(line) a.ui.AddStatus(line)
} }
@@ -616,31 +469,32 @@ func (a *App) pollLoop() {
return return
} }
msgs, err := client.PollMessages( msgs, err := client.PollMessages(lastID, 15)
lastID, pollTimeoutSec,
)
if err != nil { if err != nil {
time.Sleep(retryDelay) // Transient error — retry after delay.
time.Sleep(2 * time.Second)
continue continue
} }
for i := range msgs { for _, msg := range msgs {
a.handleServerMessage(&msgs[i]) a.handleServerMessage(&msg)
if msg.ID != "" {
if msgs[i].ID != "" {
a.mu.Lock() a.mu.Lock()
a.lastMsgID = msgs[i].ID a.lastMsgID = msg.ID
a.mu.Unlock() a.mu.Unlock()
} }
} }
} }
} }
func (a *App) handleServerMessage( func (a *App) handleServerMessage(msg *api.Message) {
msg *api.Message, ts := ""
) { if msg.TS != "" {
ts := a.parseMessageTS(msg) t := msg.ParseTS()
ts = t.Local().Format("15:04")
} else {
ts = time.Now().Format("15:04")
}
a.mu.Lock() a.mu.Lock()
myNick := a.nick myNick := a.nick
@@ -648,203 +502,79 @@ func (a *App) handleServerMessage(
switch msg.Command { switch msg.Command {
case "PRIVMSG": case "PRIVMSG":
a.handlePrivmsgMsg(msg, ts, myNick)
case "JOIN":
a.handleJoinMsg(msg, ts)
case "PART":
a.handlePartMsg(msg, ts)
case "QUIT":
a.handleQuitMsg(msg, ts)
case "NICK":
a.handleNickMsg(msg, ts, myNick)
case "NOTICE":
a.handleNoticeMsg(msg, ts)
case "TOPIC":
a.handleTopicMsg(msg, ts)
default:
a.handleDefaultMsg(msg, ts)
}
}
func (a *App) parseMessageTS(msg *api.Message) string {
if msg.TS != "" {
t := msg.ParseTS()
return t.In(time.Local).Format("15:04") //nolint:gosmopolitan // CLI uses local time
}
return time.Now().Format("15:04")
}
func (a *App) handlePrivmsgMsg(
msg *api.Message,
ts, myNick string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
if msg.From == myNick { if msg.From == myNick {
// Skip our own echoed messages (already displayed locally).
return return
} }
target := msg.To target := msg.To
if !strings.HasPrefix(target, "#") { if !strings.HasPrefix(target, "#") {
// DM — use sender's nick as buffer name.
target = msg.From target = msg.From
} }
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, msg.From, text))
a.ui.AddLine( case "JOIN":
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, msg.From, text,
),
)
}
func (a *App) handleJoinMsg(
msg *api.Message, ts string,
) {
target := msg.To target := msg.To
if target == "" { if target != "" {
return a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has joined %s", ts, msg.From, target))
} }
a.ui.AddLine( case "PART":
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has joined %s",
ts, msg.From, target,
),
)
}
func (a *App) handlePartMsg(
msg *api.Message, ts string,
) {
target := msg.To target := msg.To
if target == "" {
return
}
lines := msg.BodyLines() lines := msg.BodyLines()
reason := strings.Join(lines, " ") reason := strings.Join(lines, " ")
if target != "" {
if reason != "" { if reason != "" {
a.ui.AddLine( a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s (%s)", ts, msg.From, target, reason))
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s (%s)",
ts, msg.From, target, reason,
),
)
} else { } else {
a.ui.AddLine( a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s", ts, msg.From, target))
target, }
fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s",
ts, msg.From, target,
),
)
} }
}
func (a *App) handleQuitMsg( case "QUIT":
msg *api.Message, ts string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
reason := strings.Join(lines, " ") reason := strings.Join(lines, " ")
if reason != "" { if reason != "" {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit (%s)", ts, msg.From, reason))
fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit (%s)",
ts, msg.From, reason,
),
)
} else { } else {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit", ts, msg.From))
fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit",
ts, msg.From,
),
)
} }
}
func (a *App) handleNickMsg( case "NICK":
msg *api.Message, ts, myNick string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
newNick := "" newNick := ""
if len(lines) > 0 { if len(lines) > 0 {
newNick = lines[0] newNick = lines[0]
} }
if msg.From == myNick && newNick != "" { if msg.From == myNick && newNick != "" {
a.mu.Lock() a.mu.Lock()
a.nick = newNick a.nick = newNick
target := a.target target := a.target
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus(newNick, target, "connected") a.ui.SetStatus(newNick, target, "connected")
} }
a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s is now known as %s", ts, msg.From, newNick))
a.ui.AddStatus( case "NOTICE":
fmt.Sprintf(
"[gray]%s [yellow]*** %s is now known as %s",
ts, msg.From, newNick,
),
)
}
func (a *App) handleNoticeMsg(
msg *api.Message, ts string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
a.ui.AddStatus(fmt.Sprintf("[gray]%s [magenta]--%s-- %s", ts, msg.From, text))
a.ui.AddStatus( case "TOPIC":
fmt.Sprintf( lines := msg.BodyLines()
"[gray]%s [magenta]--%s-- %s", text := strings.Join(lines, " ")
ts, msg.From, text, if msg.To != "" {
), a.ui.AddLine(msg.To, fmt.Sprintf("[gray]%s [cyan]*** %s set topic: %s", ts, msg.From, text))
)
}
func (a *App) handleTopicMsg(
msg *api.Message, ts string,
) {
if msg.To == "" {
return
} }
default:
// Numeric replies and other messages → status window.
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
if text != "" {
a.ui.AddLine( a.ui.AddStatus(fmt.Sprintf("[gray]%s [white][%s] %s", ts, msg.Command, text))
msg.To, }
fmt.Sprintf(
"[gray]%s [cyan]*** %s set topic: %s",
ts, msg.From, text,
),
)
}
func (a *App) handleDefaultMsg(
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
if text == "" {
return
} }
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [white][%s] %s",
ts, msg.Command, text,
),
)
} }

View File

@@ -31,7 +31,6 @@ type UI struct {
} }
// NewUI creates the tview-based IRC-like UI. // NewUI creates the tview-based IRC-like UI.
func NewUI() *UI { func NewUI() *UI {
ui := &UI{ ui := &UI{
app: tview.NewApplication(), app: tview.NewApplication(),
@@ -40,145 +39,7 @@ func NewUI() *UI {
}, },
} }
ui.setupMessages() // Message area.
ui.setupStatusBar()
ui.setupInput()
ui.setupKeybindings()
ui.setupLayout()
return ui
}
// Run starts the UI event loop (blocks).
func (ui *UI) Run() error {
return ui.app.Run()
}
// Stop stops the UI.
func (ui *UI) Stop() {
ui.app.Stop()
}
// OnInput sets the callback for user input.
func (ui *UI) OnInput(fn func(string)) {
ui.onInput = fn
}
// AddLine adds a line to the specified buffer.
func (ui *UI) AddLine(bufferName string, line string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(bufferName)
buf.Lines = append(buf.Lines, line)
// Mark unread if not currently viewing this buffer.
if ui.buffers[ui.currentBuffer] != buf {
buf.Unread++
ui.refreshStatus()
}
// If viewing this buffer, append to display.
if ui.buffers[ui.currentBuffer] == buf {
_, _ = fmt.Fprintln(ui.messages, line)
}
})
}
// AddStatus adds a line to the status buffer (buffer 0).
func (ui *UI) AddStatus(line string) {
ts := time.Now().Format("15:04")
ui.AddLine(
"(status)",
fmt.Sprintf("[gray]%s[white] %s", ts, line),
)
}
// SwitchBuffer switches to the buffer at index n.
func (ui *UI) SwitchBuffer(n int) {
ui.app.QueueUpdateDraw(func() {
if n < 0 || n >= len(ui.buffers) {
return
}
ui.currentBuffer = n
buf := ui.buffers[n]
buf.Unread = 0
ui.messages.Clear()
for _, line := range buf.Lines {
_, _ = fmt.Fprintln(ui.messages, line)
}
ui.messages.ScrollToEnd()
ui.refreshStatus()
})
}
// SwitchToBuffer switches to the named buffer, creating it
func (ui *UI) SwitchToBuffer(name string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(name)
for i, b := range ui.buffers {
if b == buf {
ui.currentBuffer = i
break
}
}
buf.Unread = 0
ui.messages.Clear()
for _, line := range buf.Lines {
_, _ = fmt.Fprintln(ui.messages, line)
}
ui.messages.ScrollToEnd()
ui.refreshStatus()
})
}
// SetStatus updates the status bar text.
func (ui *UI) SetStatus(
nick, target, connStatus string,
) {
ui.app.QueueUpdateDraw(func() {
ui.refreshStatusWith(nick, target, connStatus)
})
}
// BufferCount returns the number of buffers.
func (ui *UI) BufferCount() int {
return len(ui.buffers)
}
// BufferIndex returns the index of a named buffer, or -1.
func (ui *UI) BufferIndex(name string) int {
for i, buf := range ui.buffers {
if buf.Name == name {
return i
}
}
return -1
}
func (ui *UI) setupMessages() {
ui.messages = tview.NewTextView(). ui.messages = tview.NewTextView().
SetDynamicColors(true). SetDynamicColors(true).
SetScrollable(true). SetScrollable(true).
@@ -187,109 +48,162 @@ func (ui *UI) setupMessages() {
ui.app.Draw() ui.app.Draw()
}) })
ui.messages.SetBorder(false) ui.messages.SetBorder(false)
}
func (ui *UI) setupStatusBar() { // Status bar.
ui.statusBar = tview.NewTextView(). ui.statusBar = tview.NewTextView().
SetDynamicColors(true) SetDynamicColors(true)
ui.statusBar.SetBackgroundColor(tcell.ColorNavy) ui.statusBar.SetBackgroundColor(tcell.ColorNavy)
ui.statusBar.SetTextColor(tcell.ColorWhite) ui.statusBar.SetTextColor(tcell.ColorWhite)
}
func (ui *UI) setupInput() { // Input field.
ui.input = tview.NewInputField(). ui.input = tview.NewInputField().
SetFieldBackgroundColor(tcell.ColorBlack). SetFieldBackgroundColor(tcell.ColorBlack).
SetFieldTextColor(tcell.ColorWhite) SetFieldTextColor(tcell.ColorWhite)
ui.input.SetDoneFunc(func(key tcell.Key) { ui.input.SetDoneFunc(func(key tcell.Key) {
if key != tcell.KeyEnter { if key == tcell.KeyEnter {
return
}
text := ui.input.GetText() text := ui.input.GetText()
if text == "" { if text == "" {
return return
} }
ui.input.SetText("") ui.input.SetText("")
if ui.onInput != nil { if ui.onInput != nil {
ui.onInput(text) ui.onInput(text)
} }
})
}
func (ui *UI) setupKeybindings() {
ui.app.SetInputCapture(
func(event *tcell.EventKey) *tcell.EventKey {
if event.Modifiers()&tcell.ModAlt == 0 {
return event
} }
})
// Capture Alt+N for window switching.
ui.app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Modifiers()&tcell.ModAlt != 0 {
r := event.Rune() r := event.Rune()
if r >= '0' && r <= '9' { if r >= '0' && r <= '9' {
idx := int(r - '0') idx := int(r - '0')
ui.SwitchBuffer(idx) ui.SwitchBuffer(idx)
return nil return nil
} }
}
return event return event
}, })
)
}
func (ui *UI) setupLayout() { // Layout: messages on top, status bar, input at bottom.
ui.layout = tview.NewFlex(). ui.layout = tview.NewFlex().SetDirection(tview.FlexRow).
SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false). AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false). AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true) AddItem(ui.input, 1, 0, true)
ui.app.SetRoot(ui.layout, true) ui.app.SetRoot(ui.layout, true)
ui.app.SetFocus(ui.input) ui.app.SetFocus(ui.input)
return ui
} }
// if needed. // Run starts the UI event loop (blocks).
func (ui *UI) Run() error {
return ui.app.Run()
}
// Stop stops the UI.
func (ui *UI) Stop() {
ui.app.Stop()
}
// OnInput sets the callback for user input.
func (ui *UI) OnInput(fn func(string)) {
ui.onInput = fn
}
// AddLine adds a line to the specified buffer.
func (ui *UI) AddLine(bufferName string, line string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(bufferName)
buf.Lines = append(buf.Lines, line)
// Mark unread if not currently viewing this buffer.
if ui.buffers[ui.currentBuffer] != buf {
buf.Unread++
ui.refreshStatus()
}
// If viewing this buffer, append to display.
if ui.buffers[ui.currentBuffer] == buf {
fmt.Fprintln(ui.messages, line)
}
})
}
// AddStatus adds a line to the status buffer (buffer 0).
func (ui *UI) AddStatus(line string) {
ts := time.Now().Format("15:04")
ui.AddLine("(status)", fmt.Sprintf("[gray]%s[white] %s", ts, line))
}
// SwitchBuffer switches to the buffer at index n.
func (ui *UI) SwitchBuffer(n int) {
ui.app.QueueUpdateDraw(func() {
if n < 0 || n >= len(ui.buffers) {
return
}
ui.currentBuffer = n
buf := ui.buffers[n]
buf.Unread = 0
ui.messages.Clear()
for _, line := range buf.Lines {
fmt.Fprintln(ui.messages, line)
}
ui.messages.ScrollToEnd()
ui.refreshStatus()
})
}
// SwitchToBuffer switches to the named buffer, creating it if needed.
func (ui *UI) SwitchToBuffer(name string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(name)
for i, b := range ui.buffers {
if b == buf {
ui.currentBuffer = i
break
}
}
buf.Unread = 0
ui.messages.Clear()
for _, line := range buf.Lines {
fmt.Fprintln(ui.messages, line)
}
ui.messages.ScrollToEnd()
ui.refreshStatus()
})
}
// SetStatus updates the status bar text.
func (ui *UI) SetStatus(nick, target, connStatus string) {
ui.app.QueueUpdateDraw(func() {
ui.refreshStatusWith(nick, target, connStatus)
})
}
func (ui *UI) refreshStatus() { func (ui *UI) refreshStatus() {
// Rebuilt from app state by parent QueueUpdateDraw. // Will be called from the main goroutine via QueueUpdateDraw parent.
// Rebuild status from app state — caller must provide context.
} }
func (ui *UI) refreshStatusWith( func (ui *UI) refreshStatusWith(nick, target, connStatus string) {
nick, target, connStatus string,
) {
var unreadParts []string var unreadParts []string
for i, buf := range ui.buffers { for i, buf := range ui.buffers {
if buf.Unread > 0 { if buf.Unread > 0 {
unreadParts = append( unreadParts = append(unreadParts, fmt.Sprintf("%d:%s(%d)", i, buf.Name, buf.Unread))
unreadParts,
fmt.Sprintf(
"%d:%s(%d)", i, buf.Name, buf.Unread,
),
)
} }
} }
unread := "" unread := ""
if len(unreadParts) > 0 { if len(unreadParts) > 0 {
unread = " [Act: " + unread = " [Act: " + strings.Join(unreadParts, ",") + "]"
strings.Join(unreadParts, ",") + "]"
} }
bufInfo := fmt.Sprintf( bufInfo := fmt.Sprintf("[%d:%s]", ui.currentBuffer, ui.buffers[ui.currentBuffer].Name)
"[%d:%s]",
ui.currentBuffer,
ui.buffers[ui.currentBuffer].Name,
)
ui.statusBar.Clear() ui.statusBar.Clear()
fmt.Fprintf(ui.statusBar, " [%s] %s %s %s%s",
_, _ = fmt.Fprintf( connStatus, nick, bufInfo, target, unread)
ui.statusBar, " [%s] %s %s %s%s",
connStatus, nick, bufInfo, target, unread,
)
} }
func (ui *UI) getOrCreateBuffer(name string) *Buffer { func (ui *UI) getOrCreateBuffer(name string) *Buffer {
@@ -298,9 +212,22 @@ func (ui *UI) getOrCreateBuffer(name string) *Buffer {
return buf return buf
} }
} }
buf := &Buffer{Name: name} buf := &Buffer{Name: name}
ui.buffers = append(ui.buffers, buf) ui.buffers = append(ui.buffers, buf)
return buf return buf
} }
// BufferCount returns the number of buffers.
func (ui *UI) BufferCount() int {
return len(ui.buffers)
}
// BufferIndex returns the index of a named buffer, or -1.
func (ui *UI) BufferIndex(name string) int {
for i, buf := range ui.buffers {
if buf.Name == name {
return i
}
}
return -1
}

View File

@@ -46,6 +46,26 @@ type Database struct {
params *Params params *Params
} }
// GetDB returns the underlying sql.DB connection.
func (s *Database) GetDB() *sql.DB {
return s.db
}
// NewChannel creates a Channel model instance with the db reference injected.
func (s *Database) NewChannel(id int64, name, topic, modes string, createdAt, updatedAt time.Time) *models.Channel {
c := &models.Channel{
ID: id,
Name: name,
Topic: topic,
Modes: modes,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}
c.SetDB(s)
return c
}
// New creates a new Database instance and registers lifecycle hooks. // New creates a new Database instance and registers lifecycle hooks.
func New(lc fx.Lifecycle, params Params) (*Database, error) { func New(lc fx.Lifecycle, params Params) (*Database, error) {
s := new(Database) s := new(Database)
@@ -74,460 +94,6 @@ func New(lc fx.Lifecycle, params Params) (*Database, error) {
return s, nil return s, nil
} }
// NewTest creates a Database for testing, bypassing fx lifecycle.
// It connects to the given DSN and runs all migrations.
func NewTest(dsn string) (*Database, error) {
d, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, err
}
s := &Database{
db: d,
log: slog.Default(),
}
// Item 9: Enable foreign keys
_, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx // no context in sql.Open path
if err != nil {
_ = d.Close()
return nil, fmt.Errorf("enable foreign keys: %w", err)
}
ctx := context.Background()
err = s.runMigrations(ctx)
if err != nil {
_ = d.Close()
return nil, err
}
return s, nil
}
// GetDB returns the underlying sql.DB connection.
func (s *Database) GetDB() *sql.DB {
return s.db
}
// Hydrate injects the database reference into any model that
// embeds Base.
func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) {
m.SetDB(s)
}
// GetUserByID looks up a user by their ID.
func (s *Database) GetUserByID(
ctx context.Context,
id string,
) (*models.User, error) {
u := &models.User{}
s.Hydrate(u)
err := s.db.QueryRowContext(ctx, `
SELECT id, nick, password_hash, created_at, updated_at, last_seen_at
FROM users WHERE id = ?`,
id,
).Scan(
&u.ID, &u.Nick, &u.PasswordHash,
&u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// GetChannelByID looks up a channel by its ID.
func (s *Database) GetChannelByID(
ctx context.Context,
id string,
) (*models.Channel, error) {
c := &models.Channel{}
s.Hydrate(c)
err := s.db.QueryRowContext(ctx, `
SELECT id, name, topic, modes, created_at, updated_at
FROM channels WHERE id = ?`,
id,
).Scan(
&c.ID, &c.Name, &c.Topic, &c.Modes,
&c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, err
}
return c, nil
}
// GetUserByNickModel looks up a user by their nick.
func (s *Database) GetUserByNickModel(
ctx context.Context,
nick string,
) (*models.User, error) {
u := &models.User{}
s.Hydrate(u)
err := s.db.QueryRowContext(ctx, `
SELECT id, nick, password_hash, created_at, updated_at, last_seen_at
FROM users WHERE nick = ?`,
nick,
).Scan(
&u.ID, &u.Nick, &u.PasswordHash,
&u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// GetUserByTokenModel looks up a user by their auth token.
func (s *Database) GetUserByTokenModel(
ctx context.Context,
token string,
) (*models.User, error) {
u := &models.User{}
s.Hydrate(u)
err := s.db.QueryRowContext(ctx, `
SELECT u.id, u.nick, u.password_hash,
u.created_at, u.updated_at, u.last_seen_at
FROM users u
JOIN auth_tokens t ON t.user_id = u.id
WHERE t.token = ?`,
token,
).Scan(
&u.ID, &u.Nick, &u.PasswordHash,
&u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// DeleteAuthToken removes an auth token from the database.
func (s *Database) DeleteAuthToken(
ctx context.Context,
token string,
) error {
_, err := s.db.ExecContext(ctx,
`DELETE FROM auth_tokens WHERE token = ?`, token,
)
return err
}
// UpdateUserLastSeen updates the last_seen_at timestamp for a user.
func (s *Database) UpdateUserLastSeen(
ctx context.Context,
userID string,
) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`,
userID,
)
return err
}
// CreateUserModel inserts a new user into the database.
func (s *Database) CreateUserModel(
ctx context.Context,
id, nick, passwordHash string,
) (*models.User, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO users (id, nick, password_hash)
VALUES (?, ?, ?)`,
id, nick, passwordHash,
)
if err != nil {
return nil, err
}
u := &models.User{
ID: id, Nick: nick, PasswordHash: passwordHash,
CreatedAt: now, UpdatedAt: now,
}
s.Hydrate(u)
return u, nil
}
// CreateChannel inserts a new channel into the database.
func (s *Database) CreateChannel(
ctx context.Context,
id, name, topic, modes string,
) (*models.Channel, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO channels (id, name, topic, modes)
VALUES (?, ?, ?, ?)`,
id, name, topic, modes,
)
if err != nil {
return nil, err
}
c := &models.Channel{
ID: id, Name: name, Topic: topic, Modes: modes,
CreatedAt: now, UpdatedAt: now,
}
s.Hydrate(c)
return c, nil
}
// AddChannelMember adds a user to a channel with the given modes.
func (s *Database) AddChannelMember(
ctx context.Context,
channelID, userID, modes string,
) (*models.ChannelMember, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO channel_members
(channel_id, user_id, modes)
VALUES (?, ?, ?)`,
channelID, userID, modes,
)
if err != nil {
return nil, err
}
cm := &models.ChannelMember{
ChannelID: channelID,
UserID: userID,
Modes: modes,
JoinedAt: now,
}
s.Hydrate(cm)
return cm, nil
}
// CreateMessage inserts a new message into the database.
func (s *Database) CreateMessage(
ctx context.Context,
id, fromUserID, fromNick, target, msgType, body string,
) (*models.Message, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO messages
(id, from_user_id, from_nick, target, type, body)
VALUES (?, ?, ?, ?, ?, ?)`,
id, fromUserID, fromNick, target, msgType, body,
)
if err != nil {
return nil, err
}
m := &models.Message{
ID: id,
FromUserID: fromUserID,
FromNick: fromNick,
Target: target,
Type: msgType,
Body: body,
Timestamp: now,
CreatedAt: now,
}
s.Hydrate(m)
return m, nil
}
// QueueMessage adds a message to a user's delivery queue.
func (s *Database) QueueMessage(
ctx context.Context,
userID, messageID string,
) (*models.MessageQueueEntry, error) {
now := time.Now()
res, err := s.db.ExecContext(ctx,
`INSERT INTO message_queue (user_id, message_id)
VALUES (?, ?)`,
userID, messageID,
)
if err != nil {
return nil, err
}
entryID, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get last insert id: %w", err)
}
mq := &models.MessageQueueEntry{
ID: entryID,
UserID: userID,
MessageID: messageID,
QueuedAt: now,
}
s.Hydrate(mq)
return mq, nil
}
// DequeueMessages returns up to limit pending messages for a user,
// ordered by queue time (oldest first).
func (s *Database) DequeueMessages(
ctx context.Context,
userID string,
limit int,
) ([]*models.MessageQueueEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, user_id, message_id, queued_at
FROM message_queue
WHERE user_id = ?
ORDER BY queued_at ASC
LIMIT ?`,
userID, limit,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
entries := []*models.MessageQueueEntry{}
for rows.Next() {
e := &models.MessageQueueEntry{}
s.Hydrate(e)
err = rows.Scan(&e.ID, &e.UserID, &e.MessageID, &e.QueuedAt)
if err != nil {
return nil, err
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// AckMessages removes the given queue entry IDs, marking them as delivered.
func (s *Database) AckMessages(
ctx context.Context,
entryIDs []int64,
) error {
if len(entryIDs) == 0 {
return nil
}
placeholders := make([]string, len(entryIDs))
args := make([]any, len(entryIDs))
for i, id := range entryIDs {
placeholders[i] = "?"
args[i] = id
}
query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input
"DELETE FROM message_queue WHERE id IN (%s)",
strings.Join(placeholders, ","),
)
_, err := s.db.ExecContext(ctx, query, args...)
return err
}
// CreateAuthToken inserts a new auth token for a user.
func (s *Database) CreateAuthToken(
ctx context.Context,
token, userID string,
) (*models.AuthToken, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO auth_tokens (token, user_id)
VALUES (?, ?)`,
token, userID,
)
if err != nil {
return nil, err
}
at := &models.AuthToken{Token: token, UserID: userID, CreatedAt: now}
s.Hydrate(at)
return at, nil
}
// CreateSession inserts a new session for a user.
func (s *Database) CreateSession(
ctx context.Context,
id, userID string,
) (*models.Session, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO sessions (id, user_id)
VALUES (?, ?)`,
id, userID,
)
if err != nil {
return nil, err
}
sess := &models.Session{
ID: id, UserID: userID,
CreatedAt: now, LastActiveAt: now,
}
s.Hydrate(sess)
return sess, nil
}
// CreateServerLink inserts a new server link.
func (s *Database) CreateServerLink(
ctx context.Context,
id, name, url, sharedKeyHash string,
isActive bool,
) (*models.ServerLink, error) {
now := time.Now()
active := 0
if isActive {
active = 1
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO server_links
(id, name, url, shared_key_hash, is_active)
VALUES (?, ?, ?, ?, ?)`,
id, name, url, sharedKeyHash, active,
)
if err != nil {
return nil, err
}
sl := &models.ServerLink{
ID: id,
Name: name,
URL: url,
SharedKeyHash: sharedKeyHash,
IsActive: isActive,
CreatedAt: now,
}
s.Hydrate(sl)
return sl, nil
}
func (s *Database) connect(ctx context.Context) error { func (s *Database) connect(ctx context.Context) error {
dbURL := s.params.Config.DBURL dbURL := s.params.Config.DBURL
if dbURL == "" { if dbURL == "" {
@@ -553,12 +119,6 @@ func (s *Database) connect(ctx context.Context) error {
s.db = d s.db = d
s.log.Info("database connected") s.log.Info("database connected")
// Item 9: Enable foreign keys on every connection
_, err = s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON")
if err != nil {
return fmt.Errorf("enable foreign keys: %w", err)
}
return s.runMigrations(ctx) return s.runMigrations(ctx)
} }
@@ -589,18 +149,13 @@ func (s *Database) runMigrations(ctx context.Context) error {
return nil return nil
} }
func (s *Database) bootstrapMigrationsTable( func (s *Database) bootstrapMigrationsTable(ctx context.Context) error {
ctx context.Context, _, err := s.db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (
) error {
_, err := s.db.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY, version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf("failed to create schema_migrations table: %w", err)
"create schema_migrations table: %w", err,
)
} }
return nil return nil
@@ -609,20 +164,17 @@ func (s *Database) bootstrapMigrationsTable(
func (s *Database) loadMigrations() ([]migration, error) { func (s *Database) loadMigrations() ([]migration, error) {
entries, err := fs.ReadDir(SchemaFiles, "schema") entries, err := fs.ReadDir(SchemaFiles, "schema")
if err != nil { if err != nil {
return nil, fmt.Errorf("read schema dir: %w", err) return nil, fmt.Errorf("failed to read schema dir: %w", err)
} }
var migrations []migration var migrations []migration
for _, entry := range entries { for _, entry := range entries {
if entry.IsDir() || if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") {
!strings.HasSuffix(entry.Name(), ".sql") {
continue continue
} }
parts := strings.SplitN( parts := strings.SplitN(entry.Name(), "_", minMigrationParts)
entry.Name(), "_", minMigrationParts,
)
if len(parts) < minMigrationParts { if len(parts) < minMigrationParts {
continue continue
} }
@@ -632,13 +184,9 @@ func (s *Database) loadMigrations() ([]migration, error) {
continue continue
} }
content, err := SchemaFiles.ReadFile( content, err := SchemaFiles.ReadFile("schema/" + entry.Name())
"schema/" + entry.Name(),
)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("failed to read migration %s: %w", entry.Name(), err)
"read migration %s: %w", entry.Name(), err,
)
} }
migrations = append(migrations, migration{ migrations = append(migrations, migration{
@@ -655,81 +203,31 @@ func (s *Database) loadMigrations() ([]migration, error) {
return migrations, nil return migrations, nil
} }
// Item 4: Wrap each migration in a transaction func (s *Database) applyMigrations(ctx context.Context, migrations []migration) error {
func (s *Database) applyMigrations(
ctx context.Context,
migrations []migration,
) error {
for _, m := range migrations { for _, m := range migrations {
var exists int var exists int
err := s.db.QueryRowContext(ctx, err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&exists)
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
m.version,
).Scan(&exists)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf("failed to check migration %d: %w", m.version, err)
"check migration %d: %w", m.version, err,
)
} }
if exists > 0 { if exists > 0 {
continue continue
} }
s.log.Info( s.log.Info("applying migration", "version", m.version, "name", m.name)
"applying migration",
"version", m.version, "name", m.name,
)
err = s.executeMigration(ctx, m) _, err = s.db.ExecContext(ctx, m.sql)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to apply migration %d (%s): %w", m.version, m.name, err)
}
_, err = s.db.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES (?)", m.version)
if err != nil {
return fmt.Errorf("failed to record migration %d: %w", m.version, err)
} }
} }
return nil return nil
} }
func (s *Database) executeMigration(
ctx context.Context,
m migration,
) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf(
"begin tx for migration %d: %w", m.version, err,
)
}
_, err = tx.ExecContext(ctx, m.sql)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf(
"apply migration %d (%s): %w",
m.version, m.name, err,
)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
m.version,
)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf(
"record migration %d: %w", m.version, err,
)
}
err = tx.Commit()
if err != nil {
return fmt.Errorf(
"commit migration %d: %w", m.version, err,
)
}
return nil
}

View File

@@ -1,425 +0,0 @@
package db_test
import (
"context"
"fmt"
"path/filepath"
"testing"
"time"
"git.eeqj.de/sneak/chat/internal/db"
)
const (
nickAlice = "alice"
nickBob = "bob"
nickCharlie = "charlie"
)
// setupTestDB creates a fresh database in a temp directory with
// all migrations applied.
func setupTestDB(t *testing.T) *db.Database {
t.Helper()
dir := t.TempDir()
dsn := fmt.Sprintf(
"file:%s?_journal_mode=WAL",
filepath.Join(dir, "test.db"),
)
d, err := db.NewTest(dsn)
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}
t.Cleanup(func() { _ = d.GetDB().Close() })
return d
}
func TestCreateUser(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
id, token, err := d.CreateUser(ctx, nickAlice)
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if id <= 0 {
t.Errorf("expected positive id, got %d", id)
}
if token == "" {
t.Error("expected non-empty token")
}
}
func TestGetUserByToken(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, token, _ := d.CreateUser(ctx, nickAlice)
id, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("GetUserByToken: %v", err)
}
if id <= 0 || nick != nickAlice {
t.Errorf(
"got id=%d nick=%s, want nick=%s",
id, nick, nickAlice,
)
}
}
func TestGetUserByNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
origID, _, _ := d.CreateUser(ctx, nickAlice)
id, err := d.GetUserByNick(ctx, nickAlice)
if err != nil {
t.Fatalf("GetUserByNick: %v", err)
}
if id != origID {
t.Errorf("got id %d, want %d", id, origID)
}
}
func TestGetOrCreateChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
id1, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("GetOrCreateChannel: %v", err)
}
if id1 <= 0 {
t.Errorf("expected positive id, got %d", id1)
}
// Same channel returns same ID.
id2, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("GetOrCreateChannel(2): %v", err)
}
if id1 != id2 {
t.Errorf("got different ids: %d vs %d", id1, id2)
}
}
func TestJoinAndListChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice)
ch1, _ := d.GetOrCreateChannel(ctx, "#alpha")
ch2, _ := d.GetOrCreateChannel(ctx, "#beta")
_ = d.JoinChannel(ctx, ch1, uid)
_ = d.JoinChannel(ctx, ch2, uid)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels))
}
}
func TestListChannelsEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 0 {
t.Errorf("expected 0 channels, got %d", len(channels))
}
}
func TestPartChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
_ = d.PartChannel(ctx, chID, uid)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 0 {
t.Errorf("expected 0 after part, got %d", len(channels))
}
}
func TestSendAndGetMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
_, err := d.SendMessage(ctx, chID, uid, "hello world")
if err != nil {
t.Fatalf("SendMessage: %v", err)
}
msgs, err := d.GetMessages(ctx, chID, 0, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Content != "hello world" {
t.Errorf(
"got content %q, want %q",
msgs[0].Content, "hello world",
)
}
}
func TestChannelMembers(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
uid3, _, _ := d.CreateUser(ctx, nickCharlie)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_ = d.JoinChannel(ctx, chID, uid3)
members, err := d.ChannelMembers(ctx, chID)
if err != nil {
t.Fatalf("ChannelMembers: %v", err)
}
if len(members) != 3 {
t.Fatalf("expected 3 members, got %d", len(members))
}
}
func TestChannelMembersEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
chID, _ := d.GetOrCreateChannel(ctx, "#empty")
members, err := d.ChannelMembers(ctx, chID)
if err != nil {
t.Fatalf("ChannelMembers: %v", err)
}
if len(members) != 0 {
t.Errorf("expected 0, got %d", len(members))
}
}
func TestSendDM(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob")
if err != nil {
t.Fatalf("SendDM: %v", err)
}
if msgID <= 0 {
t.Errorf("expected positive msgID, got %d", msgID)
}
msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0)
if err != nil {
t.Fatalf("GetDMs: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 DM, got %d", len(msgs))
}
if msgs[0].Content != "hey bob" {
t.Errorf("got %q, want %q", msgs[0].Content, "hey bob")
}
}
func TestPollMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_, _ = d.SendMessage(ctx, chID, uid2, "hello")
_, _ = d.SendDM(ctx, uid2, uid1, "private")
time.Sleep(10 * time.Millisecond)
msgs, err := d.PollMessages(ctx, uid1, 0, 0)
if err != nil {
t.Fatalf("PollMessages: %v", err)
}
if len(msgs) < 2 {
t.Fatalf("expected >=2 messages, got %d", len(msgs))
}
}
func TestChangeNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, token, _ := d.CreateUser(ctx, nickAlice)
err := d.ChangeNick(ctx, 1, "alice2")
if err != nil {
t.Fatalf("ChangeNick: %v", err)
}
_, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("GetUserByToken: %v", err)
}
if nick != "alice2" {
t.Errorf("got nick %q, want alice2", nick)
}
}
func TestSetTopic(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice)
_, _ = d.GetOrCreateChannel(ctx, "#general")
err := d.SetTopic(ctx, "#general", uid, "new topic")
if err != nil {
t.Fatalf("SetTopic: %v", err)
}
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("ListAllChannels: %v", err)
}
found := false
for _, ch := range channels {
if ch.Name == "#general" && ch.Topic == "new topic" {
found = true
}
}
if !found {
t.Error("topic was not updated")
}
}
func TestGetMessagesBefore(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
for i := range 5 {
_, _ = d.SendMessage(
ctx, chID, uid,
fmt.Sprintf("msg%d", i),
)
time.Sleep(10 * time.Millisecond)
}
msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3)
if err != nil {
t.Fatalf("GetMessagesBefore: %v", err)
}
if len(msgs) != 3 {
t.Fatalf("expected 3, got %d", len(msgs))
}
}
func TestListAllChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.GetOrCreateChannel(ctx, "#alpha")
_, _ = d.GetOrCreateChannel(ctx, "#beta")
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("ListAllChannels: %v", err)
}
if len(channels) != 2 {
t.Errorf("expected 2, got %d", len(channels))
}
}

View File

@@ -3,146 +3,109 @@ package db
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"time" "time"
) )
const (
defaultMessageLimit = 50
defaultPollLimit = 100
tokenBytes = 32
)
func generateToken() string { func generateToken() string {
b := make([]byte, tokenBytes) b := make([]byte, 32)
_, _ = rand.Read(b) _, _ = rand.Read(b)
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
// CreateUser registers a new user with the given nick and // CreateUser registers a new user with the given nick and returns the user with token.
// returns the user with token. func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) {
func (s *Database) CreateUser(
ctx context.Context,
nick string,
) (int64, string, error) {
token := generateToken() token := generateToken()
now := time.Now() now := time.Now()
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
"INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)",
nick, token, now, now) nick, token, now, now)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("create user: %w", err) return 0, "", fmt.Errorf("create user: %w", err)
} }
id, _ := res.LastInsertId() id, _ := res.LastInsertId()
return id, token, nil return id, token, nil
} }
// GetUserByToken returns user id and nick for a given auth // GetUserByToken returns user id and nick for a given auth token.
// token. func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) {
func (s *Database) GetUserByToken(
ctx context.Context,
token string,
) (int64, string, error) {
var id int64 var id int64
var nick string var nick string
err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick)
err := s.db.QueryRowContext(
ctx,
"SELECT id, nick FROM users WHERE token = ?",
token,
).Scan(&id, &nick)
if err != nil { if err != nil {
return 0, "", err return 0, "", err
} }
// Update last_seen // Update last_seen
_, _ = s.db.ExecContext( _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id)
ctx,
"UPDATE users SET last_seen = ? WHERE id = ?",
time.Now(), id,
)
return id, nick, nil return id, nick, nil
} }
// GetUserByNick returns user id for a given nick. // GetUserByNick returns user id for a given nick.
func (s *Database) GetUserByNick( func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) {
ctx context.Context,
nick string,
) (int64, error) {
var id int64 var id int64
err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id)
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM users WHERE nick = ?",
nick,
).Scan(&id)
return id, err return id, err
} }
// GetOrCreateChannel returns the channel id, creating it if // GetOrCreateChannel returns the channel id, creating it if needed.
// needed. func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) {
func (s *Database) GetOrCreateChannel(
ctx context.Context,
name string,
) (int64, error) {
var id int64 var id int64
err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id)
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&id)
if err == nil { if err == nil {
return id, nil return id, nil
} }
now := time.Now() now := time.Now()
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
"INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)",
name, now, now) name, now, now)
if err != nil { if err != nil {
return 0, fmt.Errorf("create channel: %w", err) return 0, fmt.Errorf("create channel: %w", err)
} }
id, _ = res.LastInsertId() id, _ = res.LastInsertId()
return id, nil return id, nil
} }
// JoinChannel adds a user to a channel. // JoinChannel adds a user to a channel.
func (s *Database) JoinChannel( func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error {
ctx context.Context,
channelID, userID int64,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)",
channelID, userID, time.Now()) channelID, userID, time.Now())
return err return err
} }
// PartChannel removes a user from a channel. // PartChannel removes a user from a channel.
func (s *Database) PartChannel( func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error {
ctx context.Context,
channelID, userID int64,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?",
channelID, userID) channelID, userID)
return err return err
} }
// ListChannels returns all channels the user has joined.
func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic FROM channels c
INNER JOIN channel_members cm ON cm.channel_id = c.id
WHERE cm.user_id = ? ORDER BY c.name`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var channels []ChannelInfo
for rows.Next() {
var ch ChannelInfo
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil {
return nil, err
}
channels = append(channels, ch)
}
if channels == nil {
channels = []ChannelInfo{}
}
return channels, nil
}
// ChannelInfo is a lightweight channel representation. // ChannelInfo is a lightweight channel representation.
type ChannelInfo struct { type ChannelInfo struct {
ID int64 `json:"id"` ID int64 `json:"id"`
@@ -150,58 +113,8 @@ type ChannelInfo struct {
Topic string `json:"topic"` Topic string `json:"topic"`
} }
// ListChannels returns all channels the user has joined.
func (s *Database) ListChannels(
ctx context.Context,
userID int64,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic FROM channels c
INNER JOIN channel_members cm ON cm.channel_id = c.id
WHERE cm.user_id = ? ORDER BY c.name`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var channels []ChannelInfo
for rows.Next() {
var ch ChannelInfo
err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic)
if err != nil {
return nil, err
}
channels = append(channels, ch)
}
err = rows.Err()
if err != nil {
return nil, err
}
if channels == nil {
channels = []ChannelInfo{}
}
return channels, nil
}
// MemberInfo represents a channel member.
type MemberInfo struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
LastSeen time.Time `json:"lastSeen"`
}
// ChannelMembers returns all members of a channel. // ChannelMembers returns all members of a channel.
func (s *Database) ChannelMembers( func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) {
ctx context.Context,
channelID int64,
) ([]MemberInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen FROM users u `SELECT u.id, u.nick, u.last_seen FROM users u
INNER JOIN channel_members cm ON cm.user_id = u.id INNER JOIN channel_members cm ON cm.user_id = u.id
@@ -209,34 +122,28 @@ func (s *Database) ChannelMembers(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }()
var members []MemberInfo var members []MemberInfo
for rows.Next() { for rows.Next() {
var m MemberInfo var m MemberInfo
if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil {
err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen)
if err != nil {
return nil, err return nil, err
} }
members = append(members, m) members = append(members, m)
} }
err = rows.Err()
if err != nil {
return nil, err
}
if members == nil { if members == nil {
members = []MemberInfo{} members = []MemberInfo{}
} }
return members, nil return members, nil
} }
// MemberInfo represents a channel member.
type MemberInfo struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
LastSeen time.Time `json:"lastSeen"`
}
// MessageInfo represents a chat message. // MessageInfo represents a chat message.
type MessageInfo struct { type MessageInfo struct {
ID int64 `json:"id"` ID int64 `json:"id"`
@@ -248,18 +155,11 @@ type MessageInfo struct {
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
} }
// GetMessages returns messages for a channel, optionally // GetMessages returns messages for a channel, optionally after a given ID.
// after a given ID. func (s *Database) GetMessages(ctx context.Context, channelID int64, afterID int64, limit int) ([]MessageInfo, error) {
func (s *Database) GetMessages(
ctx context.Context,
channelID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = 50
} }
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT m.id, c.name, u.nick, m.content, m.created_at `SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m FROM messages m
@@ -270,288 +170,128 @@ func (s *Database) GetMessages(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }()
var msgs []MessageInfo var msgs []MessageInfo
for rows.Next() { for rows.Next() {
var m MessageInfo var m MessageInfo
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil {
return nil, err return nil, err
} }
msgs = append(msgs, m) msgs = append(msgs, m)
} }
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil { if msgs == nil {
msgs = []MessageInfo{} msgs = []MessageInfo{}
} }
return msgs, nil return msgs, nil
} }
// SendMessage inserts a channel message. // SendMessage inserts a channel message.
func (s *Database) SendMessage( func (s *Database) SendMessage(ctx context.Context, channelID, userID int64, content string) (int64, error) {
ctx context.Context,
channelID, userID int64,
content string,
) (int64, error) {
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
"INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)", "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)",
channelID, userID, content, time.Now()) channelID, userID, content, time.Now())
if err != nil { if err != nil {
return 0, err return 0, err
} }
return res.LastInsertId() return res.LastInsertId()
} }
// SendDM inserts a direct message. // SendDM inserts a direct message.
func (s *Database) SendDM( func (s *Database) SendDM(ctx context.Context, fromID, toID int64, content string) (int64, error) {
ctx context.Context,
fromID, toID int64,
content string,
) (int64, error) {
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
"INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)", "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)",
fromID, content, toID, time.Now()) fromID, content, toID, time.Now())
if err != nil { if err != nil {
return 0, err return 0, err
} }
return res.LastInsertId() return res.LastInsertId()
} }
// GetDMs returns direct messages between two users after a // GetDMs returns direct messages between two users after a given ID.
// given ID. func (s *Database) GetDMs(ctx context.Context, userA, userB int64, afterID int64, limit int) ([]MessageInfo, error) {
func (s *Database) GetDMs(
ctx context.Context,
userA, userB int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = 50
} }
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT m.id, u.nick, m.content, t.nick, m.created_at `SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id > ? WHERE m.is_dm = 1 AND m.id > ?
AND ((m.user_id = ? AND m.dm_target_id = ?) AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id ASC LIMIT ?`, afterID, userA, userB, userB, userA, limit)
ORDER BY m.id ASC LIMIT ?`,
afterID, userA, userB, userB, userA, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }()
var msgs []MessageInfo var msgs []MessageInfo
for rows.Next() { for rows.Next() {
var m MessageInfo var m MessageInfo
if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil {
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err return nil, err
} }
m.IsDM = true m.IsDM = true
msgs = append(msgs, m) msgs = append(msgs, m)
} }
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil { if msgs == nil {
msgs = []MessageInfo{} msgs = []MessageInfo{}
} }
return msgs, nil return msgs, nil
} }
// PollMessages returns all new messages (channel + DM) for // PollMessages returns all new messages (channel + DM) for a user after a given ID.
// a user after a given ID. func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64, limit int) ([]MessageInfo, error) {
func (s *Database) PollMessages(
ctx context.Context,
userID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultPollLimit limit = 100
} }
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, m.is_dm, COALESCE(t.nick, ''), m.created_at
m.is_dm, COALESCE(t.nick, ''), m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
LEFT JOIN channels c ON c.id = m.channel_id LEFT JOIN channels c ON c.id = m.channel_id
LEFT JOIN users t ON t.id = m.dm_target_id LEFT JOIN users t ON t.id = m.dm_target_id
WHERE m.id > ? AND ( WHERE m.id > ? AND (
(m.is_dm = 0 AND m.channel_id IN (m.is_dm = 0 AND m.channel_id IN (SELECT channel_id FROM channel_members WHERE user_id = ?))
(SELECT channel_id FROM channel_members OR (m.is_dm = 1 AND (m.user_id = ? OR m.dm_target_id = ?))
WHERE user_id = ?))
OR (m.is_dm = 1
AND (m.user_id = ? OR m.dm_target_id = ?))
) )
ORDER BY m.id ASC LIMIT ?`, ORDER BY m.id ASC LIMIT ?`, afterID, userID, userID, userID, limit)
afterID, userID, userID, userID, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }() var msgs []MessageInfo
msgs := make([]MessageInfo, 0)
for rows.Next() { for rows.Next() {
var ( var m MessageInfo
m MessageInfo var isDM int
isDM int if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &isDM, &m.DMTarget, &m.CreatedAt); err != nil {
)
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick, &m.Content,
&isDM, &m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err return nil, err
} }
m.IsDM = isDM == 1 m.IsDM = isDM == 1
msgs = append(msgs, m) msgs = append(msgs, m)
} }
err = rows.Err()
if err != nil {
return nil, err
}
return msgs, nil
}
func scanChannelMessages(
rows *sql.Rows,
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil {
return nil, err
}
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil { if msgs == nil {
msgs = []MessageInfo{} msgs = []MessageInfo{}
} }
return msgs, nil return msgs, nil
} }
func scanDMMessages( // GetMessagesBefore returns channel messages before a given ID (for history scrollback).
rows *sql.Rows, func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) {
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func reverseMessages(msgs []MessageInfo) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
}
// GetMessagesBefore returns channel messages before a given
// ID (for history scrollback).
func (s *Database) GetMessagesBefore(
ctx context.Context,
channelID int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = 50
} }
var query string var query string
var args []any var args []any
if beforeID > 0 { if beforeID > 0 {
query = `SELECT m.id, c.name, u.nick, m.content, query = `SELECT m.id, c.name, u.nick, m.content, m.created_at
m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0 WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ?
AND m.id < ?
ORDER BY m.id DESC LIMIT ?` ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, beforeID, limit} args = []any{channelID, beforeID, limit}
} else { } else {
query = `SELECT m.id, c.name, u.nick, m.content, query = `SELECT m.id, c.name, u.nick, m.content, m.created_at
m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id INNER JOIN channels c ON c.id = m.channel_id
@@ -559,153 +299,116 @@ func (s *Database) GetMessagesBefore(
ORDER BY m.id DESC LIMIT ?` ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, limit} args = []any{channelID, limit}
} }
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }() var msgs []MessageInfo
for rows.Next() {
msgs, scanErr := scanChannelMessages(rows) var m MessageInfo
if scanErr != nil { if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
return nil, scanErr return nil, err
}
msgs = append(msgs, m)
}
if msgs == nil {
msgs = []MessageInfo{}
}
// Reverse to ascending order
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
} }
// Reverse to ascending order.
reverseMessages(msgs)
return msgs, nil return msgs, nil
} }
// GetDMsBefore returns DMs between two users before a given // GetDMsBefore returns DMs between two users before a given ID (for history scrollback).
// ID (for history scrollback). func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) {
func (s *Database) GetDMsBefore(
ctx context.Context,
userA, userB int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = 50
} }
var query string var query string
var args []any var args []any
if beforeID > 0 { if beforeID > 0 {
query = `SELECT m.id, u.nick, m.content, t.nick, query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at
m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id < ? WHERE m.is_dm = 1 AND m.id < ?
AND ((m.user_id = ? AND m.dm_target_id = ?) AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?` ORDER BY m.id DESC LIMIT ?`
args = []any{ args = []any{beforeID, userA, userB, userB, userA, limit}
beforeID, userA, userB, userB, userA, limit,
}
} else { } else {
query = `SELECT m.id, u.nick, m.content, t.nick, query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at
m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 WHERE m.is_dm = 1
AND ((m.user_id = ? AND m.dm_target_id = ?) AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?` ORDER BY m.id DESC LIMIT ?`
args = []any{userA, userB, userB, userA, limit} args = []any{userA, userB, userB, userA, limit}
} }
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }() var msgs []MessageInfo
for rows.Next() {
msgs, scanErr := scanDMMessages(rows) var m MessageInfo
if scanErr != nil { if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil {
return nil, scanErr return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
if msgs == nil {
msgs = []MessageInfo{}
}
// Reverse to ascending order
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
} }
// Reverse to ascending order.
reverseMessages(msgs)
return msgs, nil return msgs, nil
} }
// ChangeNick updates a user's nickname. // ChangeNick updates a user's nickname.
func (s *Database) ChangeNick( func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error {
ctx context.Context,
userID int64,
newNick string,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"UPDATE users SET nick = ? WHERE id = ?", "UPDATE users SET nick = ? WHERE id = ?", newNick, userID)
newNick, userID)
return err return err
} }
// SetTopic sets the topic for a channel. // SetTopic sets the topic for a channel.
func (s *Database) SetTopic( func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error {
ctx context.Context,
channelName string,
_ int64,
topic string,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"UPDATE channels SET topic = ? WHERE name = ?", "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName)
topic, channelName)
return err return err
} }
// GetServerName returns the server name (unused, config // GetServerName returns the server name (unused, config provides this).
// provides this).
func (s *Database) GetServerName() string { func (s *Database) GetServerName() string {
return "" return ""
} }
// ListAllChannels returns all channels. // ListAllChannels returns all channels.
func (s *Database) ListAllChannels( func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) {
ctx context.Context,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
"SELECT id, name, topic FROM channels ORDER BY name") "SELECT id, name, topic FROM channels ORDER BY name")
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
defer func() { _ = rows.Close() }()
var channels []ChannelInfo var channels []ChannelInfo
for rows.Next() { for rows.Next() {
var ch ChannelInfo var ch ChannelInfo
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil {
err := rows.Scan(
&ch.ID, &ch.Name, &ch.Topic,
)
if err != nil {
return nil, err return nil, err
} }
channels = append(channels, ch) channels = append(channels, ch)
} }
err = rows.Err()
if err != nil {
return nil, err
}
if channels == nil { if channels == nil {
channels = []ChannelInfo{} channels = []ChannelInfo{}
} }
return channels, nil return channels, nil
} }

View File

@@ -1,89 +0,0 @@
-- All schema changes go into this file until 1.0.0 is tagged.
-- There will not be migrations during the early development phase.
-- After 1.0.0, new changes get their own numbered migration files.
-- Users: accounts and authentication
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY, -- UUID
nick TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen_at DATETIME
);
-- Auth tokens: one user can have multiple active tokens (multiple devices)
CREATE TABLE IF NOT EXISTS auth_tokens (
token TEXT PRIMARY KEY, -- random token string
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME, -- NULL = no expiry
last_used_at DATETIME
);
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens(user_id);
-- Channels: chat rooms
CREATE TABLE IF NOT EXISTS channels (
id TEXT PRIMARY KEY, -- UUID
name TEXT NOT NULL UNIQUE, -- #general, etc.
topic TEXT NOT NULL DEFAULT '',
modes TEXT NOT NULL DEFAULT '', -- +i, +m, +s, +t, +n
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Channel members: who is in which channel, with per-user modes
CREATE TABLE IF NOT EXISTS channel_members (
channel_id TEXT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
modes TEXT NOT NULL DEFAULT '', -- +o (operator), +v (voice)
joined_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (channel_id, user_id)
);
CREATE INDEX IF NOT EXISTS idx_channel_members_user_id ON channel_members(user_id);
-- Messages: channel and DM history (rotated per MAX_HISTORY)
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY, -- UUID
ts DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
from_user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
from_nick TEXT NOT NULL, -- denormalized for history
target TEXT NOT NULL, -- #channel name or user UUID for DMs
type TEXT NOT NULL DEFAULT 'message', -- message, action, notice, join, part, quit, topic, mode, nick, system
body TEXT NOT NULL DEFAULT '',
meta TEXT NOT NULL DEFAULT '{}', -- JSON extensible metadata
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_messages_target_ts ON messages(target, ts);
CREATE INDEX IF NOT EXISTS idx_messages_from_user ON messages(from_user_id);
-- Message queue: per-user pending delivery (unread messages)
CREATE TABLE IF NOT EXISTS message_queue (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
message_id TEXT NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
queued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_queue_user_id ON message_queue(user_id, queued_at);
-- Sessions: server-held session state
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY, -- UUID
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_active_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME -- idle timeout
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
-- Server links: federation peer configuration
CREATE TABLE IF NOT EXISTS server_links (
id TEXT PRIMARY KEY, -- UUID
name TEXT NOT NULL UNIQUE, -- human-readable peer name
url TEXT NOT NULL, -- base URL of peer server
shared_key_hash TEXT NOT NULL, -- hashed shared secret
is_active INTEGER NOT NULL DEFAULT 1,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen_at DATETIME
);

View File

@@ -0,0 +1,8 @@
CREATE TABLE IF NOT EXISTS channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
topic TEXT NOT NULL DEFAULT '',
modes TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -1,18 +1,6 @@
-- Migration 003: Replace UUID-based tables with simple integer-keyed PRAGMA foreign_keys = ON;
-- tables for the HTTP API. Drops the 002 tables and recreates them.
PRAGMA foreign_keys = OFF; CREATE TABLE IF NOT EXISTS users (
DROP TABLE IF EXISTS message_queue;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS server_links;
DROP TABLE IF EXISTS messages;
DROP TABLE IF EXISTS channel_members;
DROP TABLE IF EXISTS auth_tokens;
DROP TABLE IF EXISTS channels;
DROP TABLE IF EXISTS users;
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
nick TEXT NOT NULL UNIQUE, nick TEXT NOT NULL UNIQUE,
token TEXT NOT NULL UNIQUE, token TEXT NOT NULL UNIQUE,
@@ -20,15 +8,7 @@ CREATE TABLE users (
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE TABLE channels ( CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
topic TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
@@ -36,7 +16,7 @@ CREATE TABLE channel_members (
UNIQUE(channel_id, user_id) UNIQUE(channel_id, user_id)
); );
CREATE TABLE messages ( CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE, channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
@@ -46,8 +26,6 @@ CREATE TABLE messages (
created_at DATETIME DEFAULT CURRENT_TIMESTAMP created_at DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE INDEX idx_messages_channel ON messages(channel_id, created_at); CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id, created_at);
CREATE INDEX idx_messages_dm ON messages(user_id, dm_target_id, created_at); CREATE INDEX IF NOT EXISTS idx_messages_dm ON messages(user_id, dm_target_id, created_at);
CREATE INDEX idx_users_token ON users(token); CREATE INDEX IF NOT EXISTS idx_users_token ON users(token);
PRAGMA foreign_keys = ON;

View File

@@ -11,181 +11,79 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
) )
const ( // authUser extracts the user from the Authorization header (Bearer token).
maxNickLen = 32 func (s *Handlers) authUser(r *http.Request) (int64, string, error) {
defaultHistory = 50
)
// authUser extracts the user from the Authorization header
// (Bearer token).
func (s *Handlers) authUser(
r *http.Request,
) (int64, string, error) {
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") { if !strings.HasPrefix(auth, "Bearer ") {
return 0, "", sql.ErrNoRows return 0, "", sql.ErrNoRows
} }
token := strings.TrimPrefix(auth, "Bearer ") token := strings.TrimPrefix(auth, "Bearer ")
return s.params.Database.GetUserByToken(r.Context(), token) return s.params.Database.GetUserByToken(r.Context(), token)
} }
func (s *Handlers) requireAuth( func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, string, bool) {
w http.ResponseWriter,
r *http.Request,
) (int64, string, bool) {
uid, nick, err := s.authUser(r) uid, nick, err := s.authUser(r)
if err != nil { if err != nil {
s.respondJSON( s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized)
w, r,
map[string]string{"error": "unauthorized"},
http.StatusUnauthorized,
)
return 0, "", false return 0, "", false
} }
return uid, nick, true return uid, nick, true
} }
func (s *Handlers) respondError( // HandleCreateSession creates a new user session and returns the auth token.
w http.ResponseWriter,
r *http.Request,
msg string,
code int,
) {
s.respondJSON(w, r, map[string]string{"error": msg}, code)
}
func (s *Handlers) internalError(
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
s.log.Error(msg, "error", err)
s.respondError(w, r, "internal error", http.StatusInternalServerError)
}
// bodyLines extracts body as string lines from a request body
// field.
func bodyLines(body any) []string {
switch v := body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)
}
}
return lines
case []string:
return v
default:
return nil
}
}
// HandleCreateSession creates a new user session and returns
// the auth token.
func (s *Handlers) HandleCreateSession() http.HandlerFunc { func (s *Handlers) HandleCreateSession() http.HandlerFunc {
type request struct { type request struct {
Nick string `json:"nick"` Nick string `json:"nick"`
} }
type response struct { type response struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Nick string `json:"nick"` Nick string `json:"nick"`
Token string `json:"token"` Token string `json:"token"`
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req request var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
err := json.NewDecoder(r.Body).Decode(&req) s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
if err != nil {
s.respondError(
w, r, "invalid request",
http.StatusBadRequest,
)
return return
} }
req.Nick = strings.TrimSpace(req.Nick) req.Nick = strings.TrimSpace(req.Nick)
if req.Nick == "" || len(req.Nick) > 32 {
if req.Nick == "" || len(req.Nick) > maxNickLen { s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
s.respondError(
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return return
} }
id, token, err := s.params.Database.CreateUser(r.Context(), req.Nick)
id, token, err := s.params.Database.CreateUser(
r.Context(), req.Nick,
)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict)
w, r, "nick already taken",
http.StatusConflict,
)
return return
} }
s.log.Error("create user failed", "error", err)
s.internalError(w, r, "create user failed", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated)
s.respondJSON(
w, r,
&response{ID: id, Nick: req.Nick, Token: token},
http.StatusCreated,
)
} }
} }
// HandleState returns the current user's info and joined // HandleState returns the current user's info and joined channels.
// channels.
func (s *Handlers) HandleState() http.HandlerFunc { func (s *Handlers) HandleState() http.HandlerFunc {
type response struct { type response struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Nick string `json:"nick"` Nick string `json:"nick"`
Channels []db.ChannelInfo `json:"channels"` Channels []db.ChannelInfo `json:"channels"`
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r) uid, nick, ok := s.requireAuth(w, r)
if !ok { if !ok {
return return
} }
channels, err := s.params.Database.ListChannels(r.Context(), uid)
channels, err := s.params.Database.ListChannels(
r.Context(), uid,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("list channels failed", "error", err)
w, r, "list channels failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK)
s.respondJSON(
w, r,
&response{
ID: uid, Nick: nick,
Channels: channels,
},
http.StatusOK,
)
} }
} }
@@ -196,18 +94,12 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc {
if !ok { if !ok {
return return
} }
channels, err := s.params.Database.ListAllChannels(r.Context())
channels, err := s.params.Database.ListAllChannels(
r.Context(),
)
if err != nil { if err != nil {
s.internalError( s.log.Error("list all channels failed", "error", err)
w, r, "list all channels failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
s.respondJSON(w, r, channels, http.StatusOK) s.respondJSON(w, r, channels, http.StatusOK)
} }
} }
@@ -219,561 +111,278 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
if !ok { if !ok {
return return
} }
name := "#" + chi.URLParam(r, "channel") name := "#" + chi.URLParam(r, "channel")
var chID int64 var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query "SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
r.Context(),
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
members, err := s.params.Database.ChannelMembers(r.Context(), chID)
members, err := s.params.Database.ChannelMembers(
r.Context(), chID,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("channel members failed", "error", err)
w, r, "channel members failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
s.respondJSON(w, r, members, http.StatusOK) s.respondJSON(w, r, members, http.StatusOK)
} }
} }
// HandleGetMessages returns all new messages (channel + DM) // HandleGetMessages returns all new messages (channel + DM) for the user via long-polling.
// for the user via long-polling. // This is the single unified message stream — replaces separate channel/DM/poll endpoints.
func (s *Handlers) HandleGetMessages() http.HandlerFunc { func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r) uid, _, ok := s.requireAuth(w, r)
if !ok { if !ok {
return return
} }
afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64)
afterID, _ := strconv.ParseInt( limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
r.URL.Query().Get("after"), 10, 64, msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit)
)
limit, _ := strconv.Atoi(
r.URL.Query().Get("limit"),
)
msgs, err := s.params.Database.PollMessages(
r.Context(), uid, afterID, limit,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("get messages failed", "error", err)
w, r, "get messages failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
s.respondJSON(w, r, msgs, http.StatusOK) s.respondJSON(w, r, msgs, http.StatusOK)
} }
} }
type sendRequest struct { // HandleSendCommand handles all C2S commands via POST /messages.
// The "command" field dispatches to the appropriate logic.
func (s *Handlers) HandleSendCommand() http.HandlerFunc {
type request struct {
Command string `json:"command"` Command string `json:"command"`
To string `json:"to"` To string `json:"to"`
Params []string `json:"params,omitempty"` Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"` Body interface{} `json:"body,omitempty"`
} }
// HandleSendCommand handles all C2S commands via POST
// /messages.
func (s *Handlers) HandleSendCommand() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r) uid, nick, ok := s.requireAuth(w, r)
if !ok { if !ok {
return return
} }
var req request
var req sendRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
s.respondError(
w, r, "invalid request",
http.StatusBadRequest,
)
return return
} }
req.Command = strings.ToUpper(strings.TrimSpace(req.Command))
req.Command = strings.ToUpper(
strings.TrimSpace(req.Command),
)
req.To = strings.TrimSpace(req.To) req.To = strings.TrimSpace(req.To)
s.dispatchCommand(w, r, uid, nick, &req) // Helper to extract body as string lines.
bodyLines := func() []string {
switch v := req.Body.(type) {
case []interface{}:
lines := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)
}
}
return lines
case []string:
return v
default:
return nil
}
} }
}
func (s *Handlers) dispatchCommand(
w http.ResponseWriter,
r *http.Request,
uid int64,
nick string,
req *sendRequest,
) {
switch req.Command { switch req.Command {
case "PRIVMSG", "NOTICE": case "PRIVMSG", "NOTICE":
s.handlePrivmsg(w, r, uid, req)
case "JOIN":
s.handleJoin(w, r, uid, req)
case "PART":
s.handlePart(w, r, uid, req)
case "NICK":
s.handleNick(w, r, uid, req)
case "TOPIC":
s.handleTopic(w, r, uid, req)
case "PING":
s.respondJSON(
w, r,
map[string]string{
"command": "PONG",
"from": s.params.Config.ServerName,
},
http.StatusOK,
)
default:
_ = nick
s.respondError(
w, r,
"unknown command: "+req.Command,
http.StatusBadRequest,
)
}
}
func (s *Handlers) handlePrivmsg(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" { if req.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
lines := bodyLines()
lines := bodyLines(req.Body)
if len(lines) == 0 { if len(lines) == 0 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest)
w, r, "body required", http.StatusBadRequest,
)
return return
} }
content := strings.Join(lines, "\n") content := strings.Join(lines, "\n")
if strings.HasPrefix(req.To, "#") { if strings.HasPrefix(req.To, "#") {
s.sendChannelMsg(w, r, uid, req.To, content) // Channel message
var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
return
}
msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, content)
if err != nil {
s.log.Error("send message failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
} else { } else {
s.sendDM(w, r, uid, req.To, content) // DM
} targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To)
}
func (s *Handlers) sendChannelMsg(
w http.ResponseWriter,
r *http.Request,
uid int64,
channel, content string,
) {
var chID int64
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, content)
msgID, err := s.params.Database.SendMessage(
r.Context(), chID, uid, content,
)
if err != nil { if err != nil {
s.internalError(w, r, "send message failed", err) s.log.Error("send dm failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
s.respondJSON(
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
}
func (s *Handlers) sendDM(
w http.ResponseWriter,
r *http.Request,
uid int64,
toNick, content string,
) {
targetID, err := s.params.Database.GetUserByNick(
r.Context(), toNick,
)
if err != nil {
s.respondError(
w, r, "user not found", http.StatusNotFound,
)
return
} }
msgID, err := s.params.Database.SendDM( case "JOIN":
r.Context(), uid, targetID, content,
)
if err != nil {
s.internalError(w, r, "send dm failed", err)
return
}
s.respondJSON(
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
}
func (s *Handlers) handleJoin(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" { if req.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
channel := req.To channel := req.To
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel)
chID, err := s.params.Database.GetOrCreateChannel(
r.Context(), channel,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("get/create channel failed", "error", err)
w, r, "get/create channel failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil {
err = s.params.Database.JoinChannel( s.log.Error("join channel failed", "error", err)
r.Context(), chID, uid, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
if err != nil {
s.internalError(w, r, "join channel failed", err)
return return
} }
s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK)
s.respondJSON( case "PART":
w, r,
map[string]string{
"status": "joined", "channel": channel,
},
http.StatusOK,
)
}
func (s *Handlers) handlePart(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" { if req.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
channel := req.To channel := req.To
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
var chID int64 var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID)
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil {
err = s.params.Database.PartChannel( s.log.Error("part channel failed", "error", err)
r.Context(), chID, uid, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
if err != nil {
s.internalError(w, r, "part channel failed", err)
return return
} }
s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK)
s.respondJSON( case "NICK":
w, r, lines := bodyLines()
map[string]string{
"status": "parted", "channel": channel,
},
http.StatusOK,
)
}
func (s *Handlers) handleNick(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
lines := bodyLines(req.Body)
if len(lines) == 0 { if len(lines) == 0 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest)
w, r, "body required (new nick)",
http.StatusBadRequest,
)
return return
} }
newNick := strings.TrimSpace(lines[0]) newNick := strings.TrimSpace(lines[0])
if newNick == "" || len(newNick) > maxNickLen { if newNick == "" || len(newNick) > 32 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return return
} }
if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil {
err := s.params.Database.ChangeNick(
r.Context(), uid, newNick,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict)
w, r, "nick already in use",
http.StatusConflict,
)
return return
} }
s.log.Error("change nick failed", "error", err)
s.internalError(w, r, "change nick failed", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK)
s.respondJSON( case "TOPIC":
w, r,
map[string]string{"status": "ok", "nick": newNick},
http.StatusOK,
)
}
func (s *Handlers) handleTopic(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" { if req.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
lines := bodyLines()
lines := bodyLines(req.Body)
if len(lines) == 0 { if len(lines) == 0 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest)
w, r, "body required (topic text)",
http.StatusBadRequest,
)
return return
} }
topic := strings.Join(lines, " ") topic := strings.Join(lines, " ")
channel := req.To channel := req.To
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil {
err := s.params.Database.SetTopic( s.log.Error("set topic failed", "error", err)
r.Context(), channel, uid, topic, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
if err != nil {
s.internalError(w, r, "set topic failed", err)
return return
} }
s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK)
s.respondJSON( case "PING":
w, r, s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK)
map[string]string{"status": "ok", "topic": topic},
http.StatusOK, default:
) _ = nick // suppress unused warning
s.respondJSON(w, r, map[string]string{"error": "unknown command: " + req.Command}, http.StatusBadRequest)
}
}
} }
// HandleGetHistory returns message history for a specific // HandleGetHistory returns message history for a specific target (channel or DM).
// target (channel or DM).
func (s *Handlers) HandleGetHistory() http.HandlerFunc { func (s *Handlers) HandleGetHistory() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r) uid, _, ok := s.requireAuth(w, r)
if !ok { if !ok {
return return
} }
target := r.URL.Query().Get("target") target := r.URL.Query().Get("target")
if target == "" { if target == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest)
w, r, "target required",
http.StatusBadRequest,
)
return return
} }
beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64)
beforeID, _ := strconv.ParseInt( limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
r.URL.Query().Get("before"), 10, 64,
)
limit, _ := strconv.Atoi(
r.URL.Query().Get("limit"),
)
if limit <= 0 { if limit <= 0 {
limit = defaultHistory limit = 50
} }
if strings.HasPrefix(target, "#") { if strings.HasPrefix(target, "#") {
s.getChannelHistory( // Channel history
w, r, target, beforeID, limit,
)
} else {
s.getDMHistory(
w, r, uid, target, beforeID, limit,
)
}
}
}
func (s *Handlers) getChannelHistory(
w http.ResponseWriter,
r *http.Request,
target string,
beforeID int64,
limit int,
) {
var chID int64 var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query "SELECT id FROM channels WHERE name = ?", target).Scan(&chID)
r.Context(),
"SELECT id FROM channels WHERE name = ?",
target,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit)
msgs, err := s.params.Database.GetMessagesBefore(
r.Context(), chID, beforeID, limit,
)
if err != nil { if err != nil {
s.internalError(w, r, "get history failed", err) s.log.Error("get history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON(w, r, msgs, http.StatusOK) s.respondJSON(w, r, msgs, http.StatusOK)
} } else {
// DM history
func (s *Handlers) getDMHistory( targetID, err := s.params.Database.GetUserByNick(r.Context(), target)
w http.ResponseWriter,
r *http.Request,
uid int64,
target string,
beforeID int64,
limit int,
) {
targetID, err := s.params.Database.GetUserByNick(
r.Context(), target,
)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
w, r, "user not found", http.StatusNotFound,
)
return return
} }
msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit)
msgs, err := s.params.Database.GetDMsBefore(
r.Context(), uid, targetID, beforeID, limit,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("get dm history failed", "error", err)
w, r, "get dm history failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
s.respondJSON(w, r, msgs, http.StatusOK) s.respondJSON(w, r, msgs, http.StatusOK)
}
}
} }
// HandleServerInfo returns server metadata (MOTD, name). // HandleServerInfo returns server metadata (MOTD, name).
@@ -782,7 +391,6 @@ func (s *Handlers) HandleServerInfo() http.HandlerFunc {
Name string `json:"name"` Name string `json:"name"`
MOTD string `json:"motd"` MOTD string `json:"motd"`
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, r, &response{ s.respondJSON(w, r, &response{
Name: s.params.Config.ServerName, Name: s.params.Config.ServerName,

View File

@@ -1,26 +0,0 @@
package models
import (
"context"
"time"
)
// AuthToken represents an authentication token for a user session.
type AuthToken struct {
Base
Token string `json:"-"`
UserID string `json:"userId"`
CreatedAt time.Time `json:"createdAt"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastUsedAt *time.Time `json:"lastUsedAt,omitempty"`
}
// User returns the user who owns this token.
func (t *AuthToken) User(ctx context.Context) (*User, error) {
if ul := t.GetUserLookup(); ul != nil {
return ul.GetUserByID(ctx, t.UserID)
}
return nil, ErrUserLookupNotAvailable
}

View File

@@ -1,7 +1,6 @@
package models package models
import ( import (
"context"
"time" "time"
) )
@@ -9,88 +8,10 @@ import (
type Channel struct { type Channel struct {
Base Base
ID string `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Topic string `json:"topic"` Topic string `json:"topic"`
Modes string `json:"modes"` Modes string `json:"modes"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"` UpdatedAt time.Time `json:"updatedAt"`
} }
// Members returns all users who are members of this channel.
func (c *Channel) Members(ctx context.Context) ([]*ChannelMember, error) {
rows, err := c.GetDB().QueryContext(ctx, `
SELECT cm.channel_id, cm.user_id, cm.modes, cm.joined_at,
u.nick
FROM channel_members cm
JOIN users u ON u.id = cm.user_id
WHERE cm.channel_id = ?
ORDER BY cm.joined_at`,
c.ID,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
members := []*ChannelMember{}
for rows.Next() {
m := &ChannelMember{}
m.SetDB(c.db)
err = rows.Scan(
&m.ChannelID, &m.UserID, &m.Modes,
&m.JoinedAt, &m.Nick,
)
if err != nil {
return nil, err
}
members = append(members, m)
}
return members, rows.Err()
}
// RecentMessages returns the most recent messages in this channel.
func (c *Channel) RecentMessages(
ctx context.Context,
limit int,
) ([]*Message, error) {
rows, err := c.GetDB().QueryContext(ctx, `
SELECT id, ts, from_user_id, from_nick,
target, type, body, meta, created_at
FROM messages
WHERE target = ?
ORDER BY ts DESC
LIMIT ?`,
c.Name, limit,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
messages := []*Message{}
for rows.Next() {
msg := &Message{}
msg.SetDB(c.db)
err = rows.Scan(
&msg.ID, &msg.Timestamp, &msg.FromUserID,
&msg.FromNick, &msg.Target, &msg.Type,
&msg.Body, &msg.Meta, &msg.CreatedAt,
)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, rows.Err()
}

View File

@@ -1,35 +0,0 @@
package models
import (
"context"
"time"
)
// ChannelMember represents a user's membership in a channel.
type ChannelMember struct {
Base
ChannelID string `json:"channelId"`
UserID string `json:"userId"`
Modes string `json:"modes"`
JoinedAt time.Time `json:"joinedAt"`
Nick string `json:"nick"` // denormalized from users table
}
// User returns the full User for this membership.
func (cm *ChannelMember) User(ctx context.Context) (*User, error) {
if ul := cm.GetUserLookup(); ul != nil {
return ul.GetUserByID(ctx, cm.UserID)
}
return nil, ErrUserLookupNotAvailable
}
// Channel returns the full Channel for this membership.
func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) {
if cl := cm.GetChannelLookup(); cl != nil {
return cl.GetChannelByID(ctx, cm.ChannelID)
}
return nil, ErrChannelLookupNotAvailable
}

View File

@@ -1,20 +0,0 @@
package models
import (
"time"
)
// Message represents a chat message (channel or DM).
type Message struct {
Base
ID string `json:"id"`
Timestamp time.Time `json:"ts"`
FromUserID string `json:"fromUserId"`
FromNick string `json:"from"`
Target string `json:"to"`
Type string `json:"type"`
Body string `json:"body"`
Meta string `json:"meta"`
CreatedAt time.Time `json:"createdAt"`
}

View File

@@ -1,15 +0,0 @@
package models
import (
"time"
)
// MessageQueueEntry represents a pending message delivery for a user.
type MessageQueueEntry struct {
Base
ID int64 `json:"id"`
UserID string `json:"userId"`
MessageID string `json:"messageId"`
QueuedAt time.Time `json:"queuedAt"`
}

View File

@@ -1,36 +1,14 @@
// Package models defines the data models used by the chat application. // Package models defines the data models used by the chat application.
// All model structs embed Base, which provides database access for
// relation-fetching methods directly on model instances.
package models package models
import ( import "database/sql"
"context"
"database/sql"
"errors"
)
// DB is the interface that models use to query the database. // DB is the interface that models use to query relations.
// This avoids a circular import with the db package. // This avoids a circular import with the db package.
type DB interface { type DB interface {
GetDB() *sql.DB GetDB() *sql.DB
} }
// UserLookup provides user lookup by ID without circular imports.
type UserLookup interface {
GetUserByID(ctx context.Context, id string) (*User, error)
}
// ChannelLookup provides channel lookup by ID without circular imports.
type ChannelLookup interface {
GetChannelByID(ctx context.Context, id string) (*Channel, error)
}
// Sentinel errors for model lookup methods.
var (
ErrUserLookupNotAvailable = errors.New("user lookup not available")
ErrChannelLookupNotAvailable = errors.New("channel lookup not available")
)
// Base is embedded in all model structs to provide database access. // Base is embedded in all model structs to provide database access.
type Base struct { type Base struct {
db DB db DB
@@ -40,26 +18,3 @@ type Base struct {
func (b *Base) SetDB(d DB) { func (b *Base) SetDB(d DB) {
b.db = d b.db = d
} }
// GetDB returns the database interface for use in model methods.
func (b *Base) GetDB() *sql.DB {
return b.db.GetDB()
}
// GetUserLookup returns the DB as a UserLookup if it implements the interface.
func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn
if ul, ok := b.db.(UserLookup); ok {
return ul
}
return nil
}
// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface.
func (b *Base) GetChannelLookup() ChannelLookup { //nolint:ireturn
if cl, ok := b.db.(ChannelLookup); ok {
return cl
}
return nil
}

View File

@@ -1,18 +0,0 @@
package models
import (
"time"
)
// ServerLink represents a federation peer server configuration.
type ServerLink struct {
Base
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
SharedKeyHash string `json:"-"`
IsActive bool `json:"isActive"`
CreatedAt time.Time `json:"createdAt"`
LastSeenAt *time.Time `json:"lastSeenAt,omitempty"`
}

View File

@@ -1,26 +0,0 @@
package models
import (
"context"
"time"
)
// Session represents a server-held user session.
type Session struct {
Base
ID string `json:"id"`
UserID string `json:"userId"`
CreatedAt time.Time `json:"createdAt"`
LastActiveAt time.Time `json:"lastActiveAt"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
}
// User returns the user who owns this session.
func (s *Session) User(ctx context.Context) (*User, error) {
if ul := s.GetUserLookup(); ul != nil {
return ul.GetUserByID(ctx, s.UserID)
}
return nil, ErrUserLookupNotAvailable
}

View File

@@ -1,92 +0,0 @@
package models
import (
"context"
"time"
)
// User represents a registered user account.
type User struct {
Base
ID string `json:"id"`
Nick string `json:"nick"`
PasswordHash string `json:"-"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
LastSeenAt *time.Time `json:"lastSeenAt,omitempty"`
}
// Channels returns all channels the user is a member of.
func (u *User) Channels(ctx context.Context) ([]*Channel, error) {
rows, err := u.GetDB().QueryContext(ctx, `
SELECT c.id, c.name, c.topic, c.modes, c.created_at, c.updated_at
FROM channels c
JOIN channel_members cm ON cm.channel_id = c.id
WHERE cm.user_id = ?
ORDER BY c.name`,
u.ID,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
channels := []*Channel{}
for rows.Next() {
c := &Channel{}
c.SetDB(u.db)
err = rows.Scan(
&c.ID, &c.Name, &c.Topic, &c.Modes,
&c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, err
}
channels = append(channels, c)
}
return channels, rows.Err()
}
// QueuedMessages returns undelivered messages for this user.
func (u *User) QueuedMessages(ctx context.Context) ([]*Message, error) {
rows, err := u.GetDB().QueryContext(ctx, `
SELECT m.id, m.ts, m.from_user_id, m.from_nick,
m.target, m.type, m.body, m.meta, m.created_at
FROM messages m
JOIN message_queue mq ON mq.message_id = m.id
WHERE mq.user_id = ?
ORDER BY mq.queued_at ASC`,
u.ID,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
messages := []*Message{}
for rows.Next() {
msg := &Message{}
msg.SetDB(u.db)
err = rows.Scan(
&msg.ID, &msg.Timestamp, &msg.FromUserID,
&msg.FromNick, &msg.Target, &msg.Type,
&msg.Body, &msg.Meta, &msg.CreatedAt,
)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, rows.Err()
}

View File

@@ -71,37 +71,17 @@ func (s *Server) SetupRoutes() {
s.log.Error("failed to get web dist filesystem", "error", err) s.log.Error("failed to get web dist filesystem", "error", err)
} else { } else {
fileServer := http.FileServer(http.FS(distFS)) fileServer := http.FileServer(http.FS(distFS))
s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) { s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
s.serveSPA(distFS, fileServer, w, r) // Try to serve the file; if not found, serve index.html for SPA routing
}) f, err := distFS.(fs.ReadFileFS).ReadFile(r.URL.Path[1:])
}
}
func (s *Server) serveSPA(
distFS fs.FS,
fileServer http.Handler,
w http.ResponseWriter,
r *http.Request,
) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
http.Error(w, "filesystem error", http.StatusInternalServerError)
return
}
// Try to serve the file; fall back to index.html for SPA routing.
f, err := readFS.ReadFile(r.URL.Path[1:])
if err != nil || len(f) == 0 { if err != nil || len(f) == 0 {
indexHTML, _ := readFS.ReadFile("index.html") indexHTML, _ := distFS.(fs.ReadFileFS).ReadFile("index.html")
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(indexHTML) _, _ = w.Write(indexHTML)
return return
} }
fileServer.ServeHTTP(w, r) fileServer.ServeHTTP(w, r)
})
}
} }