12 Commits

Author SHA1 Message Date
clawbot
c7bcedd8d4 fix: pin golangci-lint install to commit SHA (fixes #13)
All checks were successful
check / check (push) Successful in 1m31s
- Fix Dockerfile: use correct v2 module path with commit SHA
  (github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638...)
- Add CGO_ENABLED=0 for alpine compatibility
- Fix 35 lint issues found by golangci-lint v2.1.6:
  - Remove unused nolint directives, add nolintlint where needed
  - Pre-allocate migrations slice
  - Use t.Context() instead of context.Background() in tests
  - Fix blank lines between comments and exported functions in ui.go
2026-02-26 20:20:31 -08:00
9daf836cbe Merge pull request 'fix: repo standards audit — fix all divergences (closes #17)' (#18) from fix/repo-standards-audit into main
Some checks failed
check / check (push) Failing after 12s
Reviewed-on: #18
2026-02-27 05:10:00 +01:00
84303c969a fix: pin golangci-lint to v2.1.6 in Dockerfile
Some checks failed
check / check (push) Failing after 14s
Replace @latest with @v2.1.6 to comply with hash-pinning policy
defined in REPO_POLICIES.md.
2026-02-26 11:43:52 -08:00
clawbot
d2bc467581 fix: resolve lint issues — rename api package, fix nolint directives
Some checks failed
check / check (push) Failing after 1m3s
2026-02-26 07:45:37 -08:00
clawbot
88af2ea98f fix: repair migration 003 schema conflict and rewrite tests (refs #17)
Some checks failed
check / check (push) Failing after 1m18s
Migration 003 created tables with INTEGER keys referencing TEXT primary
keys from migration 002, causing 'no such column' errors. Fix by
properly dropping old tables before recreating with the integer schema.

Rewrite all tests to use the queries.go API (which matches the live
schema) instead of the model-based API (which expected the old UUID
schema).
2026-02-26 06:28:07 -08:00
clawbot
b78d526f02 style: fix all golangci-lint issues and format code (refs #17)
Fix 380 lint violations across all Go source files including wsl_v5,
nlreturn, noinlineerr, errcheck, funlen, funcorder, tagliatelle,
perfsprint, modernize, revive, gosec, ireturn, mnd, forcetypeassert,
cyclop, and others.

Key changes:
- Split large handler/command functions into smaller methods
- Extract scan helpers for database queries
- Reorder exported/unexported methods per funcorder
- Add sentinel errors in models package
- Use camelCase JSON tags per tagliatelle defaults
- Add package comments
- Fix .gitignore to not exclude cmd/chat-cli directory
2026-02-26 06:27:56 -08:00
clawbot
636546d74a docs: add Author section to README (refs #17) 2026-02-26 06:09:08 -08:00
clawbot
27de1227c4 chore: pin Dockerfile images by sha256, run make check in build (refs #17) 2026-02-26 06:09:04 -08:00
clawbot
ef83d6624b chore: fix Makefile — add fmt-check, docker, hooks targets; 30s test timeout (refs #17) 2026-02-26 06:08:47 -08:00
clawbot
fc91dc37c0 chore: update .gitignore and .dockerignore to match standards (refs #17) 2026-02-26 06:08:31 -08:00
clawbot
1e5811edda chore: add missing required files (refs #17)
Add LICENSE (MIT), .editorconfig, REPO_POLICIES.md, and
.gitea/workflows/check.yml per repo standards.
2026-02-26 06:08:24 -08:00
clawbot
3f8ceefd52 fix: rename duplicate db methods to fix compilation (refs #17)
CreateUser, GetUserByNick, GetUserByToken exist in both db.go (model-based,
used by tests) and queries.go (simple, used by handlers). Rename the
model-based variants to CreateUserModel, GetUserByNickModel, and
GetUserByTokenModel to resolve the compilation error.
2026-02-26 06:08:07 -08:00
40 changed files with 3440 additions and 5834 deletions

View File

@@ -1,9 +1,9 @@
.git .git
*.md node_modules
!README.md .DS_Store
chatd bin/
chat-cli
data.db data.db
data.db-wal
data.db-shm
.env .env
*.test
*.out
debug.log

12
.editorconfig Normal file
View File

@@ -0,0 +1,12 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -0,0 +1,9 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-22
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

30
.gitignore vendored
View File

@@ -1,7 +1,28 @@
# OS
.DS_Store
Thumbs.db
# Editors
*.swp
*.swo
*~
*.bak
.idea/
.vscode/
*.sublime-*
# Node
node_modules/
# Environment / secrets
.env
.env.*
*.pem
*.key
# Build artifacts
/chatd /chatd
/bin/ /bin/
data.db
.env
*.exe *.exe
*.dll *.dll
*.so *.so
@@ -9,6 +30,9 @@ data.db
*.test *.test
*.out *.out
vendor/ vendor/
# Project
data.db
debug.log debug.log
/chat-cli
web/node_modules/ web/node_modules/
chat-cli

View File

@@ -7,35 +7,26 @@ run:
linters: linters:
default: all default: all
disable: disable:
- exhaustruct # Genuinely incompatible with project patterns
- depguard - exhaustruct # Requires all struct fields
- godot - depguard # Dependency allow/block lists
- wsl - godot # Requires comments to end with periods
- wsl_v5 - wsl # Deprecated, replaced by wsl_v5
- wrapcheck - wrapcheck # Too verbose for internal packages
- varnamelen - varnamelen # Short names like db, id are idiomatic Go
- noinlineerr
- dupl linters-settings:
- paralleltest lll:
- nlreturn line-length: 88
- tagliatelle funlen:
- goconst lines: 80
- funlen statements: 50
- maintidx cyclop:
- cyclop max-complexity: 15
- gocognit dupl:
- lll threshold: 100
settings:
lll:
line-length: 88
funlen:
lines: 80
statements: 50
cyclop:
max-complexity: 15
dupl:
threshold: 100
issues: issues:
exclude-use-default: false
max-issues-per-linter: 0 max-issues-per-linter: 0
max-same-issues: 0 max-same-issues: 0

View File

@@ -1,29 +1,29 @@
# Build stage # golang:1.24-alpine, 2026-02-26
FROM golang:1.24-alpine AS builder FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
WORKDIR /src
RUN apk add --no-cache make gcc musl-dev
RUN apk add --no-cache git build-base
WORKDIR /src
COPY go.mod go.sum ./ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
# Run tests # Run all checks — build fails if branch is not green
ENV DBURL="file::memory:?cache=shared" # golangci-lint v2.1.6
RUN go test ./... RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d
RUN make check
# Build binaries ARG VERSION=dev
RUN CGO_ENABLED=1 go build -trimpath -ldflags="-s -w" -o /chatd ./cmd/chatd/ RUN go build -ldflags "-X main.Version=${VERSION}" -o /chatd ./cmd/chatd
RUN CGO_ENABLED=1 go build -trimpath -ldflags="-s -w" -o /chat-cli ./cmd/chat-cli/
# Final stage — server only # alpine:3.21, 2026-02-26
FROM alpine:3.21 FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates \
&& addgroup -S chat && adduser -S chat -G chat RUN apk add --no-cache ca-certificates
COPY --from=builder /chatd /usr/local/bin/chatd COPY --from=builder /chatd /usr/local/bin/chatd
USER chat WORKDIR /data
EXPOSE 8080 EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget -qO- http://localhost:8080/.well-known/healthcheck.json || exit 1
ENTRYPOINT ["chatd"] ENTRYPOINT ["chatd"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 sneak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,20 +1,49 @@
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") .PHONY: all build lint fmt fmt-check test check clean run debug docker hooks
LDFLAGS := -ldflags "-X main.Version=$(VERSION)"
.PHONY: build test clean docker lint BINARY := chatd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILDARCH := $(shell go env GOARCH)
LDFLAGS := -X main.Version=$(VERSION) -X main.Buildarch=$(BUILDARCH)
all: check build
build: build:
go build $(LDFLAGS) -o chatd ./cmd/chatd/ go build -ldflags "$(LDFLAGS)" -o bin/$(BINARY) ./cmd/chatd
go build $(LDFLAGS) -o chat-cli ./cmd/chat-cli/
test:
DBURL="file::memory:?cache=shared" go test ./...
clean:
rm -f chatd chat-cli
lint: lint:
GOFLAGS=-buildvcs=false golangci-lint run ./... golangci-lint run --config .golangci.yml ./...
fmt:
gofmt -s -w .
goimports -w .
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test:
go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong
check: test lint fmt-check
@echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd
@echo "==> All checks passed!"
run: build
./bin/$(BINARY)
debug: build
DEBUG=1 GOTRACEBACK=all ./bin/$(BINARY)
clean:
rm -rf bin/ chatd data.db
docker: docker:
docker build -t chat:$(VERSION) . docker build -t chat .
hooks:
@printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit
@printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit
@printf 'make check\n' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit

2308
README.md

File diff suppressed because it is too large Load Diff

182
REPO_POLICIES.md Normal file
View File

@@ -0,0 +1,182 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary. Pre-1.0.0: modify existing migrations (no installed base assumed).
Post-1.0.0: add new migration files.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

View File

@@ -1,8 +1,8 @@
// Package chatapi provides a client for the chat server HTTP API.
package chatapi package chatapi
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -10,17 +10,21 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
) )
const ( const (
httpTimeout = 30 * time.Second httpTimeout = 30 * time.Second
pollExtraTime = 5 pollExtraDelay = 5
httpErrThreshold = 400 httpErrThreshold = 400
) )
var errHTTP = errors.New("HTTP error") // ErrHTTP is returned for non-2xx responses.
var ErrHTTP = errors.New("http error")
// ErrUnexpectedFormat is returned when the response format is
// not recognised.
var ErrUnexpectedFormat = errors.New("unexpected format")
// Client wraps HTTP calls to the chat server API. // Client wraps HTTP calls to the chat server API.
type Client struct { type Client struct {
@@ -32,8 +36,10 @@ type Client struct {
// NewClient creates a new API client. // NewClient creates a new API client.
func NewClient(baseURL string) *Client { func NewClient(baseURL string) *Client {
return &Client{ return &Client{
BaseURL: baseURL, BaseURL: baseURL,
HTTPClient: &http.Client{Timeout: httpTimeout}, HTTPClient: &http.Client{
Timeout: httpTimeout,
},
} }
} }
@@ -42,8 +48,7 @@ func (c *Client) CreateSession(
nick string, nick string,
) (*SessionResponse, error) { ) (*SessionResponse, error) {
data, err := c.do( data, err := c.do(
http.MethodPost, "POST", "/api/v1/session",
"/api/v1/session",
&SessionRequest{Nick: nick}, &SessionRequest{Nick: nick},
) )
if err != nil { if err != nil {
@@ -64,9 +69,7 @@ func (c *Client) CreateSession(
// GetState returns the current user state. // GetState returns the current user state.
func (c *Client) GetState() (*StateResponse, error) { func (c *Client) GetState() (*StateResponse, error) {
data, err := c.do( data, err := c.do("GET", "/api/v1/state", nil)
http.MethodGet, "/api/v1/state", nil,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -83,41 +86,36 @@ func (c *Client) GetState() (*StateResponse, error) {
// SendMessage sends a message (any IRC command). // SendMessage sends a message (any IRC command).
func (c *Client) SendMessage(msg *Message) error { func (c *Client) SendMessage(msg *Message) error {
_, err := c.do( _, err := c.do("POST", "/api/v1/messages", msg)
http.MethodPost, "/api/v1/messages", msg,
)
return err return err
} }
// PollMessages long-polls for new messages. // PollMessages long-polls for new messages.
func (c *Client) PollMessages( func (c *Client) PollMessages(
afterID int64, afterID string,
timeout int, timeout int,
) (*PollResult, error) { ) ([]Message, error) {
client := &http.Client{ pollTimeout := time.Duration(
Timeout: time.Duration( timeout+pollExtraDelay,
timeout+pollExtraTime, ) * time.Second
) * time.Second,
} client := &http.Client{Timeout: pollTimeout}
params := url.Values{} params := url.Values{}
if afterID > 0 { if afterID != "" {
params.Set( params.Set("after", afterID)
"after",
strconv.FormatInt(afterID, 10),
)
} }
params.Set("timeout", strconv.Itoa(timeout)) params.Set("timeout", strconv.Itoa(timeout))
path := "/api/v1/messages?" + params.Encode() path := "/api/v1/messages"
if len(params) > 0 {
path += "?" + params.Encode()
}
req, err := http.NewRequestWithContext( req, err := http.NewRequest( //nolint:noctx // CLI tool
context.Background(), http.MethodGet, c.BaseURL+path, nil,
http.MethodGet,
c.BaseURL+path,
nil,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -125,7 +123,7 @@ func (c *Client) PollMessages(
req.Header.Set("Authorization", "Bearer "+c.Token) req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := client.Do(req) //nolint:gosec // URL built from trusted BaseURL + hardcoded path resp, err := client.Do(req) //nolint:gosec,nolintlint // URL from user config
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -139,34 +137,45 @@ func (c *Client) PollMessages(
if resp.StatusCode >= httpErrThreshold { if resp.StatusCode >= httpErrThreshold {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"%w %d: %s", "%w: %d: %s",
errHTTP, resp.StatusCode, string(data), ErrHTTP, resp.StatusCode, string(data),
) )
} }
return decodeMessages(data)
}
func decodeMessages(data []byte) ([]Message, error) {
var msgs []Message
err := json.Unmarshal(data, &msgs)
if err == nil {
return msgs, nil
}
var wrapped MessagesResponse var wrapped MessagesResponse
err = json.Unmarshal(data, &wrapped) err2 := json.Unmarshal(data, &wrapped)
if err != nil { if err2 != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"decode messages: %w", err, "decode messages: %w (raw: %s)",
err, string(data),
) )
} }
return &PollResult{ return wrapped.Messages, nil
Messages: wrapped.Messages,
LastID: wrapped.LastID,
}, nil
} }
// JoinChannel joins a channel. // JoinChannel joins a channel via the unified command
// endpoint.
func (c *Client) JoinChannel(channel string) error { func (c *Client) JoinChannel(channel string) error {
return c.SendMessage( return c.SendMessage(
&Message{Command: "JOIN", To: channel}, &Message{Command: "JOIN", To: channel},
) )
} }
// PartChannel leaves a channel. // PartChannel leaves a channel via the unified command
// endpoint.
func (c *Client) PartChannel(channel string) error { func (c *Client) PartChannel(channel string) error {
return c.SendMessage( return c.SendMessage(
&Message{Command: "PART", To: channel}, &Message{Command: "PART", To: channel},
@@ -175,9 +184,7 @@ func (c *Client) PartChannel(channel string) error {
// ListChannels returns all channels on the server. // ListChannels returns all channels on the server.
func (c *Client) ListChannels() ([]Channel, error) { func (c *Client) ListChannels() ([]Channel, error) {
data, err := c.do( data, err := c.do("GET", "/api/v1/channels", nil)
http.MethodGet, "/api/v1/channels", nil,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -196,14 +203,10 @@ func (c *Client) ListChannels() ([]Channel, error) {
func (c *Client) GetMembers( func (c *Client) GetMembers(
channel string, channel string,
) ([]string, error) { ) ([]string, error) {
name := strings.TrimPrefix(channel, "#") path := "/api/v1/channels/" +
url.PathEscape(channel) + "/members"
data, err := c.do( data, err := c.do("GET", path, nil)
http.MethodGet,
"/api/v1/channels/"+url.PathEscape(name)+
"/members",
nil,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -213,7 +216,8 @@ func (c *Client) GetMembers(
err = json.Unmarshal(data, &members) err = json.Unmarshal(data, &members)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"unexpected members format: %w", err, "%w: members: %s",
ErrUnexpectedFormat, string(data),
) )
} }
@@ -222,9 +226,7 @@ func (c *Client) GetMembers(
// GetServerInfo returns server info. // GetServerInfo returns server info.
func (c *Client) GetServerInfo() (*ServerInfo, error) { func (c *Client) GetServerInfo() (*ServerInfo, error) {
data, err := c.do( data, err := c.do("GET", "/api/v1/server", nil)
http.MethodGet, "/api/v1/server", nil,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -254,11 +256,8 @@ func (c *Client) do(
bodyReader = bytes.NewReader(data) bodyReader = bytes.NewReader(data)
} }
req, err := http.NewRequestWithContext( req, err := http.NewRequest( //nolint:noctx // CLI tool
context.Background(), method, c.BaseURL+path, bodyReader,
method,
c.BaseURL+path,
bodyReader,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("request: %w", err) return nil, fmt.Errorf("request: %w", err)
@@ -267,12 +266,10 @@ func (c *Client) do(
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
if c.Token != "" { if c.Token != "" {
req.Header.Set( req.Header.Set("Authorization", "Bearer "+c.Token)
"Authorization", "Bearer "+c.Token,
)
} }
resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL built from trusted BaseURL + hardcoded path resp, err := c.HTTPClient.Do(req) //nolint:gosec,nolintlint // URL from user config
if err != nil { if err != nil {
return nil, fmt.Errorf("http: %w", err) return nil, fmt.Errorf("http: %w", err)
} }
@@ -286,8 +283,8 @@ func (c *Client) do(
if resp.StatusCode >= httpErrThreshold { if resp.StatusCode >= httpErrThreshold {
return data, fmt.Errorf( return data, fmt.Errorf(
"%w %d: %s", "%w: %d: %s",
errHTTP, resp.StatusCode, string(data), ErrHTTP, resp.StatusCode, string(data),
) )
} }

View File

@@ -1,4 +1,3 @@
// Package chatapi provides API types and client for chat-cli.
package chatapi package chatapi
import "time" import "time"
@@ -8,25 +7,27 @@ type SessionRequest struct {
Nick string `json:"nick"` Nick string `json:"nick"`
} }
// SessionResponse is the response from session creation. // SessionResponse is the response from POST /api/v1/session.
type SessionResponse struct { type SessionResponse struct {
ID int64 `json:"id"` SessionID string `json:"sessionId"`
Nick string `json:"nick"` ClientID string `json:"clientId"`
Token string `json:"token"` Nick string `json:"nick"`
Token string `json:"token"`
} }
// StateResponse is the response from GET /api/v1/state. // StateResponse is the response from GET /api/v1/state.
type StateResponse struct { type StateResponse struct {
ID int64 `json:"id"` SessionID string `json:"sessionId"`
Nick string `json:"nick"` ClientID string `json:"clientId"`
Channels []string `json:"channels"` Nick string `json:"nick"`
Channels []string `json:"channels"`
} }
// Message represents a chat message envelope. // Message represents a chat message envelope.
type Message struct { type Message struct {
Command string `json:"command"` Command string `json:"command"`
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
To string `json:"to,omitempty"` To string `json:"to,omitempty"`
Params []string `json:"params,omitempty"` Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"` Body any `json:"body,omitempty"`
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
@@ -34,7 +35,8 @@ type Message struct {
Meta any `json:"meta,omitempty"` Meta any `json:"meta,omitempty"`
} }
// BodyLines returns the body as a string slice. // BodyLines returns the body as a slice of strings (for text
// messages).
func (m *Message) BodyLines() []string { func (m *Message) BodyLines() []string {
switch v := m.Body.(type) { switch v := m.Body.(type) {
case []any: case []any:
@@ -72,13 +74,6 @@ type ServerInfo struct {
// MessagesResponse wraps polling results. // MessagesResponse wraps polling results.
type MessagesResponse struct { type MessagesResponse struct {
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
LastID int64 `json:"lastId"`
}
// PollResult wraps the poll response including the cursor.
type PollResult struct {
Messages []Message
LastID int64
} }
// ParseTS parses the message timestamp. // ParseTS parses the message timestamp.

View File

@@ -1,9 +1,10 @@
// Package main is the entry point for the chat-cli client. // Package main implements chat-cli, an IRC-style terminal client.
package main package main
import ( import (
"fmt" "fmt"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -12,10 +13,9 @@ import (
) )
const ( const (
splitParts = 2 pollTimeoutSec = 15
pollTimeout = 15 retryDelay = 2 * time.Second
pollRetry = 2 * time.Second maxNickLength = 32
timeFormat = "15:04"
) )
// App holds the application state. // App holds the application state.
@@ -25,9 +25,9 @@ type App struct {
mu sync.Mutex mu sync.Mutex
nick string nick string
target string target string // current target (#channel or nick for DM)
connected bool connected bool
lastQID int64 lastMsgID string
stopPoll chan struct{} stopPoll chan struct{}
} }
@@ -41,17 +41,17 @@ func main() {
app.ui.SetStatus(app.nick, "", "disconnected") app.ui.SetStatus(app.nick, "", "disconnected")
app.ui.AddStatus( app.ui.AddStatus(
"Welcome to chat-cli an IRC-style client", "Welcome to chat-cli \u2014 an IRC-style client",
) )
app.ui.AddStatus( app.ui.AddStatus(
"Type [yellow]/connect <server-url>" + "Type [yellow]/connect <server-url>[white] " +
"[white] to begin, " + "to begin, or [yellow]/help[white] for commands",
"or [yellow]/help[white] for commands",
) )
err := app.ui.Run() err := app.ui.Run()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) _, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
} }
@@ -63,9 +63,14 @@ func (a *App) handleInput(text string) {
return return
} }
a.sendPlainText(text)
}
func (a *App) sendPlainText(text string) {
a.mu.Lock() a.mu.Lock()
target := a.target target := a.target
connected := a.connected connected := a.connected
nick := a.nick
a.mu.Unlock() a.mu.Unlock()
if !connected { if !connected {
@@ -92,26 +97,25 @@ func (a *App) handleInput(text string) {
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(
"[red]Send error: " + err.Error(), fmt.Sprintf("[red]Send error: %v", err),
) )
return return
} }
ts := time.Now().Format(timeFormat) ts := time.Now().Format("15:04")
a.mu.Lock() a.ui.AddLine(
nick := a.nick target,
a.mu.Unlock() fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
a.ui.AddLine(target, fmt.Sprintf( ts, nick, text,
"[gray]%s [green]<%s>[white] %s", ),
ts, nick, text, )
))
} }
func (a *App) handleCommand(text string) { func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch
parts := strings.SplitN(text, " ", splitParts) parts := strings.SplitN(text, " ", 2) //nolint:mnd // split into cmd+args
cmd := strings.ToLower(parts[0]) cmd := strings.ToLower(parts[0])
args := "" args := ""
@@ -119,10 +123,6 @@ func (a *App) handleCommand(text string) {
args = parts[1] args = parts[1]
} }
a.dispatchCommand(cmd, args)
}
func (a *App) dispatchCommand(cmd, args string) {
switch cmd { switch cmd {
case "/connect": case "/connect":
a.cmdConnect(args) a.cmdConnect(args)
@@ -166,7 +166,9 @@ func (a *App) cmdConnect(serverURL string) {
serverURL = strings.TrimRight(serverURL, "/") serverURL = strings.TrimRight(serverURL, "/")
a.ui.AddStatus("Connecting to " + serverURL + "...") a.ui.AddStatus(
fmt.Sprintf("Connecting to %s...", serverURL),
)
a.mu.Lock() a.mu.Lock()
nick := a.nick nick := a.nick
@@ -176,9 +178,11 @@ func (a *App) cmdConnect(serverURL string) {
resp, err := client.CreateSession(nick) resp, err := client.CreateSession(nick)
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Connection failed: %v", err, fmt.Sprintf(
)) "[red]Connection failed: %v", err,
),
)
return return
} }
@@ -187,13 +191,15 @@ func (a *App) cmdConnect(serverURL string) {
a.client = client a.client = client
a.nick = resp.Nick a.nick = resp.Nick
a.connected = true a.connected = true
a.lastQID = 0 a.lastMsgID = ""
a.mu.Unlock() a.mu.Unlock()
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[green]Connected! Nick: %s, Session: %d", fmt.Sprintf(
resp.Nick, resp.ID, "[green]Connected! Nick: %s, Session: %s",
)) resp.Nick, resp.SessionID,
),
)
a.ui.SetStatus(resp.Nick, "", "connected") a.ui.SetStatus(resp.Nick, "", "connected")
a.stopPoll = make(chan struct{}) a.stopPoll = make(chan struct{})
@@ -203,9 +209,7 @@ func (a *App) cmdConnect(serverURL string) {
func (a *App) cmdNick(nick string) { func (a *App) cmdNick(nick string) {
if nick == "" { if nick == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /nick <name>")
"[red]Usage: /nick <name>",
)
return return
} }
@@ -220,8 +224,10 @@ func (a *App) cmdNick(nick string) {
a.mu.Unlock() a.mu.Unlock()
a.ui.AddStatus( a.ui.AddStatus(
"Nick set to " + nick + fmt.Sprintf(
" (will be used on connect)", "Nick set to %s (will be used on connect)",
nick,
),
) )
return return
@@ -232,9 +238,11 @@ func (a *App) cmdNick(nick string) {
Body: []string{nick}, Body: []string{nick},
}) })
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Nick change failed: %v", err, fmt.Sprintf(
)) "[red]Nick change failed: %v", err,
),
)
return return
} }
@@ -245,14 +253,14 @@ func (a *App) cmdNick(nick string) {
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus(nick, target, "connected") a.ui.SetStatus(nick, target, "connected")
a.ui.AddStatus("Nick changed to " + nick) a.ui.AddStatus(
"Nick changed to " + nick,
)
} }
func (a *App) cmdJoin(channel string) { func (a *App) cmdJoin(channel string) {
if channel == "" { if channel == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /join #channel")
"[red]Usage: /join #channel",
)
return return
} }
@@ -273,9 +281,9 @@ func (a *App) cmdJoin(channel string) {
err := a.client.JoinChannel(channel) err := a.client.JoinChannel(channel)
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Join failed: %v", err, fmt.Sprintf("[red]Join failed: %v", err),
)) )
return return
} }
@@ -286,7 +294,8 @@ func (a *App) cmdJoin(channel string) {
a.mu.Unlock() a.mu.Unlock()
a.ui.SwitchToBuffer(channel) a.ui.SwitchToBuffer(channel)
a.ui.AddLine(channel, a.ui.AddLine(
channel,
"[yellow]*** Joined "+channel, "[yellow]*** Joined "+channel,
) )
a.ui.SetStatus(nick, channel, "connected") a.ui.SetStatus(nick, channel, "connected")
@@ -294,6 +303,7 @@ func (a *App) cmdJoin(channel string) {
func (a *App) cmdPart(channel string) { func (a *App) cmdPart(channel string) {
a.mu.Lock() a.mu.Lock()
if channel == "" { if channel == "" {
channel = a.target channel = a.target
} }
@@ -301,8 +311,7 @@ func (a *App) cmdPart(channel string) {
connected := a.connected connected := a.connected
a.mu.Unlock() a.mu.Unlock()
if channel == "" || if channel == "" || !strings.HasPrefix(channel, "#") {
!strings.HasPrefix(channel, "#") {
a.ui.AddStatus("[red]No channel to part") a.ui.AddStatus("[red]No channel to part")
return return
@@ -316,18 +325,20 @@ func (a *App) cmdPart(channel string) {
err := a.client.PartChannel(channel) err := a.client.PartChannel(channel)
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Part failed: %v", err, fmt.Sprintf("[red]Part failed: %v", err),
)) )
return return
} }
a.ui.AddLine(channel, a.ui.AddLine(
channel,
"[yellow]*** Left "+channel, "[yellow]*** Left "+channel,
) )
a.mu.Lock() a.mu.Lock()
if a.target == channel { if a.target == channel {
a.target = "" a.target = ""
} }
@@ -340,11 +351,9 @@ func (a *App) cmdPart(channel string) {
} }
func (a *App) cmdMsg(args string) { func (a *App) cmdMsg(args string) {
parts := strings.SplitN(args, " ", splitParts) parts := strings.SplitN(args, " ", 2) //nolint:mnd // split into target+text
if len(parts) < splitParts { if len(parts) < 2 { //nolint:mnd // min args
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /msg <nick> <text>")
"[red]Usage: /msg <nick> <text>",
)
return return
} }
@@ -368,26 +377,27 @@ func (a *App) cmdMsg(args string) {
Body: []string{text}, Body: []string{text},
}) })
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Send failed: %v", err, fmt.Sprintf("[red]Send failed: %v", err),
)) )
return return
} }
ts := time.Now().Format(timeFormat) ts := time.Now().Format("15:04")
a.ui.AddLine(target, fmt.Sprintf( a.ui.AddLine(
"[gray]%s [green]<%s>[white] %s", target,
ts, nick, text, fmt.Sprintf(
)) "[gray]%s [green]<%s>[white] %s",
ts, nick, text,
),
)
} }
func (a *App) cmdQuery(nick string) { func (a *App) cmdQuery(nick string) {
if nick == "" { if nick == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /query <nick>")
"[red]Usage: /query <nick>",
)
return return
} }
@@ -425,9 +435,11 @@ func (a *App) cmdTopic(args string) {
To: target, To: target,
}) })
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Topic query failed: %v", err, fmt.Sprintf(
)) "[red]Topic query failed: %v", err,
),
)
} }
return return
@@ -439,9 +451,11 @@ func (a *App) cmdTopic(args string) {
Body: []string{args}, Body: []string{args},
}) })
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Topic set failed: %v", err, fmt.Sprintf(
)) "[red]Topic set failed: %v", err,
),
)
} }
} }
@@ -465,17 +479,20 @@ func (a *App) cmdNames() {
members, err := a.client.GetMembers(target) members, err := a.client.GetMembers(target)
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]Names failed: %v", err, fmt.Sprintf("[red]Names failed: %v", err),
)) )
return return
} }
a.ui.AddLine(target, fmt.Sprintf( a.ui.AddLine(
"[cyan]*** Members of %s: %s", target,
target, strings.Join(members, " "), fmt.Sprintf(
)) "[cyan]*** Members of %s: %s",
target, strings.Join(members, " "),
),
)
} }
func (a *App) cmdList() { func (a *App) cmdList() {
@@ -491,9 +508,9 @@ func (a *App) cmdList() {
channels, err := a.client.ListChannels() channels, err := a.client.ListChannels()
if err != nil { if err != nil {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[red]List failed: %v", err, fmt.Sprintf("[red]List failed: %v", err),
)) )
return return
} }
@@ -501,10 +518,12 @@ func (a *App) cmdList() {
a.ui.AddStatus("[cyan]*** Channel list:") a.ui.AddStatus("[cyan]*** Channel list:")
for _, ch := range channels { for _, ch := range channels {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
" %s (%d members) %s", fmt.Sprintf(
ch.Name, ch.Members, ch.Topic, " %s (%d members) %s",
)) ch.Name, ch.Members, ch.Topic,
),
)
} }
a.ui.AddStatus("[cyan]*** End of channel list") a.ui.AddStatus("[cyan]*** End of channel list")
@@ -512,17 +531,12 @@ func (a *App) cmdList() {
func (a *App) cmdWindow(args string) { func (a *App) cmdWindow(args string) {
if args == "" { if args == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /window <number>")
"[red]Usage: /window <number>",
)
return return
} }
var n int n, _ := strconv.Atoi(args)
_, _ = fmt.Sscanf(args, "%d", &n)
a.ui.SwitchBuffer(n) a.ui.SwitchBuffer(n)
a.mu.Lock() a.mu.Lock()
@@ -531,14 +545,13 @@ func (a *App) cmdWindow(args string) {
if n >= 0 && n < a.ui.BufferCount() { if n >= 0 && n < a.ui.BufferCount() {
buf := a.ui.buffers[n] buf := a.ui.buffers[n]
if buf.Name != "(status)" { if buf.Name != "(status)" {
a.mu.Lock() a.mu.Lock()
a.target = buf.Name a.target = buf.Name
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus( a.ui.SetStatus(nick, buf.Name, "connected")
nick, buf.Name, "connected",
)
} else { } else {
a.ui.SetStatus(nick, "", "connected") a.ui.SetStatus(nick, "", "connected")
} }
@@ -565,18 +578,18 @@ func (a *App) cmdQuit() {
func (a *App) cmdHelp() { func (a *App) cmdHelp() {
help := []string{ help := []string{
"[cyan]*** chat-cli commands:", "[cyan]*** chat-cli commands:",
" /connect <url> Connect to server", " /connect <url> \u2014 Connect to server",
" /nick <name> Change nickname", " /nick <name> \u2014 Change nickname",
" /join #channel Join channel", " /join #channel \u2014 Join channel",
" /part [#chan] Leave channel", " /part [#chan] \u2014 Leave channel",
" /msg <nick> <text> Send DM", " /msg <nick> <text> \u2014 Send DM",
" /query <nick> Open DM window", " /query <nick> \u2014 Open DM window",
" /topic [text] View/set topic", " /topic [text] \u2014 View/set topic",
" /names List channel members", " /names \u2014 List channel members",
" /list List channels", " /list \u2014 List channels",
" /window <n> Switch buffer", " /window <n> \u2014 Switch buffer (Alt+0-9)",
" /quit Disconnect and exit", " /quit \u2014 Disconnect and exit",
" /help This help", " /help \u2014 This help",
" Plain text sends to current target.", " Plain text sends to current target.",
} }
@@ -596,36 +609,38 @@ func (a *App) pollLoop() {
a.mu.Lock() a.mu.Lock()
client := a.client client := a.client
lastQID := a.lastQID lastID := a.lastMsgID
a.mu.Unlock() a.mu.Unlock()
if client == nil { if client == nil {
return return
} }
result, err := client.PollMessages( msgs, err := client.PollMessages(
lastQID, pollTimeout, lastID, pollTimeoutSec,
) )
if err != nil { if err != nil {
time.Sleep(pollRetry) time.Sleep(retryDelay)
continue continue
} }
if result.LastID > 0 { for i := range msgs {
a.mu.Lock() a.handleServerMessage(&msgs[i])
a.lastQID = result.LastID
a.mu.Unlock()
}
for i := range result.Messages { if msgs[i].ID != "" {
a.handleServerMessage(&result.Messages[i]) a.mu.Lock()
a.lastMsgID = msgs[i].ID
a.mu.Unlock()
}
} }
} }
} }
func (a *App) handleServerMessage(msg *api.Message) { func (a *App) handleServerMessage(
ts := a.formatTS(msg) msg *api.Message,
) {
ts := a.parseMessageTS(msg)
a.mu.Lock() a.mu.Lock()
myNick := a.nick myNick := a.nick
@@ -633,34 +648,37 @@ func (a *App) handleServerMessage(msg *api.Message) {
switch msg.Command { switch msg.Command {
case "PRIVMSG": case "PRIVMSG":
a.handlePrivmsgEvent(msg, ts, myNick) a.handlePrivmsgMsg(msg, ts, myNick)
case "JOIN": case "JOIN":
a.handleJoinEvent(msg, ts) a.handleJoinMsg(msg, ts)
case "PART": case "PART":
a.handlePartEvent(msg, ts) a.handlePartMsg(msg, ts)
case "QUIT": case "QUIT":
a.handleQuitEvent(msg, ts) a.handleQuitMsg(msg, ts)
case "NICK": case "NICK":
a.handleNickEvent(msg, ts, myNick) a.handleNickMsg(msg, ts, myNick)
case "NOTICE": case "NOTICE":
a.handleNoticeEvent(msg, ts) a.handleNoticeMsg(msg, ts)
case "TOPIC": case "TOPIC":
a.handleTopicEvent(msg, ts) a.handleTopicMsg(msg, ts)
default: default:
a.handleDefaultEvent(msg, ts) a.handleDefaultMsg(msg, ts)
} }
} }
func (a *App) formatTS(msg *api.Message) string { func (a *App) parseMessageTS(msg *api.Message) string {
if msg.TS != "" { if msg.TS != "" {
return msg.ParseTS().UTC().Format(timeFormat) t := msg.ParseTS()
return t.In(time.Local).Format("15:04") //nolint:gosmopolitan // CLI uses local time
} }
return time.Now().Format(timeFormat) return time.Now().Format("15:04")
} }
func (a *App) handlePrivmsgEvent( func (a *App) handlePrivmsgMsg(
msg *api.Message, ts, myNick string, msg *api.Message,
ts, myNick string,
) { ) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
@@ -674,29 +692,37 @@ func (a *App) handlePrivmsgEvent(
target = msg.From target = msg.From
} }
a.ui.AddLine(target, fmt.Sprintf( a.ui.AddLine(
"[gray]%s [green]<%s>[white] %s", target,
ts, msg.From, text, fmt.Sprintf(
)) "[gray]%s [green]<%s>[white] %s",
ts, msg.From, text,
),
)
} }
func (a *App) handleJoinEvent( func (a *App) handleJoinMsg(
msg *api.Message, ts string, msg *api.Message, ts string,
) { ) {
if msg.To == "" { target := msg.To
if target == "" {
return return
} }
a.ui.AddLine(msg.To, fmt.Sprintf( a.ui.AddLine(
"[gray]%s [yellow]*** %s has joined %s", target,
ts, msg.From, msg.To, fmt.Sprintf(
)) "[gray]%s [yellow]*** %s has joined %s",
ts, msg.From, target,
),
)
} }
func (a *App) handlePartEvent( func (a *App) handlePartMsg(
msg *api.Message, ts string, msg *api.Message, ts string,
) { ) {
if msg.To == "" { target := msg.To
if target == "" {
return return
} }
@@ -704,38 +730,48 @@ func (a *App) handlePartEvent(
reason := strings.Join(lines, " ") reason := strings.Join(lines, " ")
if reason != "" { if reason != "" {
a.ui.AddLine(msg.To, fmt.Sprintf( a.ui.AddLine(
"[gray]%s [yellow]*** %s has left %s (%s)", target,
ts, msg.From, msg.To, reason, fmt.Sprintf(
)) "[gray]%s [yellow]*** %s has left %s (%s)",
ts, msg.From, target, reason,
),
)
} else { } else {
a.ui.AddLine(msg.To, fmt.Sprintf( a.ui.AddLine(
"[gray]%s [yellow]*** %s has left %s", target,
ts, msg.From, msg.To, fmt.Sprintf(
)) "[gray]%s [yellow]*** %s has left %s",
ts, msg.From, target,
),
)
} }
} }
func (a *App) handleQuitEvent( func (a *App) handleQuitMsg(
msg *api.Message, ts string, msg *api.Message, ts string,
) { ) {
lines := msg.BodyLines() lines := msg.BodyLines()
reason := strings.Join(lines, " ") reason := strings.Join(lines, " ")
if reason != "" { if reason != "" {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[gray]%s [yellow]*** %s has quit (%s)", fmt.Sprintf(
ts, msg.From, reason, "[gray]%s [yellow]*** %s has quit (%s)",
)) ts, msg.From, reason,
),
)
} else { } else {
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[gray]%s [yellow]*** %s has quit", fmt.Sprintf(
ts, msg.From, "[gray]%s [yellow]*** %s has quit",
)) ts, msg.From,
),
)
} }
} }
func (a *App) handleNickEvent( func (a *App) handleNickMsg(
msg *api.Message, ts, myNick string, msg *api.Message, ts, myNick string,
) { ) {
lines := msg.BodyLines() lines := msg.BodyLines()
@@ -748,32 +784,35 @@ func (a *App) handleNickEvent(
if msg.From == myNick && newNick != "" { if msg.From == myNick && newNick != "" {
a.mu.Lock() a.mu.Lock()
a.nick = newNick a.nick = newNick
target := a.target target := a.target
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus(newNick, target, "connected") a.ui.SetStatus(newNick, target, "connected")
} }
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[gray]%s [yellow]*** %s is now known as %s", fmt.Sprintf(
ts, msg.From, newNick, "[gray]%s [yellow]*** %s is now known as %s",
)) ts, msg.From, newNick,
),
)
} }
func (a *App) handleNoticeEvent( func (a *App) handleNoticeMsg(
msg *api.Message, ts string, msg *api.Message, ts string,
) { ) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
a.ui.AddStatus(fmt.Sprintf( a.ui.AddStatus(
"[gray]%s [magenta]--%s-- %s", fmt.Sprintf(
ts, msg.From, text, "[gray]%s [magenta]--%s-- %s",
)) ts, msg.From, text,
),
)
} }
func (a *App) handleTopicEvent( func (a *App) handleTopicMsg(
msg *api.Message, ts string, msg *api.Message, ts string,
) { ) {
if msg.To == "" { if msg.To == "" {
@@ -783,22 +822,29 @@ func (a *App) handleTopicEvent(
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
a.ui.AddLine(msg.To, fmt.Sprintf( a.ui.AddLine(
"[gray]%s [cyan]*** %s set topic: %s", msg.To,
ts, msg.From, text, fmt.Sprintf(
)) "[gray]%s [cyan]*** %s set topic: %s",
ts, msg.From, text,
),
)
} }
func (a *App) handleDefaultEvent( func (a *App) handleDefaultMsg(
msg *api.Message, ts string, msg *api.Message, ts string,
) { ) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
if text != "" { if text == "" {
a.ui.AddStatus(fmt.Sprintf( return
}
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [white][%s] %s", "[gray]%s [white][%s] %s",
ts, msg.Command, text, ts, msg.Command, text,
)) ),
} )
} }

View File

@@ -39,19 +39,11 @@ func NewUI() *UI {
}, },
} }
ui.initMessages() ui.setupMessages()
ui.initStatusBar() ui.setupStatusBar()
ui.initInput() ui.setupInput()
ui.initKeyCapture() ui.setupKeybindings()
ui.setupLayout()
ui.layout = tview.NewFlex().
SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true)
ui.app.SetRoot(ui.layout, true)
ui.app.SetFocus(ui.input)
return ui return ui
} }
@@ -72,29 +64,32 @@ func (ui *UI) OnInput(fn func(string)) {
} }
// AddLine adds a line to the specified buffer. // AddLine adds a line to the specified buffer.
func (ui *UI) AddLine(bufferName, line string) { func (ui *UI) AddLine(bufferName string, line string) {
ui.app.QueueUpdateDraw(func() { ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(bufferName) buf := ui.getOrCreateBuffer(bufferName)
buf.Lines = append(buf.Lines, line) buf.Lines = append(buf.Lines, line)
cur := ui.buffers[ui.currentBuffer] // Mark unread if not currently viewing this buffer.
if cur != buf { if ui.buffers[ui.currentBuffer] != buf {
buf.Unread++ buf.Unread++
ui.refreshStatusBar()
ui.refreshStatus()
} }
if cur == buf { // If viewing this buffer, append to display.
if ui.buffers[ui.currentBuffer] == buf {
_, _ = fmt.Fprintln(ui.messages, line) _, _ = fmt.Fprintln(ui.messages, line)
} }
}) })
} }
// AddStatus adds a line to the status buffer. // AddStatus adds a line to the status buffer (buffer 0).
func (ui *UI) AddStatus(line string) { func (ui *UI) AddStatus(line string) {
ts := time.Now().Format("15:04") ts := time.Now().Format("15:04")
ui.AddLine( ui.AddLine(
"(status)", "(status)",
"[gray]"+ts+"[white] "+line, fmt.Sprintf("[gray]%s[white] %s", ts, line),
) )
} }
@@ -117,12 +112,11 @@ func (ui *UI) SwitchBuffer(n int) {
} }
ui.messages.ScrollToEnd() ui.messages.ScrollToEnd()
ui.refreshStatusBar() ui.refreshStatus()
}) })
} }
// SwitchToBuffer switches to named buffer, creating if // SwitchToBuffer switches to the named buffer, creating it
// needed.
func (ui *UI) SwitchToBuffer(name string) { func (ui *UI) SwitchToBuffer(name string) {
ui.app.QueueUpdateDraw(func() { ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(name) buf := ui.getOrCreateBuffer(name)
@@ -144,7 +138,7 @@ func (ui *UI) SwitchToBuffer(name string) {
} }
ui.messages.ScrollToEnd() ui.messages.ScrollToEnd()
ui.refreshStatusBar() ui.refreshStatus()
}) })
} }
@@ -153,7 +147,7 @@ func (ui *UI) SetStatus(
nick, target, connStatus string, nick, target, connStatus string,
) { ) {
ui.app.QueueUpdateDraw(func() { ui.app.QueueUpdateDraw(func() {
ui.renderStatusBar(nick, target, connStatus) ui.refreshStatusWith(nick, target, connStatus)
}) })
} }
@@ -162,7 +156,7 @@ func (ui *UI) BufferCount() int {
return len(ui.buffers) return len(ui.buffers)
} }
// BufferIndex returns the index of a named buffer. // BufferIndex returns the index of a named buffer, or -1.
func (ui *UI) BufferIndex(name string) int { func (ui *UI) BufferIndex(name string) int {
for i, buf := range ui.buffers { for i, buf := range ui.buffers {
if buf.Name == name { if buf.Name == name {
@@ -173,7 +167,7 @@ func (ui *UI) BufferIndex(name string) int {
return -1 return -1
} }
func (ui *UI) initMessages() { func (ui *UI) setupMessages() {
ui.messages = tview.NewTextView(). ui.messages = tview.NewTextView().
SetDynamicColors(true). SetDynamicColors(true).
SetScrollable(true). SetScrollable(true).
@@ -184,14 +178,14 @@ func (ui *UI) initMessages() {
ui.messages.SetBorder(false) ui.messages.SetBorder(false)
} }
func (ui *UI) initStatusBar() { func (ui *UI) setupStatusBar() {
ui.statusBar = tview.NewTextView(). ui.statusBar = tview.NewTextView().
SetDynamicColors(true) SetDynamicColors(true)
ui.statusBar.SetBackgroundColor(tcell.ColorNavy) ui.statusBar.SetBackgroundColor(tcell.ColorNavy)
ui.statusBar.SetTextColor(tcell.ColorWhite) ui.statusBar.SetTextColor(tcell.ColorWhite)
} }
func (ui *UI) initInput() { func (ui *UI) setupInput() {
ui.input = tview.NewInputField(). ui.input = tview.NewInputField().
SetFieldBackgroundColor(tcell.ColorBlack). SetFieldBackgroundColor(tcell.ColorBlack).
SetFieldTextColor(tcell.ColorWhite) SetFieldTextColor(tcell.ColorWhite)
@@ -214,7 +208,7 @@ func (ui *UI) initInput() {
}) })
} }
func (ui *UI) initKeyCapture() { func (ui *UI) setupKeybindings() {
ui.app.SetInputCapture( ui.app.SetInputCapture(
func(event *tcell.EventKey) *tcell.EventKey { func(event *tcell.EventKey) *tcell.EventKey {
if event.Modifiers()&tcell.ModAlt == 0 { if event.Modifiers()&tcell.ModAlt == 0 {
@@ -234,21 +228,34 @@ func (ui *UI) initKeyCapture() {
) )
} }
func (ui *UI) refreshStatusBar() { func (ui *UI) setupLayout() {
// Placeholder; full refresh needs nick/target context. ui.layout = tview.NewFlex().
SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true)
ui.app.SetRoot(ui.layout, true)
ui.app.SetFocus(ui.input)
} }
func (ui *UI) renderStatusBar( // if needed.
func (ui *UI) refreshStatus() {
// Rebuilt from app state by parent QueueUpdateDraw.
}
func (ui *UI) refreshStatusWith(
nick, target, connStatus string, nick, target, connStatus string,
) { ) {
var unreadParts []string var unreadParts []string
for i, buf := range ui.buffers { for i, buf := range ui.buffers {
if buf.Unread > 0 { if buf.Unread > 0 {
unreadParts = append(unreadParts, unreadParts = append(
unreadParts,
fmt.Sprintf( fmt.Sprintf(
"%d:%s(%d)", "%d:%s(%d)", i, buf.Name, buf.Unread,
i, buf.Name, buf.Unread,
), ),
) )
} }
@@ -268,8 +275,8 @@ func (ui *UI) renderStatusBar(
ui.statusBar.Clear() ui.statusBar.Clear()
_, _ = fmt.Fprintf(ui.statusBar, _, _ = fmt.Fprintf(
" [%s] %s %s %s%s", ui.statusBar, " [%s] %s %s %s%s",
connStatus, nick, bufInfo, target, unread, connStatus, nick, bufInfo, target, unread,
) )
} }

View File

@@ -1,65 +0,0 @@
// Package broker provides an in-memory pub/sub for long-poll notifications.
package broker
import (
"sync"
)
// Broker notifies waiting clients when new messages are available.
type Broker struct {
mu sync.Mutex
listeners map[int64][]chan struct{} // userID -> list of waiting channels
}
// New creates a new Broker.
func New() *Broker {
return &Broker{
listeners: make(map[int64][]chan struct{}),
}
}
// Wait returns a channel that will be closed when a message is available for the user.
func (b *Broker) Wait(userID int64) chan struct{} {
ch := make(chan struct{}, 1)
b.mu.Lock()
b.listeners[userID] = append(b.listeners[userID], ch)
b.mu.Unlock()
return ch
}
// Notify wakes up all waiting clients for a user.
func (b *Broker) Notify(userID int64) {
b.mu.Lock()
waiters := b.listeners[userID]
delete(b.listeners, userID)
b.mu.Unlock()
for _, ch := range waiters {
select {
case ch <- struct{}{}:
default:
}
}
}
// Remove removes a specific wait channel (for cleanup on timeout).
func (b *Broker) Remove(userID int64, ch chan struct{}) {
b.mu.Lock()
defer b.mu.Unlock()
waiters := b.listeners[userID]
for i, w := range waiters {
if w == ch {
b.listeners[userID] = append(waiters[:i], waiters[i+1:]...)
break
}
}
if len(b.listeners[userID]) == 0 {
delete(b.listeners, userID)
}
}

View File

@@ -1,118 +0,0 @@
package broker_test
import (
"sync"
"testing"
"time"
"git.eeqj.de/sneak/chat/internal/broker"
)
func TestNewBroker(t *testing.T) {
t.Parallel()
b := broker.New()
if b == nil {
t.Fatal("expected non-nil broker")
}
}
func TestWaitAndNotify(t *testing.T) {
t.Parallel()
b := broker.New()
ch := b.Wait(1)
go func() {
time.Sleep(10 * time.Millisecond)
b.Notify(1)
}()
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal("timeout")
}
}
func TestNotifyWithoutWaiters(t *testing.T) {
t.Parallel()
b := broker.New()
b.Notify(42) // should not panic
}
func TestRemove(t *testing.T) {
t.Parallel()
b := broker.New()
ch := b.Wait(1)
b.Remove(1, ch)
b.Notify(1)
select {
case <-ch:
t.Fatal("should not receive after remove")
case <-time.After(50 * time.Millisecond):
}
}
func TestMultipleWaiters(t *testing.T) {
t.Parallel()
b := broker.New()
ch1 := b.Wait(1)
ch2 := b.Wait(1)
b.Notify(1)
select {
case <-ch1:
case <-time.After(time.Second):
t.Fatal("ch1 timeout")
}
select {
case <-ch2:
case <-time.After(time.Second):
t.Fatal("ch2 timeout")
}
}
func TestConcurrentWaitNotify(t *testing.T) {
t.Parallel()
b := broker.New()
var wg sync.WaitGroup
const concurrency = 100
for i := range concurrency {
wg.Add(1)
go func(uid int64) {
defer wg.Done()
ch := b.Wait(uid)
b.Notify(uid)
select {
case <-ch:
case <-time.After(time.Second):
t.Error("timeout")
}
}(int64(i % 10))
}
wg.Wait()
}
func TestRemoveNonexistent(t *testing.T) {
t.Parallel()
b := broker.New()
ch := make(chan struct{}, 1)
b.Remove(999, ch) // should not panic
}

View File

@@ -11,16 +11,20 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"time"
"git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/config"
"git.eeqj.de/sneak/chat/internal/logger" "git.eeqj.de/sneak/chat/internal/logger"
"git.eeqj.de/sneak/chat/internal/models"
"go.uber.org/fx" "go.uber.org/fx"
_ "github.com/joho/godotenv/autoload" // .env _ "github.com/joho/godotenv/autoload" // loads .env file
_ "modernc.org/sqlite" // driver _ "modernc.org/sqlite" // SQLite driver
) )
const minMigrationParts = 2 const (
minMigrationParts = 2
)
// SchemaFiles contains embedded SQL migration files. // SchemaFiles contains embedded SQL migration files.
// //
@@ -35,18 +39,15 @@ type Params struct {
Config *config.Config Config *config.Config
} }
// Database manages the SQLite connection and migrations. // Database manages the SQLite database connection and migrations.
type Database struct { type Database struct {
db *sql.DB db *sql.DB
log *slog.Logger log *slog.Logger
params *Params params *Params
} }
// New creates a new Database and registers lifecycle hooks. // New creates a new Database instance and registers lifecycle hooks.
func New( func New(lc fx.Lifecycle, params Params) (*Database, error) {
lc fx.Lifecycle,
params Params,
) (*Database, error) {
s := new(Database) s := new(Database)
s.params = &params s.params = &params
s.log = params.Logger.Get() s.log = params.Logger.Get()
@@ -73,11 +74,460 @@ func New(
return s, nil return s, nil
} }
// NewTest creates a Database for testing, bypassing fx lifecycle.
// It connects to the given DSN and runs all migrations.
func NewTest(dsn string) (*Database, error) {
d, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, err
}
s := &Database{
db: d,
log: slog.Default(),
}
// Item 9: Enable foreign keys
_, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx,nolintlint // no context in sql.Open path
if err != nil {
_ = d.Close()
return nil, fmt.Errorf("enable foreign keys: %w", err)
}
ctx := context.Background()
err = s.runMigrations(ctx)
if err != nil {
_ = d.Close()
return nil, err
}
return s, nil
}
// GetDB returns the underlying sql.DB connection. // GetDB returns the underlying sql.DB connection.
func (s *Database) GetDB() *sql.DB { func (s *Database) GetDB() *sql.DB {
return s.db return s.db
} }
// Hydrate injects the database reference into any model that
// embeds Base.
func (s *Database) Hydrate(m interface{ SetDB(d models.DB) }) {
m.SetDB(s)
}
// GetUserByID looks up a user by their ID.
func (s *Database) GetUserByID(
ctx context.Context,
id string,
) (*models.User, error) {
u := &models.User{}
s.Hydrate(u)
err := s.db.QueryRowContext(ctx, `
SELECT id, nick, password_hash, created_at, updated_at, last_seen_at
FROM users WHERE id = ?`,
id,
).Scan(
&u.ID, &u.Nick, &u.PasswordHash,
&u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// GetChannelByID looks up a channel by its ID.
func (s *Database) GetChannelByID(
ctx context.Context,
id string,
) (*models.Channel, error) {
c := &models.Channel{}
s.Hydrate(c)
err := s.db.QueryRowContext(ctx, `
SELECT id, name, topic, modes, created_at, updated_at
FROM channels WHERE id = ?`,
id,
).Scan(
&c.ID, &c.Name, &c.Topic, &c.Modes,
&c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, err
}
return c, nil
}
// GetUserByNickModel looks up a user by their nick.
func (s *Database) GetUserByNickModel(
ctx context.Context,
nick string,
) (*models.User, error) {
u := &models.User{}
s.Hydrate(u)
err := s.db.QueryRowContext(ctx, `
SELECT id, nick, password_hash, created_at, updated_at, last_seen_at
FROM users WHERE nick = ?`,
nick,
).Scan(
&u.ID, &u.Nick, &u.PasswordHash,
&u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// GetUserByTokenModel looks up a user by their auth token.
func (s *Database) GetUserByTokenModel(
ctx context.Context,
token string,
) (*models.User, error) {
u := &models.User{}
s.Hydrate(u)
err := s.db.QueryRowContext(ctx, `
SELECT u.id, u.nick, u.password_hash,
u.created_at, u.updated_at, u.last_seen_at
FROM users u
JOIN auth_tokens t ON t.user_id = u.id
WHERE t.token = ?`,
token,
).Scan(
&u.ID, &u.Nick, &u.PasswordHash,
&u.CreatedAt, &u.UpdatedAt, &u.LastSeenAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// DeleteAuthToken removes an auth token from the database.
func (s *Database) DeleteAuthToken(
ctx context.Context,
token string,
) error {
_, err := s.db.ExecContext(ctx,
`DELETE FROM auth_tokens WHERE token = ?`, token,
)
return err
}
// UpdateUserLastSeen updates the last_seen_at timestamp for a user.
func (s *Database) UpdateUserLastSeen(
ctx context.Context,
userID string,
) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`,
userID,
)
return err
}
// CreateUserModel inserts a new user into the database.
func (s *Database) CreateUserModel(
ctx context.Context,
id, nick, passwordHash string,
) (*models.User, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO users (id, nick, password_hash)
VALUES (?, ?, ?)`,
id, nick, passwordHash,
)
if err != nil {
return nil, err
}
u := &models.User{
ID: id, Nick: nick, PasswordHash: passwordHash,
CreatedAt: now, UpdatedAt: now,
}
s.Hydrate(u)
return u, nil
}
// CreateChannel inserts a new channel into the database.
func (s *Database) CreateChannel(
ctx context.Context,
id, name, topic, modes string,
) (*models.Channel, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO channels (id, name, topic, modes)
VALUES (?, ?, ?, ?)`,
id, name, topic, modes,
)
if err != nil {
return nil, err
}
c := &models.Channel{
ID: id, Name: name, Topic: topic, Modes: modes,
CreatedAt: now, UpdatedAt: now,
}
s.Hydrate(c)
return c, nil
}
// AddChannelMember adds a user to a channel with the given modes.
func (s *Database) AddChannelMember(
ctx context.Context,
channelID, userID, modes string,
) (*models.ChannelMember, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO channel_members
(channel_id, user_id, modes)
VALUES (?, ?, ?)`,
channelID, userID, modes,
)
if err != nil {
return nil, err
}
cm := &models.ChannelMember{
ChannelID: channelID,
UserID: userID,
Modes: modes,
JoinedAt: now,
}
s.Hydrate(cm)
return cm, nil
}
// CreateMessage inserts a new message into the database.
func (s *Database) CreateMessage(
ctx context.Context,
id, fromUserID, fromNick, target, msgType, body string,
) (*models.Message, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO messages
(id, from_user_id, from_nick, target, type, body)
VALUES (?, ?, ?, ?, ?, ?)`,
id, fromUserID, fromNick, target, msgType, body,
)
if err != nil {
return nil, err
}
m := &models.Message{
ID: id,
FromUserID: fromUserID,
FromNick: fromNick,
Target: target,
Type: msgType,
Body: body,
Timestamp: now,
CreatedAt: now,
}
s.Hydrate(m)
return m, nil
}
// QueueMessage adds a message to a user's delivery queue.
func (s *Database) QueueMessage(
ctx context.Context,
userID, messageID string,
) (*models.MessageQueueEntry, error) {
now := time.Now()
res, err := s.db.ExecContext(ctx,
`INSERT INTO message_queue (user_id, message_id)
VALUES (?, ?)`,
userID, messageID,
)
if err != nil {
return nil, err
}
entryID, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get last insert id: %w", err)
}
mq := &models.MessageQueueEntry{
ID: entryID,
UserID: userID,
MessageID: messageID,
QueuedAt: now,
}
s.Hydrate(mq)
return mq, nil
}
// DequeueMessages returns up to limit pending messages for a user,
// ordered by queue time (oldest first).
func (s *Database) DequeueMessages(
ctx context.Context,
userID string,
limit int,
) ([]*models.MessageQueueEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, user_id, message_id, queued_at
FROM message_queue
WHERE user_id = ?
ORDER BY queued_at ASC
LIMIT ?`,
userID, limit,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
entries := []*models.MessageQueueEntry{}
for rows.Next() {
e := &models.MessageQueueEntry{}
s.Hydrate(e)
err = rows.Scan(&e.ID, &e.UserID, &e.MessageID, &e.QueuedAt)
if err != nil {
return nil, err
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// AckMessages removes the given queue entry IDs, marking them as delivered.
func (s *Database) AckMessages(
ctx context.Context,
entryIDs []int64,
) error {
if len(entryIDs) == 0 {
return nil
}
placeholders := make([]string, len(entryIDs))
args := make([]any, len(entryIDs))
for i, id := range entryIDs {
placeholders[i] = "?"
args[i] = id
}
query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input
"DELETE FROM message_queue WHERE id IN (%s)",
strings.Join(placeholders, ","),
)
_, err := s.db.ExecContext(ctx, query, args...)
return err
}
// CreateAuthToken inserts a new auth token for a user.
func (s *Database) CreateAuthToken(
ctx context.Context,
token, userID string,
) (*models.AuthToken, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO auth_tokens (token, user_id)
VALUES (?, ?)`,
token, userID,
)
if err != nil {
return nil, err
}
at := &models.AuthToken{Token: token, UserID: userID, CreatedAt: now}
s.Hydrate(at)
return at, nil
}
// CreateSession inserts a new session for a user.
func (s *Database) CreateSession(
ctx context.Context,
id, userID string,
) (*models.Session, error) {
now := time.Now()
_, err := s.db.ExecContext(ctx,
`INSERT INTO sessions (id, user_id)
VALUES (?, ?)`,
id, userID,
)
if err != nil {
return nil, err
}
sess := &models.Session{
ID: id, UserID: userID,
CreatedAt: now, LastActiveAt: now,
}
s.Hydrate(sess)
return sess, nil
}
// CreateServerLink inserts a new server link.
func (s *Database) CreateServerLink(
ctx context.Context,
id, name, url, sharedKeyHash string,
isActive bool,
) (*models.ServerLink, error) {
now := time.Now()
active := 0
if isActive {
active = 1
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO server_links
(id, name, url, shared_key_hash, is_active)
VALUES (?, ?, ?, ?, ?)`,
id, name, url, sharedKeyHash, active,
)
if err != nil {
return nil, err
}
sl := &models.ServerLink{
ID: id,
Name: name,
URL: url,
SharedKeyHash: sharedKeyHash,
IsActive: isActive,
CreatedAt: now,
}
s.Hydrate(sl)
return sl, nil
}
func (s *Database) connect(ctx context.Context) error { func (s *Database) connect(ctx context.Context) error {
dbURL := s.params.Config.DBURL dbURL := s.params.Config.DBURL
if dbURL == "" { if dbURL == "" {
@@ -88,18 +538,14 @@ func (s *Database) connect(ctx context.Context) error {
d, err := sql.Open("sqlite", dbURL) d, err := sql.Open("sqlite", dbURL)
if err != nil { if err != nil {
s.log.Error( s.log.Error("failed to open database", "error", err)
"failed to open database", "error", err,
)
return err return err
} }
err = d.PingContext(ctx) err = d.PingContext(ctx)
if err != nil { if err != nil {
s.log.Error( s.log.Error("failed to ping database", "error", err)
"failed to ping database", "error", err,
)
return err return err
} }
@@ -107,9 +553,8 @@ func (s *Database) connect(ctx context.Context) error {
s.db = d s.db = d
s.log.Info("database connected") s.log.Info("database connected")
_, err = s.db.ExecContext( // Item 9: Enable foreign keys on every connection
ctx, "PRAGMA foreign_keys = ON", _, err = s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON")
)
if err != nil { if err != nil {
return fmt.Errorf("enable foreign keys: %w", err) return fmt.Errorf("enable foreign keys: %w", err)
} }
@@ -123,17 +568,10 @@ type migration struct {
sql string sql string
} }
func (s *Database) runMigrations( func (s *Database) runMigrations(ctx context.Context) error {
ctx context.Context, err := s.bootstrapMigrationsTable(ctx)
) error {
_, err := s.db.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)`)
if err != nil { if err != nil {
return fmt.Errorf( return err
"create schema_migrations: %w", err,
)
} }
migrations, err := s.loadMigrations() migrations, err := s.loadMigrations()
@@ -141,11 +579,9 @@ func (s *Database) runMigrations(
return err return err
} }
for _, m := range migrations { err = s.applyMigrations(ctx, migrations)
err = s.applyMigration(ctx, m) if err != nil {
if err != nil { return err
return err
}
} }
s.log.Info("database migrations complete") s.log.Info("database migrations complete")
@@ -153,87 +589,30 @@ func (s *Database) runMigrations(
return nil return nil
} }
func (s *Database) applyMigration( func (s *Database) bootstrapMigrationsTable(
ctx context.Context, ctx context.Context,
m migration,
) error { ) error {
var exists int _, err := s.db.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS schema_migrations (
err := s.db.QueryRowContext(ctx, version INTEGER PRIMARY KEY,
`SELECT COUNT(*) FROM schema_migrations applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
WHERE version = ?`, )`)
m.version,
).Scan(&exists)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
"check migration %d: %w", m.version, err, "create schema_migrations table: %w", err,
) )
} }
if exists > 0 { return nil
return nil
}
s.log.Info(
"applying migration",
"version", m.version,
"name", m.name,
)
return s.execMigration(ctx, m)
} }
func (s *Database) execMigration( func (s *Database) loadMigrations() ([]migration, error) {
ctx context.Context,
m migration,
) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf(
"begin tx for migration %d: %w",
m.version, err,
)
}
_, err = tx.ExecContext(ctx, m.sql)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf(
"apply migration %d (%s): %w",
m.version, m.name, err,
)
}
_, err = tx.ExecContext(ctx,
`INSERT INTO schema_migrations (version)
VALUES (?)`,
m.version,
)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf(
"record migration %d: %w",
m.version, err,
)
}
return tx.Commit()
}
func (s *Database) loadMigrations() (
[]migration,
error,
) {
entries, err := fs.ReadDir(SchemaFiles, "schema") entries, err := fs.ReadDir(SchemaFiles, "schema")
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("read schema dir: %w", err)
"read schema dir: %w", err,
)
} }
var migrations []migration migrations := make([]migration, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {
if entry.IsDir() || if entry.IsDir() ||
@@ -248,18 +627,17 @@ func (s *Database) loadMigrations() (
continue continue
} }
version, parseErr := strconv.Atoi(parts[0]) version, err := strconv.Atoi(parts[0])
if parseErr != nil { if err != nil {
continue continue
} }
content, readErr := SchemaFiles.ReadFile( content, err := SchemaFiles.ReadFile(
"schema/" + entry.Name(), "schema/" + entry.Name(),
) )
if readErr != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"read migration %s: %w", "read migration %s: %w", entry.Name(), err,
entry.Name(), readErr,
) )
} }
@@ -276,3 +654,82 @@ func (s *Database) loadMigrations() (
return migrations, nil return migrations, nil
} }
// Item 4: Wrap each migration in a transaction
func (s *Database) applyMigrations(
ctx context.Context,
migrations []migration,
) error {
for _, m := range migrations {
var exists int
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
m.version,
).Scan(&exists)
if err != nil {
return fmt.Errorf(
"check migration %d: %w", m.version, err,
)
}
if exists > 0 {
continue
}
s.log.Info(
"applying migration",
"version", m.version, "name", m.name,
)
err = s.executeMigration(ctx, m)
if err != nil {
return err
}
}
return nil
}
func (s *Database) executeMigration(
ctx context.Context,
m migration,
) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf(
"begin tx for migration %d: %w", m.version, err,
)
}
_, err = tx.ExecContext(ctx, m.sql)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf(
"apply migration %d (%s): %w",
m.version, m.name, err,
)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
m.version,
)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf(
"record migration %d: %w", m.version, err,
)
}
err = tx.Commit()
if err != nil {
return fmt.Errorf(
"commit migration %d: %w", m.version, err,
)
}
return nil
}

424
internal/db/db_test.go Normal file
View File

@@ -0,0 +1,424 @@
package db_test
import (
"fmt"
"path/filepath"
"testing"
"time"
"git.eeqj.de/sneak/chat/internal/db"
)
const (
nickAlice = "alice"
nickBob = "bob"
nickCharlie = "charlie"
)
// setupTestDB creates a fresh database in a temp directory with
// all migrations applied.
func setupTestDB(t *testing.T) *db.Database {
t.Helper()
dir := t.TempDir()
dsn := fmt.Sprintf(
"file:%s?_journal_mode=WAL",
filepath.Join(dir, "test.db"),
)
d, err := db.NewTest(dsn)
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}
t.Cleanup(func() { _ = d.GetDB().Close() })
return d
}
func TestCreateUser(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
id, token, err := d.CreateUser(ctx, nickAlice)
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if id <= 0 {
t.Errorf("expected positive id, got %d", id)
}
if token == "" {
t.Error("expected non-empty token")
}
}
func TestGetUserByToken(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
_, token, _ := d.CreateUser(ctx, nickAlice)
id, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("GetUserByToken: %v", err)
}
if id <= 0 || nick != nickAlice {
t.Errorf(
"got id=%d nick=%s, want nick=%s",
id, nick, nickAlice,
)
}
}
func TestGetUserByNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
origID, _, _ := d.CreateUser(ctx, nickAlice)
id, err := d.GetUserByNick(ctx, nickAlice)
if err != nil {
t.Fatalf("GetUserByNick: %v", err)
}
if id != origID {
t.Errorf("got id %d, want %d", id, origID)
}
}
func TestGetOrCreateChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
id1, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("GetOrCreateChannel: %v", err)
}
if id1 <= 0 {
t.Errorf("expected positive id, got %d", id1)
}
// Same channel returns same ID.
id2, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("GetOrCreateChannel(2): %v", err)
}
if id1 != id2 {
t.Errorf("got different ids: %d vs %d", id1, id2)
}
}
func TestJoinAndListChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
ch1, _ := d.GetOrCreateChannel(ctx, "#alpha")
ch2, _ := d.GetOrCreateChannel(ctx, "#beta")
_ = d.JoinChannel(ctx, ch1, uid)
_ = d.JoinChannel(ctx, ch2, uid)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels))
}
}
func TestListChannelsEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 0 {
t.Errorf("expected 0 channels, got %d", len(channels))
}
}
func TestPartChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
_ = d.PartChannel(ctx, chID, uid)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 0 {
t.Errorf("expected 0 after part, got %d", len(channels))
}
}
func TestSendAndGetMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
_, err := d.SendMessage(ctx, chID, uid, "hello world")
if err != nil {
t.Fatalf("SendMessage: %v", err)
}
msgs, err := d.GetMessages(ctx, chID, 0, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Content != "hello world" {
t.Errorf(
"got content %q, want %q",
msgs[0].Content, "hello world",
)
}
}
func TestChannelMembers(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
uid3, _, _ := d.CreateUser(ctx, nickCharlie)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_ = d.JoinChannel(ctx, chID, uid3)
members, err := d.ChannelMembers(ctx, chID)
if err != nil {
t.Fatalf("ChannelMembers: %v", err)
}
if len(members) != 3 {
t.Fatalf("expected 3 members, got %d", len(members))
}
}
func TestChannelMembersEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
chID, _ := d.GetOrCreateChannel(ctx, "#empty")
members, err := d.ChannelMembers(ctx, chID)
if err != nil {
t.Fatalf("ChannelMembers: %v", err)
}
if len(members) != 0 {
t.Errorf("expected 0, got %d", len(members))
}
}
func TestSendDM(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob")
if err != nil {
t.Fatalf("SendDM: %v", err)
}
if msgID <= 0 {
t.Errorf("expected positive msgID, got %d", msgID)
}
msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0)
if err != nil {
t.Fatalf("GetDMs: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 DM, got %d", len(msgs))
}
if msgs[0].Content != "hey bob" {
t.Errorf("got %q, want %q", msgs[0].Content, "hey bob")
}
}
func TestPollMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_, _ = d.SendMessage(ctx, chID, uid2, "hello")
_, _ = d.SendDM(ctx, uid2, uid1, "private")
time.Sleep(10 * time.Millisecond)
msgs, err := d.PollMessages(ctx, uid1, 0, 0)
if err != nil {
t.Fatalf("PollMessages: %v", err)
}
if len(msgs) < 2 {
t.Fatalf("expected >=2 messages, got %d", len(msgs))
}
}
func TestChangeNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
_, token, _ := d.CreateUser(ctx, nickAlice)
err := d.ChangeNick(ctx, 1, "alice2")
if err != nil {
t.Fatalf("ChangeNick: %v", err)
}
_, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("GetUserByToken: %v", err)
}
if nick != "alice2" {
t.Errorf("got nick %q, want alice2", nick)
}
}
func TestSetTopic(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
_, _ = d.GetOrCreateChannel(ctx, "#general")
err := d.SetTopic(ctx, "#general", uid, "new topic")
if err != nil {
t.Fatalf("SetTopic: %v", err)
}
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("ListAllChannels: %v", err)
}
found := false
for _, ch := range channels {
if ch.Name == "#general" && ch.Topic == "new topic" {
found = true
}
}
if !found {
t.Error("topic was not updated")
}
}
func TestGetMessagesBefore(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
for i := range 5 {
_, _ = d.SendMessage(
ctx, chID, uid,
fmt.Sprintf("msg%d", i),
)
time.Sleep(10 * time.Millisecond)
}
msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3)
if err != nil {
t.Fatalf("GetMessagesBefore: %v", err)
}
if len(msgs) != 3 {
t.Fatalf("expected 3, got %d", len(msgs))
}
}
func TestListAllChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
_, _ = d.GetOrCreateChannel(ctx, "#alpha")
_, _ = d.GetOrCreateChannel(ctx, "#beta")
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("ListAllChannels: %v", err)
}
if len(channels) != 2 {
t.Errorf("expected 2, got %d", len(channels))
}
}

View File

@@ -1,47 +0,0 @@
package db
import (
"context"
"database/sql"
"fmt"
"log/slog"
"sync/atomic"
)
//nolint:gochecknoglobals // test counter
var testDBCounter atomic.Int64
// NewTestDatabase creates an in-memory database for testing.
func NewTestDatabase() (*Database, error) {
n := testDBCounter.Add(1)
dsn := fmt.Sprintf(
"file:testdb%d?mode=memory"+
"&cache=shared&_pragma=foreign_keys(1)",
n,
)
d, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, err
}
database := &Database{db: d, log: slog.Default()}
err = database.runMigrations(context.Background())
if err != nil {
closeErr := d.Close()
if closeErr != nil {
return nil, closeErr
}
return nil, err
}
return database, nil
}
// Close closes the underlying database connection.
func (s *Database) Close() error {
return s.db.Close()
}

View File

@@ -5,17 +5,14 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"time" "time"
"github.com/google/uuid"
) )
const ( const (
tokenBytes = 32 defaultMessageLimit = 50
defaultPollLimit = 100 defaultPollLimit = 100
defaultHistLimit = 50 tokenBytes = 32
) )
func generateToken() string { func generateToken() string {
@@ -25,33 +22,8 @@ func generateToken() string {
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
// IRCMessage is the IRC envelope for all messages. // CreateUser registers a new user with the given nick and
type IRCMessage struct { // returns the user with token.
ID string `json:"id"`
Command string `json:"command"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Body json.RawMessage `json:"body,omitempty"`
TS string `json:"ts"`
Meta json.RawMessage `json:"meta,omitempty"`
DBID int64 `json:"-"`
}
// ChannelInfo is a lightweight channel representation.
type ChannelInfo struct {
ID int64 `json:"id"`
Name string `json:"name"`
Topic string `json:"topic"`
}
// MemberInfo represents a channel member.
type MemberInfo struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
LastSeen time.Time `json:"lastSeen"`
}
// CreateUser registers a new user with the given nick.
func (s *Database) CreateUser( func (s *Database) CreateUser(
ctx context.Context, ctx context.Context,
nick string, nick string,
@@ -60,9 +32,7 @@ func (s *Database) CreateUser(
now := time.Now() now := time.Now()
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
`INSERT INTO users "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)",
(nick, token, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
nick, token, now, now) nick, token, now, now)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("create user: %w", err) return 0, "", fmt.Errorf("create user: %w", err)
@@ -73,7 +43,8 @@ func (s *Database) CreateUser(
return id, token, nil return id, token, nil
} }
// GetUserByToken returns user id and nick for a token. // GetUserByToken returns user id and nick for a given auth
// token.
func (s *Database) GetUserByToken( func (s *Database) GetUserByToken(
ctx context.Context, ctx context.Context,
token string, token string,
@@ -91,6 +62,7 @@ func (s *Database) GetUserByToken(
return 0, "", err return 0, "", err
} }
// Update last_seen
_, _ = s.db.ExecContext( _, _ = s.db.ExecContext(
ctx, ctx,
"UPDATE users SET last_seen = ? WHERE id = ?", "UPDATE users SET last_seen = ? WHERE id = ?",
@@ -116,23 +88,8 @@ func (s *Database) GetUserByNick(
return id, err return id, err
} }
// GetChannelByName returns the channel ID for a name. // GetOrCreateChannel returns the channel id, creating it if
func (s *Database) GetChannelByName( // needed.
ctx context.Context,
name string,
) (int64, error) {
var id int64
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&id)
return id, err
}
// GetOrCreateChannel returns channel id, creating if needed.
func (s *Database) GetOrCreateChannel( func (s *Database) GetOrCreateChannel(
ctx context.Context, ctx context.Context,
name string, name string,
@@ -151,9 +108,7 @@ func (s *Database) GetOrCreateChannel(
now := time.Now() now := time.Now()
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
`INSERT INTO channels "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)",
(name, created_at, updated_at)
VALUES (?, ?, ?)`,
name, now, now) name, now, now)
if err != nil { if err != nil {
return 0, fmt.Errorf("create channel: %w", err) return 0, fmt.Errorf("create channel: %w", err)
@@ -170,9 +125,7 @@ func (s *Database) JoinChannel(
channelID, userID int64, channelID, userID int64,
) error { ) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_members "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)",
(channel_id, user_id, joined_at)
VALUES (?, ?, ?)`,
channelID, userID, time.Now()) channelID, userID, time.Now())
return err return err
@@ -184,35 +137,35 @@ func (s *Database) PartChannel(
channelID, userID int64, channelID, userID int64,
) error { ) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
`DELETE FROM channel_members "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?",
WHERE channel_id = ? AND user_id = ?`,
channelID, userID) channelID, userID)
return err return err
} }
// DeleteChannelIfEmpty removes a channel with no members. // ChannelInfo is a lightweight channel representation.
func (s *Database) DeleteChannelIfEmpty( type ChannelInfo struct {
ctx context.Context, ID int64 `json:"id"`
channelID int64, Name string `json:"name"`
) error { Topic string `json:"topic"`
_, err := s.db.ExecContext(ctx,
`DELETE FROM channels WHERE id = ?
AND NOT EXISTS
(SELECT 1 FROM channel_members
WHERE channel_id = ?)`,
channelID, channelID)
return err
} }
// scanChannels scans rows into a ChannelInfo slice. // ListChannels returns all channels the user has joined.
func scanChannels( func (s *Database) ListChannels(
rows *sql.Rows, ctx context.Context,
userID int64,
) ([]ChannelInfo, error) { ) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic FROM channels c
INNER JOIN channel_members cm ON cm.channel_id = c.id
WHERE cm.user_id = ? ORDER BY c.name`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }() defer func() { _ = rows.Close() }()
var out []ChannelInfo var channels []ChannelInfo
for rows.Next() { for rows.Next() {
var ch ChannelInfo var ch ChannelInfo
@@ -222,52 +175,26 @@ func scanChannels(
return nil, err return nil, err
} }
out = append(out, ch) channels = append(channels, ch)
} }
err := rows.Err() err = rows.Err()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if out == nil { if channels == nil {
out = []ChannelInfo{} channels = []ChannelInfo{}
} }
return out, nil return channels, nil
} }
// ListChannels returns channels the user has joined. // MemberInfo represents a channel member.
func (s *Database) ListChannels( type MemberInfo struct {
ctx context.Context, ID int64 `json:"id"`
userID int64, Nick string `json:"nick"`
) ([]ChannelInfo, error) { LastSeen time.Time `json:"lastSeen"`
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic
FROM channels c
INNER JOIN channel_members cm
ON cm.channel_id = c.id
WHERE cm.user_id = ?
ORDER BY c.name`, userID)
if err != nil {
return nil, err
}
return scanChannels(rows)
}
// ListAllChannels returns every channel.
func (s *Database) ListAllChannels(
ctx context.Context,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT id, name, topic
FROM channels ORDER BY name`)
if err != nil {
return nil, err
}
return scanChannels(rows)
} }
// ChannelMembers returns all members of a channel. // ChannelMembers returns all members of a channel.
@@ -276,12 +203,9 @@ func (s *Database) ChannelMembers(
channelID int64, channelID int64,
) ([]MemberInfo, error) { ) ([]MemberInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen `SELECT u.id, u.nick, u.last_seen FROM users u
FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id
INNER JOIN channel_members cm WHERE cm.channel_id = ? ORDER BY u.nick`, channelID)
ON cm.user_id = u.id
WHERE cm.channel_id = ?
ORDER BY u.nick`, channelID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -293,7 +217,7 @@ func (s *Database) ChannelMembers(
for rows.Next() { for rows.Next() {
var m MemberInfo var m MemberInfo
err = rows.Scan(&m.ID, &m.Nick, &m.LastSeen) err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -313,256 +237,407 @@ func (s *Database) ChannelMembers(
return members, nil return members, nil
} }
// scanInt64s scans rows into an int64 slice. // MessageInfo represents a chat message.
func scanInt64s(rows *sql.Rows) ([]int64, error) { type MessageInfo struct {
ID int64 `json:"id"`
Channel string `json:"channel,omitempty"`
Nick string `json:"nick"`
Content string `json:"content"`
IsDM bool `json:"isDm,omitempty"`
DMTarget string `json:"dmTarget,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
// GetMessages returns messages for a channel, optionally
// after a given ID.
func (s *Database) GetMessages(
ctx context.Context,
channelID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultMessageLimit
}
rows, err := s.db.QueryContext(ctx,
`SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id > ?
ORDER BY m.id ASC LIMIT ?`, channelID, afterID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }() defer func() { _ = rows.Close() }()
var ids []int64 var msgs []MessageInfo
for rows.Next() { for rows.Next() {
var id int64 var m MessageInfo
err := rows.Scan(&id) err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ids = append(ids, id) msgs = append(msgs, m)
} }
err := rows.Err() err = rows.Err()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ids, nil if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
} }
// GetChannelMemberIDs returns user IDs in a channel. // SendMessage inserts a channel message.
func (s *Database) GetChannelMemberIDs( func (s *Database) SendMessage(
ctx context.Context, ctx context.Context,
channelID int64, channelID, userID int64,
) ([]int64, error) { content string,
rows, err := s.db.QueryContext(ctx, ) (int64, error) {
`SELECT user_id FROM channel_members
WHERE channel_id = ?`, channelID)
if err != nil {
return nil, err
}
return scanInt64s(rows)
}
// GetUserChannelIDs returns channel IDs the user is in.
func (s *Database) GetUserChannelIDs(
ctx context.Context,
userID int64,
) ([]int64, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT channel_id FROM channel_members
WHERE user_id = ?`, userID)
if err != nil {
return nil, err
}
return scanInt64s(rows)
}
// InsertMessage stores a message and returns its DB ID.
func (s *Database) InsertMessage(
ctx context.Context,
command, from, to string,
body json.RawMessage,
meta json.RawMessage,
) (int64, string, error) {
msgUUID := uuid.New().String()
now := time.Now().UTC()
if body == nil {
body = json.RawMessage("[]")
}
if meta == nil {
meta = json.RawMessage("{}")
}
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
`INSERT INTO messages "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)",
(uuid, command, msg_from, msg_to, channelID, userID, content, time.Now())
body, meta, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
msgUUID, command, from, to,
string(body), string(meta), now)
if err != nil { if err != nil {
return 0, "", err return 0, err
} }
id, _ := res.LastInsertId() return res.LastInsertId()
return id, msgUUID, nil
} }
// EnqueueMessage adds a message to a user's queue. // SendDM inserts a direct message.
func (s *Database) EnqueueMessage( func (s *Database) SendDM(
ctx context.Context, ctx context.Context,
userID, messageID int64, fromID, toID int64,
) error { content string,
_, err := s.db.ExecContext(ctx, ) (int64, error) {
`INSERT OR IGNORE INTO client_queues res, err := s.db.ExecContext(ctx,
(user_id, message_id, created_at) "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)",
VALUES (?, ?, ?)`, fromID, content, toID, time.Now())
userID, messageID, time.Now()) if err != nil {
return 0, err
}
return err return res.LastInsertId()
} }
// PollMessages returns queued messages for a user. // GetDMs returns direct messages between two users after a
// given ID.
func (s *Database) GetDMs(
ctx context.Context,
userA, userB int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultMessageLimit
}
rows, err := s.db.QueryContext(ctx,
`SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id > ?
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id ASC LIMIT ?`,
afterID, userA, userB, userB, userA, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
// PollMessages returns all new messages (channel + DM) for
// a user after a given ID.
func (s *Database) PollMessages( func (s *Database) PollMessages(
ctx context.Context, ctx context.Context,
userID, afterQueueID int64, userID int64,
afterID int64,
limit int, limit int,
) ([]IRCMessage, int64, error) { ) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultPollLimit limit = defaultPollLimit
} }
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT cq.id, m.uuid, m.command, `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content,
m.msg_from, m.msg_to, m.is_dm, COALESCE(t.nick, ''), m.created_at
m.body, m.meta, m.created_at FROM messages m
FROM client_queues cq INNER JOIN users u ON u.id = m.user_id
INNER JOIN messages m LEFT JOIN channels c ON c.id = m.channel_id
ON m.id = cq.message_id LEFT JOIN users t ON t.id = m.dm_target_id
WHERE cq.user_id = ? AND cq.id > ? WHERE m.id > ? AND (
ORDER BY cq.id ASC LIMIT ?`, (m.is_dm = 0 AND m.channel_id IN
userID, afterQueueID, limit) (SELECT channel_id FROM channel_members
if err != nil { WHERE user_id = ?))
return nil, afterQueueID, err OR (m.is_dm = 1
} AND (m.user_id = ? OR m.dm_target_id = ?))
)
msgs, lastQID, scanErr := scanMessages( ORDER BY m.id ASC LIMIT ?`,
rows, afterQueueID, afterID, userID, userID, userID, limit)
)
if scanErr != nil {
return nil, afterQueueID, scanErr
}
return msgs, lastQID, nil
}
// GetHistory returns message history for a target.
func (s *Database) GetHistory(
ctx context.Context,
target string,
beforeID int64,
limit int,
) ([]IRCMessage, error) {
if limit <= 0 {
limit = defaultHistLimit
}
rows, err := s.queryHistory(
ctx, target, beforeID, limit,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msgs, _, scanErr := scanMessages(rows, 0) defer func() { _ = rows.Close() }()
if scanErr != nil {
return nil, scanErr msgs := make([]MessageInfo, 0)
for rows.Next() {
var (
m MessageInfo
isDM int
)
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick, &m.Content,
&isDM, &m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = isDM == 1
msgs = append(msgs, m)
} }
if msgs == nil { err = rows.Err()
msgs = []IRCMessage{} if err != nil {
return nil, err
} }
reverseMessages(msgs)
return msgs, nil return msgs, nil
} }
func (s *Database) queryHistory( func scanChannelMessages(
ctx context.Context,
target string,
beforeID int64,
limit int,
) (*sql.Rows, error) {
if beforeID > 0 {
return s.db.QueryContext(ctx,
`SELECT id, uuid, command, msg_from,
msg_to, body, meta, created_at
FROM messages
WHERE msg_to = ? AND id < ?
AND command = 'PRIVMSG'
ORDER BY id DESC LIMIT ?`,
target, beforeID, limit)
}
return s.db.QueryContext(ctx,
`SELECT id, uuid, command, msg_from,
msg_to, body, meta, created_at
FROM messages
WHERE msg_to = ?
AND command = 'PRIVMSG'
ORDER BY id DESC LIMIT ?`,
target, limit)
}
func scanMessages(
rows *sql.Rows, rows *sql.Rows,
fallbackQID int64, ) ([]MessageInfo, error) {
) ([]IRCMessage, int64, error) { var msgs []MessageInfo
defer func() { _ = rows.Close() }()
var msgs []IRCMessage
lastQID := fallbackQID
for rows.Next() { for rows.Next() {
var ( var m MessageInfo
m IRCMessage
qID int64
body, meta string
ts time.Time
)
err := rows.Scan( err := rows.Scan(
&qID, &m.ID, &m.Command, &m.ID, &m.Channel, &m.Nick,
&m.From, &m.To, &m.Content, &m.CreatedAt,
&body, &meta, &ts,
) )
if err != nil { if err != nil {
return nil, fallbackQID, err return nil, err
} }
m.Body = json.RawMessage(body)
m.Meta = json.RawMessage(meta)
m.TS = ts.Format(time.RFC3339Nano)
m.DBID = qID
lastQID = qID
msgs = append(msgs, m) msgs = append(msgs, m)
} }
err := rows.Err() err := rows.Err()
if err != nil { if err != nil {
return nil, fallbackQID, err return nil, err
} }
if msgs == nil { if msgs == nil {
msgs = []IRCMessage{} msgs = []MessageInfo{}
} }
return msgs, lastQID, nil return msgs, nil
} }
func reverseMessages(msgs []IRCMessage) { func scanDMMessages(
rows *sql.Rows,
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func reverseMessages(msgs []MessageInfo) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i] msgs[i], msgs[j] = msgs[j], msgs[i]
} }
} }
// GetMessagesBefore returns channel messages before a given
// ID (for history scrollback).
func (s *Database) GetMessagesBefore(
ctx context.Context,
channelID int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultMessageLimit
}
var query string
var args []any
if beforeID > 0 {
query = `SELECT m.id, c.name, u.nick, m.content,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
AND m.id < ?
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, beforeID, limit}
} else {
query = `SELECT m.id, c.name, u.nick, m.content,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs, scanErr := scanChannelMessages(rows)
if scanErr != nil {
return nil, scanErr
}
// Reverse to ascending order.
reverseMessages(msgs)
return msgs, nil
}
// GetDMsBefore returns DMs between two users before a given
// ID (for history scrollback).
func (s *Database) GetDMsBefore(
ctx context.Context,
userA, userB int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultMessageLimit
}
var query string
var args []any
if beforeID > 0 {
query = `SELECT m.id, u.nick, m.content, t.nick,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id < ?
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{
beforeID, userA, userB, userB, userA, limit,
}
} else {
query = `SELECT m.id, u.nick, m.content, t.nick,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{userA, userB, userB, userA, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs, scanErr := scanDMMessages(rows)
if scanErr != nil {
return nil, scanErr
}
// Reverse to ascending order.
reverseMessages(msgs)
return msgs, nil
}
// ChangeNick updates a user's nickname. // ChangeNick updates a user's nickname.
func (s *Database) ChangeNick( func (s *Database) ChangeNick(
ctx context.Context, ctx context.Context,
@@ -579,45 +654,58 @@ func (s *Database) ChangeNick(
// SetTopic sets the topic for a channel. // SetTopic sets the topic for a channel.
func (s *Database) SetTopic( func (s *Database) SetTopic(
ctx context.Context, ctx context.Context,
channelName, topic string, channelName string,
_ int64,
topic string,
) error { ) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
`UPDATE channels SET topic = ?, "UPDATE channels SET topic = ? WHERE name = ?",
updated_at = ? WHERE name = ?`, topic, channelName)
topic, time.Now(), channelName)
return err return err
} }
// DeleteUser removes a user and all their data. // GetServerName returns the server name (unused, config
func (s *Database) DeleteUser( // provides this).
ctx context.Context, func (s *Database) GetServerName() string {
userID int64, return ""
) error {
_, err := s.db.ExecContext(
ctx,
"DELETE FROM users WHERE id = ?",
userID,
)
return err
} }
// GetAllChannelMembershipsForUser returns channels // ListAllChannels returns all channels.
// a user belongs to. func (s *Database) ListAllChannels(
func (s *Database) GetAllChannelMembershipsForUser(
ctx context.Context, ctx context.Context,
userID int64,
) ([]ChannelInfo, error) { ) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic "SELECT id, name, topic FROM channels ORDER BY name")
FROM channels c
INNER JOIN channel_members cm
ON cm.channel_id = c.id
WHERE cm.user_id = ?`, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return scanChannels(rows) defer func() { _ = rows.Close() }()
var channels []ChannelInfo
for rows.Next() {
var ch ChannelInfo
err := rows.Scan(
&ch.ID, &ch.Name, &ch.Topic,
)
if err != nil {
return nil, err
}
channels = append(channels, ch)
}
err = rows.Err()
if err != nil {
return nil, err
}
if channels == nil {
channels = []ChannelInfo{}
}
return channels, nil
} }

View File

@@ -1,550 +0,0 @@
package db_test
import (
"context"
"encoding/json"
"testing"
"git.eeqj.de/sneak/chat/internal/db"
_ "modernc.org/sqlite"
)
func setupTestDB(t *testing.T) *db.Database {
t.Helper()
d, err := db.NewTestDatabase()
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
closeErr := d.Close()
if closeErr != nil {
t.Logf("close db: %v", closeErr)
}
})
return d
}
func TestCreateUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
id, token, err := database.CreateUser(ctx, "alice")
if err != nil {
t.Fatal(err)
}
if id == 0 || token == "" {
t.Fatal("expected valid id and token")
}
_, _, err = database.CreateUser(ctx, "alice")
if err == nil {
t.Fatal("expected error for duplicate nick")
}
}
func TestGetUserByToken(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
_, token, err := database.CreateUser(ctx, "bob")
if err != nil {
t.Fatal(err)
}
id, nick, err := database.GetUserByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if nick != "bob" || id == 0 {
t.Fatalf("expected bob, got %s", nick)
}
_, _, err = database.GetUserByToken(ctx, "badtoken")
if err == nil {
t.Fatal("expected error for bad token")
}
}
func TestGetUserByNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
_, _, err := database.CreateUser(ctx, "charlie")
if err != nil {
t.Fatal(err)
}
id, err := database.GetUserByNick(ctx, "charlie")
if err != nil || id == 0 {
t.Fatal("expected to find charlie")
}
_, err = database.GetUserByNick(ctx, "nobody")
if err == nil {
t.Fatal("expected error for unknown nick")
}
}
func TestChannelOperations(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
chID, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil || chID == 0 {
t.Fatal("expected channel id")
}
chID2, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil || chID2 != chID {
t.Fatal("expected same channel id")
}
chID3, err := database.GetChannelByName(ctx, "#test")
if err != nil || chID3 != chID {
t.Fatal("expected same channel id")
}
_, err = database.GetChannelByName(ctx, "#nope")
if err == nil {
t.Fatal("expected error for nonexistent channel")
}
}
func TestJoinAndPart(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, "user1")
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#chan")
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, uid)
if err != nil {
t.Fatal(err)
}
ids, err := database.GetChannelMemberIDs(ctx, chID)
if err != nil || len(ids) != 1 || ids[0] != uid {
t.Fatal("expected user in channel")
}
err = database.JoinChannel(ctx, chID, uid)
if err != nil {
t.Fatal(err)
}
err = database.PartChannel(ctx, chID, uid)
if err != nil {
t.Fatal(err)
}
ids, _ = database.GetChannelMemberIDs(ctx, chID)
if len(ids) != 0 {
t.Fatal("expected empty channel")
}
}
func TestDeleteChannelIfEmpty(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
chID, err := database.GetOrCreateChannel(
ctx, "#empty",
)
if err != nil {
t.Fatal(err)
}
uid, _, err := database.CreateUser(ctx, "temp")
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, uid)
if err != nil {
t.Fatal(err)
}
err = database.PartChannel(ctx, chID, uid)
if err != nil {
t.Fatal(err)
}
err = database.DeleteChannelIfEmpty(ctx, chID)
if err != nil {
t.Fatal(err)
}
_, err = database.GetChannelByName(ctx, "#empty")
if err == nil {
t.Fatal("expected channel to be deleted")
}
}
func createUserWithChannels(
t *testing.T,
database *db.Database,
nick, ch1Name, ch2Name string,
) (int64, int64, int64) {
t.Helper()
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, nick)
if err != nil {
t.Fatal(err)
}
ch1, err := database.GetOrCreateChannel(
ctx, ch1Name,
)
if err != nil {
t.Fatal(err)
}
ch2, err := database.GetOrCreateChannel(
ctx, ch2Name,
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, ch1, uid)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, ch2, uid)
if err != nil {
t.Fatal(err)
}
return uid, ch1, ch2
}
func TestListChannels(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
uid, _, _ := createUserWithChannels(
t, database, "lister", "#a", "#b",
)
channels, err := database.ListChannels(
context.Background(), uid,
)
if err != nil || len(channels) != 2 {
t.Fatalf(
"expected 2 channels, got %d",
len(channels),
)
}
}
func TestListAllChannels(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
_, err := database.GetOrCreateChannel(ctx, "#x")
if err != nil {
t.Fatal(err)
}
_, err = database.GetOrCreateChannel(ctx, "#y")
if err != nil {
t.Fatal(err)
}
channels, err := database.ListAllChannels(ctx)
if err != nil || len(channels) < 2 {
t.Fatalf(
"expected >= 2 channels, got %d",
len(channels),
)
}
}
func TestChangeNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
uid, token, err := database.CreateUser(ctx, "old")
if err != nil {
t.Fatal(err)
}
err = database.ChangeNick(ctx, uid, "new")
if err != nil {
t.Fatal(err)
}
_, nick, err := database.GetUserByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if nick != "new" {
t.Fatalf("expected new, got %s", nick)
}
}
func TestSetTopic(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
_, err := database.GetOrCreateChannel(
ctx, "#topictest",
)
if err != nil {
t.Fatal(err)
}
err = database.SetTopic(ctx, "#topictest", "Hello")
if err != nil {
t.Fatal(err)
}
channels, err := database.ListAllChannels(ctx)
if err != nil {
t.Fatal(err)
}
for _, ch := range channels {
if ch.Name == "#topictest" &&
ch.Topic != "Hello" {
t.Fatalf(
"expected topic Hello, got %s",
ch.Topic,
)
}
}
}
func TestInsertAndPollMessages(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, "poller")
if err != nil {
t.Fatal(err)
}
body := json.RawMessage(`["hello"]`)
dbID, msgUUID, err := database.InsertMessage(
ctx, "PRIVMSG", "poller", "#test", body, nil,
)
if err != nil || dbID == 0 || msgUUID == "" {
t.Fatal("insert failed")
}
err = database.EnqueueMessage(ctx, uid, dbID)
if err != nil {
t.Fatal(err)
}
const batchSize = 10
msgs, lastQID, err := database.PollMessages(
ctx, uid, 0, batchSize,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 1 {
t.Fatalf(
"expected 1 message, got %d", len(msgs),
)
}
if msgs[0].Command != "PRIVMSG" {
t.Fatalf(
"expected PRIVMSG, got %s", msgs[0].Command,
)
}
if lastQID == 0 {
t.Fatal("expected nonzero lastQID")
}
msgs, _, _ = database.PollMessages(
ctx, uid, lastQID, batchSize,
)
if len(msgs) != 0 {
t.Fatalf(
"expected 0 messages, got %d", len(msgs),
)
}
}
func TestGetHistory(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
const msgCount = 10
for range msgCount {
_, _, err := database.InsertMessage(
ctx, "PRIVMSG", "user", "#hist",
json.RawMessage(`["msg"]`), nil,
)
if err != nil {
t.Fatal(err)
}
}
const histLimit = 5
msgs, err := database.GetHistory(
ctx, "#hist", 0, histLimit,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != histLimit {
t.Fatalf("expected %d, got %d",
histLimit, len(msgs))
}
if msgs[0].DBID > msgs[histLimit-1].DBID {
t.Fatal("expected ascending order")
}
}
func TestDeleteUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, "deleteme")
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(
ctx, "#delchan",
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, uid)
if err != nil {
t.Fatal(err)
}
err = database.DeleteUser(ctx, uid)
if err != nil {
t.Fatal(err)
}
_, err = database.GetUserByNick(ctx, "deleteme")
if err == nil {
t.Fatal("user should be deleted")
}
ids, _ := database.GetChannelMemberIDs(ctx, chID)
if len(ids) != 0 {
t.Fatal("expected no members after deletion")
}
}
func TestChannelMembers(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := context.Background()
uid1, _, err := database.CreateUser(ctx, "m1")
if err != nil {
t.Fatal(err)
}
uid2, _, err := database.CreateUser(ctx, "m2")
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(
ctx, "#members",
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, uid1)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, uid2)
if err != nil {
t.Fatal(err)
}
members, err := database.ChannelMembers(ctx, chID)
if err != nil || len(members) != 2 {
t.Fatalf(
"expected 2 members, got %d",
len(members),
)
}
}
func TestGetAllChannelMembershipsForUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
uid, _, _ := createUserWithChannels(
t, database, "multi", "#m1", "#m2",
)
channels, err :=
database.GetAllChannelMembershipsForUser(
context.Background(), uid,
)
if err != nil || len(channels) != 2 {
t.Fatalf(
"expected 2 channels, got %d",
len(channels),
)
}
}

View File

@@ -1,54 +1,4 @@
-- Chat server schema (pre-1.0 consolidated) CREATE TABLE IF NOT EXISTS schema_migrations (
PRAGMA foreign_keys = ON; version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
-- Users: IRC-style sessions (no passwords, just nick + token)
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
nick TEXT NOT NULL UNIQUE,
token TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE INDEX IF NOT EXISTS idx_users_token ON users(token);
-- Channels
CREATE TABLE IF NOT EXISTS channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
topic TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Channel members
CREATE TABLE IF NOT EXISTS channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, user_id)
);
-- Messages: IRC envelope format
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT NOT NULL UNIQUE,
command TEXT NOT NULL DEFAULT 'PRIVMSG',
msg_from TEXT NOT NULL DEFAULT '',
msg_to TEXT NOT NULL DEFAULT '',
body TEXT NOT NULL DEFAULT '[]',
meta TEXT NOT NULL DEFAULT '{}',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_messages_to_id ON messages(msg_to, id);
CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at);
-- Per-client message queues for fan-out delivery
CREATE TABLE IF NOT EXISTS client_queues (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, message_id)
);
CREATE INDEX IF NOT EXISTS idx_client_queues_user ON client_queues(user_id, id);

View File

@@ -0,0 +1,89 @@
-- All schema changes go into this file until 1.0.0 is tagged.
-- There will not be migrations during the early development phase.
-- After 1.0.0, new changes get their own numbered migration files.
-- Users: accounts and authentication
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY, -- UUID
nick TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen_at DATETIME
);
-- Auth tokens: one user can have multiple active tokens (multiple devices)
CREATE TABLE IF NOT EXISTS auth_tokens (
token TEXT PRIMARY KEY, -- random token string
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME, -- NULL = no expiry
last_used_at DATETIME
);
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens(user_id);
-- Channels: chat rooms
CREATE TABLE IF NOT EXISTS channels (
id TEXT PRIMARY KEY, -- UUID
name TEXT NOT NULL UNIQUE, -- #general, etc.
topic TEXT NOT NULL DEFAULT '',
modes TEXT NOT NULL DEFAULT '', -- +i, +m, +s, +t, +n
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Channel members: who is in which channel, with per-user modes
CREATE TABLE IF NOT EXISTS channel_members (
channel_id TEXT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
modes TEXT NOT NULL DEFAULT '', -- +o (operator), +v (voice)
joined_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (channel_id, user_id)
);
CREATE INDEX IF NOT EXISTS idx_channel_members_user_id ON channel_members(user_id);
-- Messages: channel and DM history (rotated per MAX_HISTORY)
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY, -- UUID
ts DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
from_user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
from_nick TEXT NOT NULL, -- denormalized for history
target TEXT NOT NULL, -- #channel name or user UUID for DMs
type TEXT NOT NULL DEFAULT 'message', -- message, action, notice, join, part, quit, topic, mode, nick, system
body TEXT NOT NULL DEFAULT '',
meta TEXT NOT NULL DEFAULT '{}', -- JSON extensible metadata
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_messages_target_ts ON messages(target, ts);
CREATE INDEX IF NOT EXISTS idx_messages_from_user ON messages(from_user_id);
-- Message queue: per-user pending delivery (unread messages)
CREATE TABLE IF NOT EXISTS message_queue (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
message_id TEXT NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
queued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_queue_user_id ON message_queue(user_id, queued_at);
-- Sessions: server-held session state
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY, -- UUID
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_active_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME -- idle timeout
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
-- Server links: federation peer configuration
CREATE TABLE IF NOT EXISTS server_links (
id TEXT PRIMARY KEY, -- UUID
name TEXT NOT NULL UNIQUE, -- human-readable peer name
url TEXT NOT NULL, -- base URL of peer server
shared_key_hash TEXT NOT NULL, -- hashed shared secret
is_active INTEGER NOT NULL DEFAULT 1,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen_at DATETIME
);

View File

@@ -0,0 +1,53 @@
-- Migration 003: Replace UUID-based tables with simple integer-keyed
-- tables for the HTTP API. Drops the 002 tables and recreates them.
PRAGMA foreign_keys = OFF;
DROP TABLE IF EXISTS message_queue;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS server_links;
DROP TABLE IF EXISTS messages;
DROP TABLE IF EXISTS channel_members;
DROP TABLE IF EXISTS auth_tokens;
DROP TABLE IF EXISTS channels;
DROP TABLE IF EXISTS users;
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
nick TEXT NOT NULL UNIQUE,
token TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
topic TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, user_id)
);
CREATE TABLE messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
is_dm INTEGER NOT NULL DEFAULT 0,
dm_target_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_messages_channel ON messages(channel_id, created_at);
CREATE INDEX idx_messages_dm ON messages(user_id, dm_target_id, created_at);
CREATE INDEX idx_users_token ON users(token);
PRAGMA foreign_keys = ON;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -4,11 +4,9 @@ package handlers
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"log/slog" "log/slog"
"net/http" "net/http"
"git.eeqj.de/sneak/chat/internal/broker"
"git.eeqj.de/sneak/chat/internal/config" "git.eeqj.de/sneak/chat/internal/config"
"git.eeqj.de/sneak/chat/internal/db" "git.eeqj.de/sneak/chat/internal/db"
"git.eeqj.de/sneak/chat/internal/globals" "git.eeqj.de/sneak/chat/internal/globals"
@@ -17,8 +15,6 @@ import (
"go.uber.org/fx" "go.uber.org/fx"
) )
var errUnauthorized = errors.New("unauthorized")
// Params defines the dependencies for creating Handlers. // Params defines the dependencies for creating Handlers.
type Params struct { type Params struct {
fx.In fx.In
@@ -35,19 +31,14 @@ type Handlers struct {
params *Params params *Params
log *slog.Logger log *slog.Logger
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
broker *broker.Broker
} }
// New creates a new Handlers instance. // New creates a new Handlers instance.
func New( func New(lc fx.Lifecycle, params Params) (*Handlers, error) {
lc fx.Lifecycle,
params Params,
) (*Handlers, error) {
s := new(Handlers) s := new(Handlers)
s.params = &params s.params = &params
s.log = params.Logger.Get() s.log = params.Logger.Get()
s.hc = params.Healthcheck s.hc = params.Healthcheck
s.broker = broker.New()
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(_ context.Context) error {
@@ -58,17 +49,9 @@ func New(
return s, nil return s, nil
} }
func (s *Handlers) respondJSON( func (s *Handlers) respondJSON(w http.ResponseWriter, _ *http.Request, data any, status int) {
w http.ResponseWriter,
_ *http.Request,
data any,
status int,
) {
w.Header().Set(
"Content-Type",
"application/json; charset=utf-8",
)
w.WriteHeader(status) w.WriteHeader(status)
w.Header().Set("Content-Type", "application/json")
if data != nil { if data != nil {
err := json.NewEncoder(w).Encode(data) err := json.NewEncoder(w).Encode(data)

View File

@@ -0,0 +1,26 @@
package models
import (
"context"
"time"
)
// AuthToken represents an authentication token for a user session.
type AuthToken struct {
Base
Token string `json:"-"`
UserID string `json:"userId"`
CreatedAt time.Time `json:"createdAt"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastUsedAt *time.Time `json:"lastUsedAt,omitempty"`
}
// User returns the user who owns this token.
func (t *AuthToken) User(ctx context.Context) (*User, error) {
if ul := t.GetUserLookup(); ul != nil {
return ul.GetUserByID(ctx, t.UserID)
}
return nil, ErrUserLookupNotAvailable
}

View File

@@ -0,0 +1,96 @@
package models
import (
"context"
"time"
)
// Channel represents a chat channel.
type Channel struct {
Base
ID string `json:"id"`
Name string `json:"name"`
Topic string `json:"topic"`
Modes string `json:"modes"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// Members returns all users who are members of this channel.
func (c *Channel) Members(ctx context.Context) ([]*ChannelMember, error) {
rows, err := c.GetDB().QueryContext(ctx, `
SELECT cm.channel_id, cm.user_id, cm.modes, cm.joined_at,
u.nick
FROM channel_members cm
JOIN users u ON u.id = cm.user_id
WHERE cm.channel_id = ?
ORDER BY cm.joined_at`,
c.ID,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
members := []*ChannelMember{}
for rows.Next() {
m := &ChannelMember{}
m.SetDB(c.db)
err = rows.Scan(
&m.ChannelID, &m.UserID, &m.Modes,
&m.JoinedAt, &m.Nick,
)
if err != nil {
return nil, err
}
members = append(members, m)
}
return members, rows.Err()
}
// RecentMessages returns the most recent messages in this channel.
func (c *Channel) RecentMessages(
ctx context.Context,
limit int,
) ([]*Message, error) {
rows, err := c.GetDB().QueryContext(ctx, `
SELECT id, ts, from_user_id, from_nick,
target, type, body, meta, created_at
FROM messages
WHERE target = ?
ORDER BY ts DESC
LIMIT ?`,
c.Name, limit,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
messages := []*Message{}
for rows.Next() {
msg := &Message{}
msg.SetDB(c.db)
err = rows.Scan(
&msg.ID, &msg.Timestamp, &msg.FromUserID,
&msg.FromNick, &msg.Target, &msg.Type,
&msg.Body, &msg.Meta, &msg.CreatedAt,
)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, rows.Err()
}

View File

@@ -0,0 +1,35 @@
package models
import (
"context"
"time"
)
// ChannelMember represents a user's membership in a channel.
type ChannelMember struct {
Base
ChannelID string `json:"channelId"`
UserID string `json:"userId"`
Modes string `json:"modes"`
JoinedAt time.Time `json:"joinedAt"`
Nick string `json:"nick"` // denormalized from users table
}
// User returns the full User for this membership.
func (cm *ChannelMember) User(ctx context.Context) (*User, error) {
if ul := cm.GetUserLookup(); ul != nil {
return ul.GetUserByID(ctx, cm.UserID)
}
return nil, ErrUserLookupNotAvailable
}
// Channel returns the full Channel for this membership.
func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) {
if cl := cm.GetChannelLookup(); cl != nil {
return cl.GetChannelByID(ctx, cm.ChannelID)
}
return nil, ErrChannelLookupNotAvailable
}

View File

@@ -0,0 +1,20 @@
package models
import (
"time"
)
// Message represents a chat message (channel or DM).
type Message struct {
Base
ID string `json:"id"`
Timestamp time.Time `json:"ts"`
FromUserID string `json:"fromUserId"`
FromNick string `json:"from"`
Target string `json:"to"`
Type string `json:"type"`
Body string `json:"body"`
Meta string `json:"meta"`
CreatedAt time.Time `json:"createdAt"`
}

View File

@@ -0,0 +1,15 @@
package models
import (
"time"
)
// MessageQueueEntry represents a pending message delivery for a user.
type MessageQueueEntry struct {
Base
ID int64 `json:"id"`
UserID string `json:"userId"`
MessageID string `json:"messageId"`
QueuedAt time.Time `json:"queuedAt"`
}

65
internal/models/model.go Normal file
View File

@@ -0,0 +1,65 @@
// Package models defines the data models used by the chat application.
// All model structs embed Base, which provides database access for
// relation-fetching methods directly on model instances.
package models
import (
"context"
"database/sql"
"errors"
)
// DB is the interface that models use to query the database.
// This avoids a circular import with the db package.
type DB interface {
GetDB() *sql.DB
}
// UserLookup provides user lookup by ID without circular imports.
type UserLookup interface {
GetUserByID(ctx context.Context, id string) (*User, error)
}
// ChannelLookup provides channel lookup by ID without circular imports.
type ChannelLookup interface {
GetChannelByID(ctx context.Context, id string) (*Channel, error)
}
// Sentinel errors for model lookup methods.
var (
ErrUserLookupNotAvailable = errors.New("user lookup not available")
ErrChannelLookupNotAvailable = errors.New("channel lookup not available")
)
// Base is embedded in all model structs to provide database access.
type Base struct {
db DB
}
// SetDB injects the database reference into a model.
func (b *Base) SetDB(d DB) {
b.db = d
}
// GetDB returns the database interface for use in model methods.
func (b *Base) GetDB() *sql.DB {
return b.db.GetDB()
}
// GetUserLookup returns the DB as a UserLookup if it implements the interface.
func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn,nolintlint // returns interface by design
if ul, ok := b.db.(UserLookup); ok {
return ul
}
return nil
}
// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface.
func (b *Base) GetChannelLookup() ChannelLookup { //nolint:ireturn,nolintlint // returns interface by design
if cl, ok := b.db.(ChannelLookup); ok {
return cl
}
return nil
}

View File

@@ -0,0 +1,18 @@
package models
import (
"time"
)
// ServerLink represents a federation peer server configuration.
type ServerLink struct {
Base
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
SharedKeyHash string `json:"-"`
IsActive bool `json:"isActive"`
CreatedAt time.Time `json:"createdAt"`
LastSeenAt *time.Time `json:"lastSeenAt,omitempty"`
}

View File

@@ -0,0 +1,26 @@
package models
import (
"context"
"time"
)
// Session represents a server-held user session.
type Session struct {
Base
ID string `json:"id"`
UserID string `json:"userId"`
CreatedAt time.Time `json:"createdAt"`
LastActiveAt time.Time `json:"lastActiveAt"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
}
// User returns the user who owns this session.
func (s *Session) User(ctx context.Context) (*User, error) {
if ul := s.GetUserLookup(); ul != nil {
return ul.GetUserByID(ctx, s.UserID)
}
return nil, ErrUserLookupNotAvailable
}

92
internal/models/user.go Normal file
View File

@@ -0,0 +1,92 @@
package models
import (
"context"
"time"
)
// User represents a registered user account.
type User struct {
Base
ID string `json:"id"`
Nick string `json:"nick"`
PasswordHash string `json:"-"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
LastSeenAt *time.Time `json:"lastSeenAt,omitempty"`
}
// Channels returns all channels the user is a member of.
func (u *User) Channels(ctx context.Context) ([]*Channel, error) {
rows, err := u.GetDB().QueryContext(ctx, `
SELECT c.id, c.name, c.topic, c.modes, c.created_at, c.updated_at
FROM channels c
JOIN channel_members cm ON cm.channel_id = c.id
WHERE cm.user_id = ?
ORDER BY c.name`,
u.ID,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
channels := []*Channel{}
for rows.Next() {
c := &Channel{}
c.SetDB(u.db)
err = rows.Scan(
&c.ID, &c.Name, &c.Topic, &c.Modes,
&c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, err
}
channels = append(channels, c)
}
return channels, rows.Err()
}
// QueuedMessages returns undelivered messages for this user.
func (u *User) QueuedMessages(ctx context.Context) ([]*Message, error) {
rows, err := u.GetDB().QueryContext(ctx, `
SELECT m.id, m.ts, m.from_user_id, m.from_nick,
m.target, m.type, m.body, m.meta, m.created_at
FROM messages m
JOIN message_queue mq ON mq.message_id = m.id
WHERE mq.user_id = ?
ORDER BY mq.queued_at ASC`,
u.ID,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
messages := []*Message{}
for rows.Next() {
msg := &Message{}
msg.SetDB(u.db)
err = rows.Scan(
&msg.ID, &msg.Timestamp, &msg.FromUserID,
&msg.FromNick, &msg.Target, &msg.Type,
&msg.Body, &msg.Meta, &msg.CreatedAt,
)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, rows.Err()
}

View File

@@ -4,6 +4,6 @@ import "time"
const ( const (
httpReadTimeout = 10 * time.Second httpReadTimeout = 10 * time.Second
httpWriteTimeout = 60 * time.Second httpWriteTimeout = 10 * time.Second
maxHeaderBytes = 1 << 20 maxHeaderBytes = 1 << 20
) )

View File

@@ -16,7 +16,7 @@ import (
const routeTimeout = 60 * time.Second const routeTimeout = 60 * time.Second
// SetupRoutes configures the HTTP routes and middleware. // SetupRoutes configures the HTTP routes and middleware chain.
func (s *Server) SetupRoutes() { func (s *Server) SetupRoutes() {
s.router = chi.NewRouter() s.router = chi.NewRouter()
@@ -39,19 +39,13 @@ func (s *Server) SetupRoutes() {
} }
// Health check // Health check
s.router.Get( s.router.Get("/.well-known/healthcheck.json", s.h.HandleHealthCheck())
"/.well-known/healthcheck.json",
s.h.HandleHealthCheck(),
)
// Protected metrics endpoint // Protected metrics endpoint
if viper.GetString("METRICS_USERNAME") != "" { if viper.GetString("METRICS_USERNAME") != "" {
s.router.Group(func(r chi.Router) { s.router.Group(func(r chi.Router) {
r.Use(s.mw.MetricsAuth()) r.Use(s.mw.MetricsAuth())
r.Get("/metrics", r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP))
http.HandlerFunc(
promhttp.Handler().ServeHTTP,
))
}) })
} }
@@ -59,66 +53,55 @@ func (s *Server) SetupRoutes() {
s.router.Route("/api/v1", func(r chi.Router) { s.router.Route("/api/v1", func(r chi.Router) {
r.Get("/server", s.h.HandleServerInfo()) r.Get("/server", s.h.HandleServerInfo())
r.Post("/session", s.h.HandleCreateSession()) r.Post("/session", s.h.HandleCreateSession())
// Unified state and message endpoints
r.Get("/state", s.h.HandleState()) r.Get("/state", s.h.HandleState())
r.Get("/messages", s.h.HandleGetMessages()) r.Get("/messages", s.h.HandleGetMessages())
r.Post("/messages", s.h.HandleSendCommand()) r.Post("/messages", s.h.HandleSendCommand())
r.Get("/history", s.h.HandleGetHistory()) r.Get("/history", s.h.HandleGetHistory())
// Channels
r.Get("/channels", s.h.HandleListAllChannels()) r.Get("/channels", s.h.HandleListAllChannels())
r.Get( r.Get("/channels/{channel}/members", s.h.HandleChannelMembers())
"/channels/{channel}/members",
s.h.HandleChannelMembers(),
)
}) })
// Serve embedded SPA // Serve embedded SPA
s.setupSPA()
}
func (s *Server) setupSPA() {
distFS, err := fs.Sub(web.Dist, "dist") distFS, err := fs.Sub(web.Dist, "dist")
if err != nil { if err != nil {
s.log.Error( s.log.Error("failed to get web dist filesystem", "error", err)
"failed to get web dist filesystem", } else {
"error", err, fileServer := http.FileServer(http.FS(distFS))
)
s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
s.serveSPA(distFS, fileServer, w, r)
})
}
}
func (s *Server) serveSPA(
distFS fs.FS,
fileServer http.Handler,
w http.ResponseWriter,
r *http.Request,
) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
http.Error(w, "filesystem error", http.StatusInternalServerError)
return return
} }
fileServer := http.FileServer(http.FS(distFS)) // Try to serve the file; fall back to index.html for SPA routing.
f, err := readFS.ReadFile(r.URL.Path[1:])
if err != nil || len(f) == 0 {
indexHTML, _ := readFS.ReadFile("index.html")
s.router.Get("/*", func( w.Header().Set("Content-Type", "text/html; charset=utf-8")
w http.ResponseWriter, w.WriteHeader(http.StatusOK)
r *http.Request, _, _ = w.Write(indexHTML)
) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
fileServer.ServeHTTP(w, r)
return return
} }
f, readErr := readFS.ReadFile(r.URL.Path[1:]) fileServer.ServeHTTP(w, r)
if readErr != nil || len(f) == 0 {
indexHTML, indexErr := readFS.ReadFile(
"index.html",
)
if indexErr != nil {
http.NotFound(w, r)
return
}
w.Header().Set(
"Content-Type",
"text/html; charset=utf-8",
)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(indexHTML)
return
}
fileServer.ServeHTTP(w, r)
})
} }

View File

@@ -148,9 +148,7 @@ func (s *Server) cleanupForExit() {
func (s *Server) cleanShutdown() { func (s *Server) cleanShutdown() {
s.exitCode = 0 s.exitCode = 0
ctxShutdown, shutdownCancel := context.WithTimeout( ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
context.Background(), shutdownTimeout,
)
err := s.httpServer.Shutdown(ctxShutdown) err := s.httpServer.Shutdown(ctxShutdown)
if err != nil { if err != nil {

465
web/dist/app.js vendored

File diff suppressed because one or more lines are too long