12 Commits

Author SHA1 Message Date
clawbot
c7bcedd8d4 fix: pin golangci-lint install to commit SHA (fixes #13)
All checks were successful
check / check (push) Successful in 1m31s
- Fix Dockerfile: use correct v2 module path with commit SHA
  (github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638...)
- Add CGO_ENABLED=0 for alpine compatibility
- Fix 35 lint issues found by golangci-lint v2.1.6:
  - Remove unused nolint directives, add nolintlint where needed
  - Pre-allocate migrations slice
  - Use t.Context() instead of context.Background() in tests
  - Fix blank lines between comments and exported functions in ui.go
2026-02-26 20:20:31 -08:00
9daf836cbe Merge pull request 'fix: repo standards audit — fix all divergences (closes #17)' (#18) from fix/repo-standards-audit into main
Some checks failed
check / check (push) Failing after 12s
Reviewed-on: #18
2026-02-27 05:10:00 +01:00
84303c969a fix: pin golangci-lint to v2.1.6 in Dockerfile
Some checks failed
check / check (push) Failing after 14s
Replace @latest with @v2.1.6 to comply with hash-pinning policy
defined in REPO_POLICIES.md.
2026-02-26 11:43:52 -08:00
clawbot
d2bc467581 fix: resolve lint issues — rename api package, fix nolint directives
Some checks failed
check / check (push) Failing after 1m3s
2026-02-26 07:45:37 -08:00
clawbot
88af2ea98f fix: repair migration 003 schema conflict and rewrite tests (refs #17)
Some checks failed
check / check (push) Failing after 1m18s
Migration 003 created tables with INTEGER keys referencing TEXT primary
keys from migration 002, causing 'no such column' errors. Fix by
properly dropping old tables before recreating with the integer schema.

Rewrite all tests to use the queries.go API (which matches the live
schema) instead of the model-based API (which expected the old UUID
schema).
2026-02-26 06:28:07 -08:00
clawbot
b78d526f02 style: fix all golangci-lint issues and format code (refs #17)
Fix 380 lint violations across all Go source files including wsl_v5,
nlreturn, noinlineerr, errcheck, funlen, funcorder, tagliatelle,
perfsprint, modernize, revive, gosec, ireturn, mnd, forcetypeassert,
cyclop, and others.

Key changes:
- Split large handler/command functions into smaller methods
- Extract scan helpers for database queries
- Reorder exported/unexported methods per funcorder
- Add sentinel errors in models package
- Use camelCase JSON tags per tagliatelle defaults
- Add package comments
- Fix .gitignore to not exclude cmd/chat-cli directory
2026-02-26 06:27:56 -08:00
clawbot
636546d74a docs: add Author section to README (refs #17) 2026-02-26 06:09:08 -08:00
clawbot
27de1227c4 chore: pin Dockerfile images by sha256, run make check in build (refs #17) 2026-02-26 06:09:04 -08:00
clawbot
ef83d6624b chore: fix Makefile — add fmt-check, docker, hooks targets; 30s test timeout (refs #17) 2026-02-26 06:08:47 -08:00
clawbot
fc91dc37c0 chore: update .gitignore and .dockerignore to match standards (refs #17) 2026-02-26 06:08:31 -08:00
clawbot
1e5811edda chore: add missing required files (refs #17)
Add LICENSE (MIT), .editorconfig, REPO_POLICIES.md, and
.gitea/workflows/check.yml per repo standards.
2026-02-26 06:08:24 -08:00
clawbot
3f8ceefd52 fix: rename duplicate db methods to fix compilation (refs #17)
CreateUser, GetUserByNick, GetUserByToken exist in both db.go (model-based,
used by tests) and queries.go (simple, used by handlers). Rename the
model-based variants to CreateUserModel, GetUserByNickModel, and
GetUserByTokenModel to resolve the compilation error.
2026-02-26 06:08:07 -08:00
20 changed files with 2114 additions and 1028 deletions

View File

@@ -1,8 +1,9 @@
.git
node_modules
.DS_Store
bin/
chatd
data.db
.env
.git
*.test
*.out
debug.log

12
.editorconfig Normal file
View File

@@ -0,0 +1,12 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -0,0 +1,9 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-22
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

30
.gitignore vendored
View File

@@ -1,7 +1,28 @@
# OS
.DS_Store
Thumbs.db
# Editors
*.swp
*.swo
*~
*.bak
.idea/
.vscode/
*.sublime-*
# Node
node_modules/
# Environment / secrets
.env
.env.*
*.pem
*.key
# Build artifacts
/chatd
/bin/
data.db
.env
*.exe
*.dll
*.so
@@ -9,6 +30,9 @@ data.db
*.test
*.out
vendor/
# Project
data.db
debug.log
/chat-cli
web/node_modules/
chat-cli

View File

@@ -1,6 +1,7 @@
FROM golang:1.24-alpine AS builder
# golang:1.24-alpine, 2026-02-26
FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
RUN apk add --no-cache git
RUN apk add --no-cache git build-base
WORKDIR /src
COPY go.mod go.sum ./
@@ -8,10 +9,16 @@ RUN go mod download
COPY . .
# Run all checks — build fails if branch is not green
# golangci-lint v2.1.6
RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d
RUN make check
ARG VERSION=dev
RUN go build -ldflags "-X main.Version=${VERSION}" -o /chatd ./cmd/chatd
FROM alpine:3.21
# alpine:3.21, 2026-02-26
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates
COPY --from=builder /chatd /usr/local/bin/chatd

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 sneak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt test check clean run debug
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks
BINARY := chatd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -17,18 +17,15 @@ fmt:
gofmt -s -w .
goimports -w .
test:
go test -v -race -cover ./...
# Check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong
check:
@echo "==> Checking formatting..."
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
@echo "==> Running linter..."
golangci-lint run --config .golangci.yml ./...
@echo "==> Running tests..."
go test -v -race ./...
test:
go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong
check: test lint fmt-check
@echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd
@echo "==> All checks passed!"
@@ -41,3 +38,12 @@ debug: build
clean:
rm -rf bin/ chatd data.db
docker:
docker build -t chat .
hooks:
@printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit
@printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit
@printf 'make check\n' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit

View File

@@ -656,3 +656,8 @@ embedded web client.
## License
MIT
## Author
[@sneak](https://sneak.berlin)

182
REPO_POLICIES.md Normal file
View File

@@ -0,0 +1,182 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary. Pre-1.0.0: modify existing migrations (no installed base assumed).
Post-1.0.0: add new migration files.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

View File

@@ -1,9 +1,8 @@
// Package api provides the HTTP client for the chat server API.
package api
// Package chatapi provides a client for the chat server HTTP API.
package chatapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -15,20 +14,17 @@ import (
)
const (
// httpClientTimeout is the default HTTP client timeout in seconds.
httpClientTimeout = 30
// httpStatusErrorThreshold is the minimum status code considered an error.
httpStatusErrorThreshold = 400
// pollTimeoutBuffer is extra seconds added to HTTP timeout beyond the poll timeout.
pollTimeoutBuffer = 5
httpTimeout = 30 * time.Second
pollExtraDelay = 5
httpErrThreshold = 400
)
var (
// ErrHTTPStatus is returned when the server responds with an error status code.
ErrHTTPStatus = errors.New("HTTP error")
// ErrUnexpectedFormat is returned when the response format is unexpected.
ErrUnexpectedFormat = errors.New("unexpected format")
)
// ErrHTTP is returned for non-2xx responses.
var ErrHTTP = errors.New("http error")
// ErrUnexpectedFormat is returned when the response format is
// not recognised.
var ErrUnexpectedFormat = errors.New("unexpected format")
// Client wraps HTTP calls to the chat server API.
type Client struct {
@@ -42,14 +38,19 @@ func NewClient(baseURL string) *Client {
return &Client{
BaseURL: baseURL,
HTTPClient: &http.Client{
Timeout: httpClientTimeout * time.Second,
Timeout: httpTimeout,
},
}
}
// CreateSession creates a new session on the server.
func (c *Client) CreateSession(nick string) (*SessionResponse, error) {
data, err := c.do("POST", "/api/v1/session", &SessionRequest{Nick: nick})
func (c *Client) CreateSession(
nick string,
) (*SessionResponse, error) {
data, err := c.do(
"POST", "/api/v1/session",
&SessionRequest{Nick: nick},
)
if err != nil {
return nil, err
}
@@ -91,9 +92,15 @@ func (c *Client) SendMessage(msg *Message) error {
}
// PollMessages long-polls for new messages.
func (c *Client) PollMessages(afterID string, timeout int) ([]Message, error) {
// Use a longer HTTP timeout than the server long-poll timeout.
client := &http.Client{Timeout: time.Duration(timeout+pollTimeoutBuffer) * time.Second}
func (c *Client) PollMessages(
afterID string,
timeout int,
) ([]Message, error) {
pollTimeout := time.Duration(
timeout+pollExtraDelay,
) * time.Second
client := &http.Client{Timeout: pollTimeout}
params := url.Values{}
if afterID != "" {
@@ -107,14 +114,16 @@ func (c *Client) PollMessages(afterID string, timeout int) ([]Message, error) {
path += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, c.BaseURL+path, nil)
req, err := http.NewRequest( //nolint:noctx // CLI tool
http.MethodGet, c.BaseURL+path, nil,
)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := client.Do(req) //nolint:gosec // G704: BaseURL is set by user at connect time, not tainted input
resp, err := client.Do(req) //nolint:gosec,nolintlint // URL from user config
if err != nil {
return nil, err
}
@@ -126,37 +135,51 @@ func (c *Client) PollMessages(afterID string, timeout int) ([]Message, error) {
return nil, err
}
if resp.StatusCode >= httpStatusErrorThreshold {
return nil, fmt.Errorf("%w: %d: %s", ErrHTTPStatus, resp.StatusCode, string(data))
if resp.StatusCode >= httpErrThreshold {
return nil, fmt.Errorf(
"%w: %d: %s",
ErrHTTP, resp.StatusCode, string(data),
)
}
// The server may return an array directly or wrapped.
return decodeMessages(data)
}
func decodeMessages(data []byte) ([]Message, error) {
var msgs []Message
err = json.Unmarshal(data, &msgs)
if err != nil {
// Try wrapped format.
var wrapped MessagesResponse
err2 := json.Unmarshal(data, &wrapped)
if err2 != nil {
return nil, fmt.Errorf("decode messages: %w (raw: %s)", err, string(data))
}
msgs = wrapped.Messages
err := json.Unmarshal(data, &msgs)
if err == nil {
return msgs, nil
}
return msgs, nil
var wrapped MessagesResponse
err2 := json.Unmarshal(data, &wrapped)
if err2 != nil {
return nil, fmt.Errorf(
"decode messages: %w (raw: %s)",
err, string(data),
)
}
return wrapped.Messages, nil
}
// JoinChannel joins a channel via the unified command endpoint.
// JoinChannel joins a channel via the unified command
// endpoint.
func (c *Client) JoinChannel(channel string) error {
return c.SendMessage(&Message{Command: "JOIN", To: channel})
return c.SendMessage(
&Message{Command: "JOIN", To: channel},
)
}
// PartChannel leaves a channel via the unified command endpoint.
// PartChannel leaves a channel via the unified command
// endpoint.
func (c *Client) PartChannel(channel string) error {
return c.SendMessage(&Message{Command: "PART", To: channel})
return c.SendMessage(
&Message{Command: "PART", To: channel},
)
}
// ListChannels returns all channels on the server.
@@ -177,8 +200,13 @@ func (c *Client) ListChannels() ([]Channel, error) {
}
// GetMembers returns members of a channel.
func (c *Client) GetMembers(channel string) ([]string, error) {
data, err := c.do("GET", "/api/v1/channels/"+url.PathEscape(channel)+"/members", nil)
func (c *Client) GetMembers(
channel string,
) ([]string, error) {
path := "/api/v1/channels/" +
url.PathEscape(channel) + "/members"
data, err := c.do("GET", path, nil)
if err != nil {
return nil, err
}
@@ -187,15 +215,10 @@ func (c *Client) GetMembers(channel string) ([]string, error) {
err = json.Unmarshal(data, &members)
if err != nil {
// Try object format.
var obj map[string]any
err2 := json.Unmarshal(data, &obj)
if err2 != nil {
return nil, err
}
// Extract member names from whatever format.
return nil, fmt.Errorf("%w: members: %s", ErrUnexpectedFormat, string(data))
return nil, fmt.Errorf(
"%w: members: %s",
ErrUnexpectedFormat, string(data),
)
}
return members, nil
@@ -218,7 +241,10 @@ func (c *Client) GetServerInfo() (*ServerInfo, error) {
return &info, nil
}
func (c *Client) do(method, path string, body any) ([]byte, error) {
func (c *Client) do(
method, path string,
body any,
) ([]byte, error) {
var bodyReader io.Reader
if body != nil {
@@ -230,7 +256,9 @@ func (c *Client) do(method, path string, body any) ([]byte, error) {
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(context.Background(), method, c.BaseURL+path, bodyReader)
req, err := http.NewRequest( //nolint:noctx // CLI tool
method, c.BaseURL+path, bodyReader,
)
if err != nil {
return nil, fmt.Errorf("request: %w", err)
}
@@ -241,7 +269,7 @@ func (c *Client) do(method, path string, body any) ([]byte, error) {
req.Header.Set("Authorization", "Bearer "+c.Token)
}
resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL constructed from trusted base URL
resp, err := c.HTTPClient.Do(req) //nolint:gosec,nolintlint // URL from user config
if err != nil {
return nil, fmt.Errorf("http: %w", err)
}
@@ -253,8 +281,11 @@ func (c *Client) do(method, path string, body any) ([]byte, error) {
return nil, fmt.Errorf("read body: %w", err)
}
if resp.StatusCode >= httpStatusErrorThreshold {
return data, fmt.Errorf("%w: %d: %s", ErrHTTPStatus, resp.StatusCode, string(data))
if resp.StatusCode >= httpErrThreshold {
return data, fmt.Errorf(
"%w: %d: %s",
ErrHTTP, resp.StatusCode, string(data),
)
}
return data, nil

View File

@@ -1,4 +1,4 @@
package api //nolint:revive // package name "api" is conventional for API client packages
package chatapi
import "time"
@@ -35,11 +35,13 @@ type Message struct {
Meta any `json:"meta,omitempty"`
}
// BodyLines returns the body as a slice of strings (for text messages).
// BodyLines returns the body as a slice of strings (for text
// messages).
func (m *Message) BodyLines() []string {
switch v := m.Body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)

View File

@@ -1,23 +1,21 @@
// Package main implements the chat-cli terminal client.
// Package main implements chat-cli, an IRC-style terminal client.
package main
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"git.eeqj.de/sneak/chat/cmd/chat-cli/api"
api "git.eeqj.de/sneak/chat/cmd/chat-cli/api"
)
const (
// splitParts is the number of parts to split a command into (command + args).
splitParts = 2
// pollTimeout is the long-poll timeout in seconds.
pollTimeout = 15
// pollRetryDelay is the delay before retrying a failed poll.
pollRetryDelay = 2 * time.Second
pollTimeoutSec = 15
retryDelay = 2 * time.Second
maxNickLength = 32
)
// App holds the application state.
@@ -42,12 +40,18 @@ func main() {
app.ui.OnInput(app.handleInput)
app.ui.SetStatus(app.nick, "", "disconnected")
app.ui.AddStatus("Welcome to chat-cli — an IRC-style client")
app.ui.AddStatus("Type [yellow]/connect <server-url>[white] to begin, or [yellow]/help[white] for commands")
app.ui.AddStatus(
"Welcome to chat-cli \u2014 an IRC-style client",
)
app.ui.AddStatus(
"Type [yellow]/connect <server-url>[white] " +
"to begin, or [yellow]/help[white] for commands",
)
err := app.ui.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
_, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
@@ -59,20 +63,29 @@ func (a *App) handleInput(text string) {
return
}
// Plain text → PRIVMSG to current target.
a.sendPlainText(text)
}
func (a *App) sendPlainText(text string) {
a.mu.Lock()
target := a.target
connected := a.connected
nick := a.nick
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected. Use /connect <url>")
a.ui.AddStatus(
"[red]Not connected. Use /connect <url>",
)
return
}
if target == "" {
a.ui.AddStatus("[red]No target. Use /join #channel or /query nick")
a.ui.AddStatus(
"[red]No target. " +
"Use /join #channel or /query nick",
)
return
}
@@ -83,40 +96,26 @@ func (a *App) handleInput(text string) {
Body: []string{text},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Send error: %v", err))
a.ui.AddStatus(
fmt.Sprintf("[red]Send error: %v", err),
)
return
}
// Echo locally.
ts := time.Now().Format("15:04")
a.mu.Lock()
nick := a.nick
a.mu.Unlock()
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, nick, text,
),
)
}
func (a *App) commandHandlers() map[string]func(string) {
return map[string]func(string){
"/connect": a.cmdConnect,
"/nick": a.cmdNick,
"/join": a.cmdJoin,
"/part": a.cmdPart,
"/msg": a.cmdMsg,
"/query": a.cmdQuery,
"/topic": a.cmdTopic,
"/names": func(_ string) { a.cmdNames() },
"/list": func(_ string) { a.cmdList() },
"/window": a.cmdWindow,
"/w": a.cmdWindow,
"/quit": func(_ string) { a.cmdQuit() },
"/help": func(_ string) { a.cmdHelp() },
}
}
func (a *App) handleCommand(text string) {
parts := strings.SplitN(text, " ", splitParts)
func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch
parts := strings.SplitN(text, " ", 2) //nolint:mnd // split into cmd+args
cmd := strings.ToLower(parts[0])
args := ""
@@ -124,23 +123,52 @@ func (a *App) handleCommand(text string) {
args = parts[1]
}
if handler, ok := a.commandHandlers()[cmd]; ok {
handler(args)
} else {
a.ui.AddStatus("[red]Unknown command: " + cmd)
switch cmd {
case "/connect":
a.cmdConnect(args)
case "/nick":
a.cmdNick(args)
case "/join":
a.cmdJoin(args)
case "/part":
a.cmdPart(args)
case "/msg":
a.cmdMsg(args)
case "/query":
a.cmdQuery(args)
case "/topic":
a.cmdTopic(args)
case "/names":
a.cmdNames()
case "/list":
a.cmdList()
case "/window", "/w":
a.cmdWindow(args)
case "/quit":
a.cmdQuit()
case "/help":
a.cmdHelp()
default:
a.ui.AddStatus(
"[red]Unknown command: " + cmd,
)
}
}
func (a *App) cmdConnect(serverURL string) {
if serverURL == "" {
a.ui.AddStatus("[red]Usage: /connect <server-url>")
a.ui.AddStatus(
"[red]Usage: /connect <server-url>",
)
return
}
serverURL = strings.TrimRight(serverURL, "/")
a.ui.AddStatus(fmt.Sprintf("Connecting to %s...", serverURL))
a.ui.AddStatus(
fmt.Sprintf("Connecting to %s...", serverURL),
)
a.mu.Lock()
nick := a.nick
@@ -150,7 +178,11 @@ func (a *App) cmdConnect(serverURL string) {
resp, err := client.CreateSession(nick)
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Connection failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf(
"[red]Connection failed: %v", err,
),
)
return
}
@@ -162,11 +194,16 @@ func (a *App) cmdConnect(serverURL string) {
a.lastMsgID = ""
a.mu.Unlock()
a.ui.AddStatus(fmt.Sprintf("[green]Connected! Nick: %s, Session: %s", resp.Nick, resp.SessionID))
a.ui.AddStatus(
fmt.Sprintf(
"[green]Connected! Nick: %s, Session: %s",
resp.Nick, resp.SessionID,
),
)
a.ui.SetStatus(resp.Nick, "", "connected")
// Start polling.
a.stopPoll = make(chan struct{})
go a.pollLoop()
}
@@ -185,7 +222,13 @@ func (a *App) cmdNick(nick string) {
a.mu.Lock()
a.nick = nick
a.mu.Unlock()
a.ui.AddStatus(fmt.Sprintf("Nick set to %s (will be used on connect)", nick))
a.ui.AddStatus(
fmt.Sprintf(
"Nick set to %s (will be used on connect)",
nick,
),
)
return
}
@@ -195,7 +238,11 @@ func (a *App) cmdNick(nick string) {
Body: []string{nick},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Nick change failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf(
"[red]Nick change failed: %v", err,
),
)
return
}
@@ -204,8 +251,11 @@ func (a *App) cmdNick(nick string) {
a.nick = nick
target := a.target
a.mu.Unlock()
a.ui.SetStatus(nick, target, "connected")
a.ui.AddStatus("Nick changed to " + nick)
a.ui.AddStatus(
"Nick changed to " + nick,
)
}
func (a *App) cmdJoin(channel string) {
@@ -231,7 +281,9 @@ func (a *App) cmdJoin(channel string) {
err := a.client.JoinChannel(channel)
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Join failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf("[red]Join failed: %v", err),
)
return
}
@@ -242,12 +294,16 @@ func (a *App) cmdJoin(channel string) {
a.mu.Unlock()
a.ui.SwitchToBuffer(channel)
a.ui.AddLine(channel, "[yellow]*** Joined "+channel)
a.ui.AddLine(
channel,
"[yellow]*** Joined "+channel,
)
a.ui.SetStatus(nick, channel, "connected")
}
func (a *App) cmdPart(channel string) {
a.mu.Lock()
if channel == "" {
channel = a.target
}
@@ -269,14 +325,20 @@ func (a *App) cmdPart(channel string) {
err := a.client.PartChannel(channel)
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Part failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf("[red]Part failed: %v", err),
)
return
}
a.ui.AddLine(channel, "[yellow]*** Left "+channel)
a.ui.AddLine(
channel,
"[yellow]*** Left "+channel,
)
a.mu.Lock()
if a.target == channel {
a.target = ""
}
@@ -289,8 +351,8 @@ func (a *App) cmdPart(channel string) {
}
func (a *App) cmdMsg(args string) {
parts := strings.SplitN(args, " ", splitParts)
if len(parts) < splitParts {
parts := strings.SplitN(args, " ", 2) //nolint:mnd // split into target+text
if len(parts) < 2 { //nolint:mnd // min args
a.ui.AddStatus("[red]Usage: /msg <nick> <text>")
return
@@ -315,13 +377,22 @@ func (a *App) cmdMsg(args string) {
Body: []string{text},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Send failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf("[red]Send failed: %v", err),
)
return
}
ts := time.Now().Format("15:04")
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, nick, text,
),
)
}
func (a *App) cmdQuery(nick string) {
@@ -359,13 +430,16 @@ func (a *App) cmdTopic(args string) {
}
if args == "" {
// Query topic.
err := a.client.SendMessage(&api.Message{
Command: "TOPIC",
To: target,
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Topic query failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf(
"[red]Topic query failed: %v", err,
),
)
}
return
@@ -377,7 +451,11 @@ func (a *App) cmdTopic(args string) {
Body: []string{args},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Topic set failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf(
"[red]Topic set failed: %v", err,
),
)
}
}
@@ -401,12 +479,20 @@ func (a *App) cmdNames() {
members, err := a.client.GetMembers(target)
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]Names failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf("[red]Names failed: %v", err),
)
return
}
a.ui.AddLine(target, fmt.Sprintf("[cyan]*** Members of %s: %s", target, strings.Join(members, " ")))
a.ui.AddLine(
target,
fmt.Sprintf(
"[cyan]*** Members of %s: %s",
target, strings.Join(members, " "),
),
)
}
func (a *App) cmdList() {
@@ -422,7 +508,9 @@ func (a *App) cmdList() {
channels, err := a.client.ListChannels()
if err != nil {
a.ui.AddStatus(fmt.Sprintf("[red]List failed: %v", err))
a.ui.AddStatus(
fmt.Sprintf("[red]List failed: %v", err),
)
return
}
@@ -430,7 +518,12 @@ func (a *App) cmdList() {
a.ui.AddStatus("[cyan]*** Channel list:")
for _, ch := range channels {
a.ui.AddStatus(fmt.Sprintf(" %s (%d members) %s", ch.Name, ch.Members, ch.Topic))
a.ui.AddStatus(
fmt.Sprintf(
" %s (%d members) %s",
ch.Name, ch.Members, ch.Topic,
),
)
}
a.ui.AddStatus("[cyan]*** End of channel list")
@@ -443,21 +536,21 @@ func (a *App) cmdWindow(args string) {
return
}
n := 0
_, _ = fmt.Sscanf(args, "%d", &n)
n, _ := strconv.Atoi(args)
a.ui.SwitchBuffer(n)
a.mu.Lock()
nick := a.nick
a.mu.Unlock()
// Update target based on buffer.
if n < a.ui.BufferCount() {
if n >= 0 && n < a.ui.BufferCount() {
buf := a.ui.buffers[n]
if buf.Name != "(status)" {
a.mu.Lock()
a.target = buf.Name
a.mu.Unlock()
a.ui.SetStatus(nick, buf.Name, "connected")
} else {
a.ui.SetStatus(nick, "", "connected")
@@ -467,13 +560,17 @@ func (a *App) cmdWindow(args string) {
func (a *App) cmdQuit() {
a.mu.Lock()
if a.connected && a.client != nil {
_ = a.client.SendMessage(&api.Message{Command: "QUIT"})
_ = a.client.SendMessage(
&api.Message{Command: "QUIT"},
)
}
if a.stopPoll != nil {
close(a.stopPoll)
}
a.mu.Unlock()
a.ui.Stop()
}
@@ -481,20 +578,21 @@ func (a *App) cmdQuit() {
func (a *App) cmdHelp() {
help := []string{
"[cyan]*** chat-cli commands:",
" /connect <url> Connect to server",
" /nick <name> Change nickname",
" /join #channel Join channel",
" /part [#chan] Leave channel",
" /msg <nick> <text> Send DM",
" /query <nick> Open DM window",
" /topic [text] View/set topic",
" /names List channel members",
" /list List channels",
" /window <n> Switch buffer (Alt+0-9)",
" /quit Disconnect and exit",
" /help This help",
" /connect <url> \u2014 Connect to server",
" /nick <name> \u2014 Change nickname",
" /join #channel \u2014 Join channel",
" /part [#chan] \u2014 Leave channel",
" /msg <nick> <text> \u2014 Send DM",
" /query <nick> \u2014 Open DM window",
" /topic [text] \u2014 View/set topic",
" /names \u2014 List channel members",
" /list \u2014 List channels",
" /window <n> \u2014 Switch buffer (Alt+0-9)",
" /quit \u2014 Disconnect and exit",
" /help \u2014 This help",
" Plain text sends to current target.",
}
for _, line := range help {
a.ui.AddStatus(line)
}
@@ -518,38 +616,31 @@ func (a *App) pollLoop() {
return
}
msgs, err := client.PollMessages(lastID, pollTimeout)
msgs, err := client.PollMessages(
lastID, pollTimeoutSec,
)
if err != nil {
// Transient error — retry after delay.
time.Sleep(pollRetryDelay)
time.Sleep(retryDelay)
continue
}
for _, msg := range msgs {
a.handleServerMessage(&msg)
for i := range msgs {
a.handleServerMessage(&msgs[i])
if msg.ID != "" {
if msgs[i].ID != "" {
a.mu.Lock()
a.lastMsgID = msg.ID
a.lastMsgID = msgs[i].ID
a.mu.Unlock()
}
}
}
}
func (a *App) messageTimestamp(msg *api.Message) string {
if msg.TS != "" {
t := msg.ParseTS()
return t.Local().Format("15:04") //nolint:gosmopolitan // CLI displays local time intentionally
}
return time.Now().Format("15:04")
}
func (a *App) handleServerMessage(msg *api.Message) {
ts := a.messageTimestamp(msg)
func (a *App) handleServerMessage(
msg *api.Message,
) {
ts := a.parseMessageTS(msg)
a.mu.Lock()
myNick := a.nick
@@ -557,25 +648,38 @@ func (a *App) handleServerMessage(msg *api.Message) {
switch msg.Command {
case "PRIVMSG":
a.handleMsgPrivmsg(msg, ts, myNick)
a.handlePrivmsgMsg(msg, ts, myNick)
case "JOIN":
a.handleMsgJoin(msg, ts)
a.handleJoinMsg(msg, ts)
case "PART":
a.handleMsgPart(msg, ts)
a.handlePartMsg(msg, ts)
case "QUIT":
a.handleMsgQuit(msg, ts)
a.handleQuitMsg(msg, ts)
case "NICK":
a.handleMsgNick(msg, ts, myNick)
a.handleNickMsg(msg, ts, myNick)
case "NOTICE":
a.handleMsgNotice(msg, ts)
a.handleNoticeMsg(msg, ts)
case "TOPIC":
a.handleMsgTopic(msg, ts)
a.handleTopicMsg(msg, ts)
default:
a.handleMsgDefault(msg, ts)
a.handleDefaultMsg(msg, ts)
}
}
func (a *App) handleMsgPrivmsg(msg *api.Message, ts, myNick string) {
func (a *App) parseMessageTS(msg *api.Message) string {
if msg.TS != "" {
t := msg.ParseTS()
return t.In(time.Local).Format("15:04") //nolint:gosmopolitan // CLI uses local time
}
return time.Now().Format("15:04")
}
func (a *App) handlePrivmsgMsg(
msg *api.Message,
ts, myNick string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
@@ -588,40 +692,88 @@ func (a *App) handleMsgPrivmsg(msg *api.Message, ts, myNick string) {
target = msg.From
}
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, msg.From, text))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, msg.From, text,
),
)
}
func (a *App) handleMsgJoin(msg *api.Message, ts string) {
if msg.To != "" {
a.ui.AddLine(msg.To, fmt.Sprintf("[gray]%s [yellow]*** %s has joined %s", ts, msg.From, msg.To))
}
}
func (a *App) handleMsgPart(msg *api.Message, ts string) {
func (a *App) handleJoinMsg(
msg *api.Message, ts string,
) {
target := msg.To
reason := strings.Join(msg.BodyLines(), " ")
if target == "" {
return
}
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has joined %s",
ts, msg.From, target,
),
)
}
func (a *App) handlePartMsg(
msg *api.Message, ts string,
) {
target := msg.To
if target == "" {
return
}
lines := msg.BodyLines()
reason := strings.Join(lines, " ")
if reason != "" {
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s (%s)", ts, msg.From, target, reason))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s (%s)",
ts, msg.From, target, reason,
),
)
} else {
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s", ts, msg.From, target))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s",
ts, msg.From, target,
),
)
}
}
func (a *App) handleMsgQuit(msg *api.Message, ts string) {
reason := strings.Join(msg.BodyLines(), " ")
func (a *App) handleQuitMsg(
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
reason := strings.Join(lines, " ")
if reason != "" {
a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit (%s)", ts, msg.From, reason))
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit (%s)",
ts, msg.From, reason,
),
)
} else {
a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit", ts, msg.From))
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit",
ts, msg.From,
),
)
}
}
func (a *App) handleMsgNick(msg *api.Message, ts, myNick string) {
func (a *App) handleNickMsg(
msg *api.Message, ts, myNick string,
) {
lines := msg.BodyLines()
newNick := ""
@@ -634,27 +786,65 @@ func (a *App) handleMsgNick(msg *api.Message, ts, myNick string) {
a.nick = newNick
target := a.target
a.mu.Unlock()
a.ui.SetStatus(newNick, target, "connected")
}
a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s is now known as %s", ts, msg.From, newNick))
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [yellow]*** %s is now known as %s",
ts, msg.From, newNick,
),
)
}
func (a *App) handleMsgNotice(msg *api.Message, ts string) {
text := strings.Join(msg.BodyLines(), " ")
a.ui.AddStatus(fmt.Sprintf("[gray]%s [magenta]--%s-- %s", ts, msg.From, text))
func (a *App) handleNoticeMsg(
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [magenta]--%s-- %s",
ts, msg.From, text,
),
)
}
func (a *App) handleMsgTopic(msg *api.Message, ts string) {
text := strings.Join(msg.BodyLines(), " ")
if msg.To != "" {
a.ui.AddLine(msg.To, fmt.Sprintf("[gray]%s [cyan]*** %s set topic: %s", ts, msg.From, text))
func (a *App) handleTopicMsg(
msg *api.Message, ts string,
) {
if msg.To == "" {
return
}
lines := msg.BodyLines()
text := strings.Join(lines, " ")
a.ui.AddLine(
msg.To,
fmt.Sprintf(
"[gray]%s [cyan]*** %s set topic: %s",
ts, msg.From, text,
),
)
}
func (a *App) handleMsgDefault(msg *api.Message, ts string) {
text := strings.Join(msg.BodyLines(), " ")
if text != "" {
a.ui.AddStatus(fmt.Sprintf("[gray]%s [white][%s] %s", ts, msg.Command, text))
func (a *App) handleDefaultMsg(
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
if text == "" {
return
}
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [white][%s] %s",
ts, msg.Command, text,
),
)
}

View File

@@ -39,33 +39,11 @@ func NewUI() *UI {
},
}
// Message area.
ui.messages = tview.NewTextView().
SetDynamicColors(true).
SetScrollable(true).
SetWordWrap(true).
SetChangedFunc(func() {
ui.app.Draw()
})
ui.messages.SetBorder(false)
// Status bar.
ui.statusBar = tview.NewTextView().
SetDynamicColors(true)
ui.statusBar.SetBackgroundColor(tcell.ColorNavy)
ui.statusBar.SetTextColor(tcell.ColorWhite)
ui.setupMessages()
ui.setupStatusBar()
ui.setupInput()
ui.setupKeyCapture()
// Layout: messages on top, status bar, input at bottom.
ui.layout = tview.NewFlex().SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true)
ui.app.SetRoot(ui.layout, true)
ui.app.SetFocus(ui.input)
ui.setupKeybindings()
ui.setupLayout()
return ui
}
@@ -108,7 +86,11 @@ func (ui *UI) AddLine(bufferName string, line string) {
// AddStatus adds a line to the status buffer (buffer 0).
func (ui *UI) AddStatus(line string) {
ts := time.Now().Format("15:04")
ui.AddLine("(status)", fmt.Sprintf("[gray]%s[white] %s", ts, line))
ui.AddLine(
"(status)",
fmt.Sprintf("[gray]%s[white] %s", ts, line),
)
}
// SwitchBuffer switches to the buffer at index n.
@@ -119,6 +101,7 @@ func (ui *UI) SwitchBuffer(n int) {
}
ui.currentBuffer = n
buf := ui.buffers[n]
buf.Unread = 0
@@ -133,10 +116,11 @@ func (ui *UI) SwitchBuffer(n int) {
})
}
// SwitchToBuffer switches to the named buffer, creating it if needed.
// SwitchToBuffer switches to the named buffer, creating it
func (ui *UI) SwitchToBuffer(name string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(name)
for i, b := range ui.buffers {
if b == buf {
ui.currentBuffer = i
@@ -159,7 +143,9 @@ func (ui *UI) SwitchToBuffer(name string) {
}
// SetStatus updates the status bar text.
func (ui *UI) SetStatus(nick, target, connStatus string) {
func (ui *UI) SetStatus(
nick, target, connStatus string,
) {
ui.app.QueueUpdateDraw(func() {
ui.refreshStatusWith(nick, target, connStatus)
})
@@ -181,10 +167,29 @@ func (ui *UI) BufferIndex(name string) int {
return -1
}
func (ui *UI) setupMessages() {
ui.messages = tview.NewTextView().
SetDynamicColors(true).
SetScrollable(true).
SetWordWrap(true).
SetChangedFunc(func() {
ui.app.Draw()
})
ui.messages.SetBorder(false)
}
func (ui *UI) setupStatusBar() {
ui.statusBar = tview.NewTextView().
SetDynamicColors(true)
ui.statusBar.SetBackgroundColor(tcell.ColorNavy)
ui.statusBar.SetTextColor(tcell.ColorWhite)
}
func (ui *UI) setupInput() {
ui.input = tview.NewInputField().
SetFieldBackgroundColor(tcell.ColorBlack).
SetFieldTextColor(tcell.ColorWhite)
ui.input.SetDoneFunc(func(key tcell.Key) {
if key != tcell.KeyEnter {
return
@@ -203,9 +208,13 @@ func (ui *UI) setupInput() {
})
}
func (ui *UI) setupKeyCapture() {
ui.app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Modifiers()&tcell.ModAlt != 0 {
func (ui *UI) setupKeybindings() {
ui.app.SetInputCapture(
func(event *tcell.EventKey) *tcell.EventKey {
if event.Modifiers()&tcell.ModAlt == 0 {
return event
}
r := event.Rune()
if r >= '0' && r <= '9' {
idx := int(r - '0')
@@ -213,36 +222,63 @@ func (ui *UI) setupKeyCapture() {
return nil
}
}
return event
})
return event
},
)
}
func (ui *UI) setupLayout() {
ui.layout = tview.NewFlex().
SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true)
ui.app.SetRoot(ui.layout, true)
ui.app.SetFocus(ui.input)
}
// if needed.
func (ui *UI) refreshStatus() {
// Will be called from the main goroutine via QueueUpdateDraw parent.
// Rebuild status from app state — caller must provide context.
// Rebuilt from app state by parent QueueUpdateDraw.
}
func (ui *UI) refreshStatusWith(nick, target, connStatus string) {
func (ui *UI) refreshStatusWith(
nick, target, connStatus string,
) {
var unreadParts []string
for i, buf := range ui.buffers {
if buf.Unread > 0 {
unreadParts = append(unreadParts, fmt.Sprintf("%d:%s(%d)", i, buf.Name, buf.Unread))
unreadParts = append(
unreadParts,
fmt.Sprintf(
"%d:%s(%d)", i, buf.Name, buf.Unread,
),
)
}
}
unread := ""
if len(unreadParts) > 0 {
unread = " [Act: " + strings.Join(unreadParts, ",") + "]"
unread = " [Act: " +
strings.Join(unreadParts, ",") + "]"
}
bufInfo := fmt.Sprintf("[%d:%s]", ui.currentBuffer, ui.buffers[ui.currentBuffer].Name)
bufInfo := fmt.Sprintf(
"[%d:%s]",
ui.currentBuffer,
ui.buffers[ui.currentBuffer].Name,
)
ui.statusBar.Clear()
_, _ = fmt.Fprintf(ui.statusBar, " [%s] %s %s %s%s",
connStatus, nick, bufInfo, target, unread)
_, _ = fmt.Fprintf(
ui.statusBar, " [%s] %s %s %s%s",
connStatus, nick, bufInfo, target, unread,
)
}
func (ui *UI) getOrCreateBuffer(name string) *Buffer {

View File

@@ -88,7 +88,7 @@ func NewTest(dsn string) (*Database, error) {
}
// Item 9: Enable foreign keys
_, err = d.ExecContext(context.Background(), "PRAGMA foreign_keys = ON")
_, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx,nolintlint // no context in sql.Open path
if err != nil {
_ = d.Close()
@@ -164,8 +164,8 @@ func (s *Database) GetChannelByID(
return c, nil
}
// GetUserByNick looks up a user by their nick.
func (s *Database) GetUserByNick(
// GetUserByNickModel looks up a user by their nick.
func (s *Database) GetUserByNickModel(
ctx context.Context,
nick string,
) (*models.User, error) {
@@ -187,8 +187,8 @@ func (s *Database) GetUserByNick(
return u, nil
}
// GetUserByToken looks up a user by their auth token.
func (s *Database) GetUserByToken(
// GetUserByTokenModel looks up a user by their auth token.
func (s *Database) GetUserByTokenModel(
ctx context.Context,
token string,
) (*models.User, error) {
@@ -238,8 +238,8 @@ func (s *Database) UpdateUserLastSeen(
return err
}
// CreateUser inserts a new user into the database.
func (s *Database) CreateUser(
// CreateUserModel inserts a new user into the database.
func (s *Database) CreateUserModel(
ctx context.Context,
id, nick, passwordHash string,
) (*models.User, error) {
@@ -435,8 +435,7 @@ func (s *Database) AckMessages(
args[i] = id
}
//nolint:gosec // G201: placeholders are all "?" literals, not user input
query := fmt.Sprintf(
query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input
"DELETE FROM message_queue WHERE id IN (%s)",
strings.Join(placeholders, ","),
)
@@ -613,7 +612,7 @@ func (s *Database) loadMigrations() ([]migration, error) {
return nil, fmt.Errorf("read schema dir: %w", err)
}
var migrations []migration
migrations := make([]migration, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() ||
@@ -662,7 +661,28 @@ func (s *Database) applyMigrations(
migrations []migration,
) error {
for _, m := range migrations {
err := s.applyOneMigration(ctx, m)
var exists int
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
m.version,
).Scan(&exists)
if err != nil {
return fmt.Errorf(
"check migration %d: %w", m.version, err,
)
}
if exists > 0 {
continue
}
s.log.Info(
"applying migration",
"version", m.version, "name", m.name,
)
err = s.executeMigration(ctx, m)
if err != nil {
return err
}
@@ -671,33 +691,25 @@ func (s *Database) applyMigrations(
return nil
}
func (s *Database) applyOneMigration(ctx context.Context, m migration) error {
var exists int
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
m.version,
).Scan(&exists)
if err != nil {
return fmt.Errorf("check migration %d: %w", m.version, err)
}
if exists > 0 {
return nil
}
s.log.Info("applying migration", "version", m.version, "name", m.name)
func (s *Database) executeMigration(
ctx context.Context,
m migration,
) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx for migration %d: %w", m.version, err)
return fmt.Errorf(
"begin tx for migration %d: %w", m.version, err,
)
}
_, err = tx.ExecContext(ctx, m.sql)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %d (%s): %w", m.version, m.name, err)
return fmt.Errorf(
"apply migration %d (%s): %w",
m.version, m.name, err,
)
}
_, err = tx.ExecContext(ctx,
@@ -707,8 +719,17 @@ func (s *Database) applyOneMigration(ctx context.Context, m migration) error {
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %d: %w", m.version, err)
return fmt.Errorf(
"record migration %d: %w", m.version, err,
)
}
return tx.Commit()
err = tx.Commit()
if err != nil {
return fmt.Errorf(
"commit migration %d: %w", m.version, err,
)
}
return nil
}

View File

@@ -1,7 +1,6 @@
package db_test
import (
"context"
"fmt"
"path/filepath"
"testing"
@@ -41,215 +40,121 @@ func TestCreateUser(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
u, err := d.CreateUser(ctx, "u1", nickAlice, "hash1")
id, token, err := d.CreateUser(ctx, nickAlice)
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if u.ID != "u1" || u.Nick != nickAlice {
t.Errorf("got user %+v", u)
if id <= 0 {
t.Errorf("expected positive id, got %d", id)
}
if token == "" {
t.Error("expected non-empty token")
}
}
func TestCreateAuthToken(t *testing.T) {
func TestGetUserByToken(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
_, err := d.CreateUser(ctx, "u1", nickAlice, "h")
_, token, _ := d.CreateUser(ctx, nickAlice)
id, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("CreateUser: %v", err)
t.Fatalf("GetUserByToken: %v", err)
}
tok, err := d.CreateAuthToken(ctx, "tok1", "u1")
if err != nil {
t.Fatalf("CreateAuthToken: %v", err)
}
if tok.Token != "tok1" || tok.UserID != "u1" {
t.Errorf("unexpected token: %+v", tok)
}
u, err := tok.User(ctx)
if err != nil {
t.Fatalf("AuthToken.User: %v", err)
}
if u.ID != "u1" || u.Nick != nickAlice {
t.Errorf("AuthToken.User got %+v", u)
if id <= 0 || nick != nickAlice {
t.Errorf(
"got id=%d nick=%s, want nick=%s",
id, nick, nickAlice,
)
}
}
func TestCreateChannel(t *testing.T) {
func TestGetUserByNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
ch, err := d.CreateChannel(
ctx, "c1", "#general", "welcome", "+n",
)
origID, _, _ := d.CreateUser(ctx, nickAlice)
id, err := d.GetUserByNick(ctx, nickAlice)
if err != nil {
t.Fatalf("CreateChannel: %v", err)
t.Fatalf("GetUserByNick: %v", err)
}
if ch.ID != "c1" || ch.Name != "#general" {
t.Errorf("unexpected channel: %+v", ch)
if id != origID {
t.Errorf("got id %d, want %d", id, origID)
}
}
func TestAddChannelMember(t *testing.T) {
func TestGetOrCreateChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#general", "", "")
cm, err := d.AddChannelMember(ctx, "c1", "u1", "+o")
id1, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("AddChannelMember: %v", err)
t.Fatalf("GetOrCreateChannel: %v", err)
}
if cm.ChannelID != "c1" || cm.Modes != "+o" {
t.Errorf("unexpected member: %+v", cm)
if id1 <= 0 {
t.Errorf("expected positive id, got %d", id1)
}
// Same channel returns same ID.
id2, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("GetOrCreateChannel(2): %v", err)
}
if id1 != id2 {
t.Errorf("got different ids: %d vs %d", id1, id2)
}
}
func TestCreateMessage(t *testing.T) {
func TestJoinAndListChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
uid, _, _ := d.CreateUser(ctx, nickAlice)
ch1, _ := d.GetOrCreateChannel(ctx, "#alpha")
ch2, _ := d.GetOrCreateChannel(ctx, "#beta")
msg, err := d.CreateMessage(
ctx, "m1", "u1", nickAlice,
"#general", "message", "hello",
)
_ = d.JoinChannel(ctx, ch1, uid)
_ = d.JoinChannel(ctx, ch2, uid)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("CreateMessage: %v", err)
}
if msg.ID != "m1" || msg.Body != "hello" {
t.Errorf("unexpected message: %+v", msg)
}
}
func TestQueueMessage(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateUser(ctx, "u2", nickBob, "h")
_, _ = d.CreateMessage(
ctx, "m1", "u1", nickAlice, "u2", "message", "hi",
)
mq, err := d.QueueMessage(ctx, "u2", "m1")
if err != nil {
t.Fatalf("QueueMessage: %v", err)
}
if mq.UserID != "u2" || mq.MessageID != "m1" {
t.Errorf("unexpected queue entry: %+v", mq)
}
}
func TestCreateSession(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
sess, err := d.CreateSession(ctx, "s1", "u1")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if sess.ID != "s1" || sess.UserID != "u1" {
t.Errorf("unexpected session: %+v", sess)
}
u, err := sess.User(ctx)
if err != nil {
t.Fatalf("Session.User: %v", err)
}
if u.ID != "u1" {
t.Errorf("Session.User got %v, want u1", u.ID)
}
}
func TestCreateServerLink(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
sl, err := d.CreateServerLink(
ctx, "sl1", "peer1",
"https://peer.example.com", "keyhash", true,
)
if err != nil {
t.Fatalf("CreateServerLink: %v", err)
}
if sl.ID != "sl1" || !sl.IsActive {
t.Errorf("unexpected server link: %+v", sl)
}
}
func TestUserChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#alpha", "", "")
_, _ = d.CreateChannel(ctx, "c2", "#beta", "", "")
_, _ = d.AddChannelMember(ctx, "c1", "u1", "")
_, _ = d.AddChannelMember(ctx, "c2", "u1", "")
channels, err := u.Channels(ctx)
if err != nil {
t.Fatalf("User.Channels: %v", err)
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels))
}
if channels[0].Name != "#alpha" {
t.Errorf("first channel: got %s", channels[0].Name)
}
if channels[1].Name != "#beta" {
t.Errorf("second channel: got %s", channels[1].Name)
}
}
func TestUserChannelsEmpty(t *testing.T) {
func TestListChannelsEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
uid, _, _ := d.CreateUser(ctx, nickAlice)
channels, err := u.Channels(ctx)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("User.Channels: %v", err)
t.Fatalf("ListChannels: %v", err)
}
if len(channels) != 0 {
@@ -257,60 +162,57 @@ func TestUserChannelsEmpty(t *testing.T) {
}
}
func TestUserQueuedMessages(t *testing.T) {
func TestPartChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateUser(ctx, "u2", nickBob, "h")
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
for i := range 3 {
id := fmt.Sprintf("m%d", i)
_ = d.JoinChannel(ctx, chID, uid)
_ = d.PartChannel(ctx, chID, uid)
_, _ = d.CreateMessage(
ctx, id, "u2", nickBob, "u1",
"message", fmt.Sprintf("msg%d", i),
)
time.Sleep(10 * time.Millisecond)
_, _ = d.QueueMessage(ctx, "u1", id)
}
msgs, err := u.QueuedMessages(ctx)
channels, err := d.ListChannels(ctx, uid)
if err != nil {
t.Fatalf("User.QueuedMessages: %v", err)
t.Fatalf("ListChannels: %v", err)
}
if len(msgs) != 3 {
t.Fatalf("expected 3 messages, got %d", len(msgs))
}
for i, msg := range msgs {
want := fmt.Sprintf("msg%d", i)
if msg.Body != want {
t.Errorf("msg %d: got %q, want %q", i, msg.Body, want)
}
if len(channels) != 0 {
t.Errorf("expected 0 after part, got %d", len(channels))
}
}
func TestUserQueuedMessagesEmpty(t *testing.T) {
func TestSendAndGetMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
msgs, err := u.QueuedMessages(ctx)
_, err := d.SendMessage(ctx, chID, uid, "hello world")
if err != nil {
t.Fatalf("User.QueuedMessages: %v", err)
t.Fatalf("SendMessage: %v", err)
}
if len(msgs) != 0 {
t.Errorf("expected 0 messages, got %d", len(msgs))
msgs, err := d.GetMessages(ctx, chID, 0, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Content != "hello world" {
t.Errorf(
"got content %q, want %q",
msgs[0].Content, "hello world",
)
}
}
@@ -318,50 +220,38 @@ func TestChannelMembers(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "")
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateUser(ctx, "u2", nickBob, "h")
_, _ = d.CreateUser(ctx, "u3", nickCharlie, "h")
_, _ = d.AddChannelMember(ctx, "c1", "u1", "+o")
_, _ = d.AddChannelMember(ctx, "c1", "u2", "+v")
_, _ = d.AddChannelMember(ctx, "c1", "u3", "")
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
uid3, _, _ := d.CreateUser(ctx, nickCharlie)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
members, err := ch.Members(ctx)
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_ = d.JoinChannel(ctx, chID, uid3)
members, err := d.ChannelMembers(ctx, chID)
if err != nil {
t.Fatalf("Channel.Members: %v", err)
t.Fatalf("ChannelMembers: %v", err)
}
if len(members) != 3 {
t.Fatalf("expected 3 members, got %d", len(members))
}
nicks := map[string]bool{}
for _, m := range members {
nicks[m.Nick] = true
}
for _, want := range []string{
nickAlice, nickBob, nickCharlie,
} {
if !nicks[want] {
t.Errorf("missing nick %q", want)
}
}
}
func TestChannelMembersEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
ch, _ := d.CreateChannel(ctx, "c1", "#empty", "", "")
chID, _ := d.GetOrCreateChannel(ctx, "#empty")
members, err := ch.Members(ctx)
members, err := d.ChannelMembers(ctx, chID)
if err != nil {
t.Fatalf("Channel.Members: %v", err)
t.Fatalf("ChannelMembers: %v", err)
}
if len(members) != 0 {
@@ -369,126 +259,166 @@ func TestChannelMembersEmpty(t *testing.T) {
}
}
func TestChannelRecentMessages(t *testing.T) {
func TestSendDM(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "")
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob")
if err != nil {
t.Fatalf("SendDM: %v", err)
}
if msgID <= 0 {
t.Errorf("expected positive msgID, got %d", msgID)
}
msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0)
if err != nil {
t.Fatalf("GetDMs: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 DM, got %d", len(msgs))
}
if msgs[0].Content != "hey bob" {
t.Errorf("got %q, want %q", msgs[0].Content, "hey bob")
}
}
func TestPollMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_, _ = d.SendMessage(ctx, chID, uid2, "hello")
_, _ = d.SendDM(ctx, uid2, uid1, "private")
time.Sleep(10 * time.Millisecond)
msgs, err := d.PollMessages(ctx, uid1, 0, 0)
if err != nil {
t.Fatalf("PollMessages: %v", err)
}
if len(msgs) < 2 {
t.Fatalf("expected >=2 messages, got %d", len(msgs))
}
}
func TestChangeNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
_, token, _ := d.CreateUser(ctx, nickAlice)
err := d.ChangeNick(ctx, 1, "alice2")
if err != nil {
t.Fatalf("ChangeNick: %v", err)
}
_, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("GetUserByToken: %v", err)
}
if nick != "alice2" {
t.Errorf("got nick %q, want alice2", nick)
}
}
func TestSetTopic(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
_, _ = d.GetOrCreateChannel(ctx, "#general")
err := d.SetTopic(ctx, "#general", uid, "new topic")
if err != nil {
t.Fatalf("SetTopic: %v", err)
}
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("ListAllChannels: %v", err)
}
found := false
for _, ch := range channels {
if ch.Name == "#general" && ch.Topic == "new topic" {
found = true
}
}
if !found {
t.Error("topic was not updated")
}
}
func TestGetMessagesBefore(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
for i := range 5 {
id := fmt.Sprintf("m%d", i)
_, _ = d.CreateMessage(
ctx, id, "u1", nickAlice, "#general",
"message", fmt.Sprintf("msg%d", i),
_, _ = d.SendMessage(
ctx, chID, uid,
fmt.Sprintf("msg%d", i),
)
time.Sleep(10 * time.Millisecond)
}
msgs, err := ch.RecentMessages(ctx, 3)
msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3)
if err != nil {
t.Fatalf("RecentMessages: %v", err)
t.Fatalf("GetMessagesBefore: %v", err)
}
if len(msgs) != 3 {
t.Fatalf("expected 3, got %d", len(msgs))
}
if msgs[0].Body != "msg4" {
t.Errorf("first: got %q, want msg4", msgs[0].Body)
}
if msgs[2].Body != "msg2" {
t.Errorf("last: got %q, want msg2", msgs[2].Body)
}
}
func TestChannelRecentMessagesLargeLimit(t *testing.T) {
func TestListAllChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ctx := t.Context()
ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "")
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateMessage(
ctx, "m1", "u1", nickAlice,
"#general", "message", "only",
)
_, _ = d.GetOrCreateChannel(ctx, "#alpha")
_, _ = d.GetOrCreateChannel(ctx, "#beta")
msgs, err := ch.RecentMessages(ctx, 100)
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("RecentMessages: %v", err)
t.Fatalf("ListAllChannels: %v", err)
}
if len(msgs) != 1 {
t.Errorf("expected 1, got %d", len(msgs))
}
}
func TestChannelMemberUser(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#general", "", "")
cm, _ := d.AddChannelMember(ctx, "c1", "u1", "+o")
u, err := cm.User(ctx)
if err != nil {
t.Fatalf("ChannelMember.User: %v", err)
}
if u.ID != "u1" || u.Nick != nickAlice {
t.Errorf("got %+v", u)
}
}
func TestChannelMemberChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#general", "topic", "+n")
cm, _ := d.AddChannelMember(ctx, "c1", "u1", "")
ch, err := cm.Channel(ctx)
if err != nil {
t.Fatalf("ChannelMember.Channel: %v", err)
}
if ch.ID != "c1" || ch.Topic != "topic" {
t.Errorf("got %+v", ch)
}
}
func TestDMMessage(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateUser(ctx, "u2", nickBob, "h")
msg, err := d.CreateMessage(
ctx, "m1", "u1", nickAlice, "u2", "message", "hey",
)
if err != nil {
t.Fatalf("CreateMessage DM: %v", err)
}
if msg.Target != "u2" {
t.Errorf("target: got %q, want u2", msg.Target)
if len(channels) != 2 {
t.Errorf("expected 2, got %d", len(channels))
}
}

View File

@@ -2,44 +2,140 @@ package db
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"time"
)
// GetOrCreateChannel returns the channel id, creating it if needed.
func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (string, error) {
var id string
const (
defaultMessageLimit = 50
defaultPollLimit = 100
tokenBytes = 32
)
err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id)
func generateToken() string {
b := make([]byte, tokenBytes)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
// CreateUser registers a new user with the given nick and
// returns the user with token.
func (s *Database) CreateUser(
ctx context.Context,
nick string,
) (int64, string, error) {
token := generateToken()
now := time.Now()
res, err := s.db.ExecContext(ctx,
"INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)",
nick, token, now, now)
if err != nil {
return 0, "", fmt.Errorf("create user: %w", err)
}
id, _ := res.LastInsertId()
return id, token, nil
}
// GetUserByToken returns user id and nick for a given auth
// token.
func (s *Database) GetUserByToken(
ctx context.Context,
token string,
) (int64, string, error) {
var id int64
var nick string
err := s.db.QueryRowContext(
ctx,
"SELECT id, nick FROM users WHERE token = ?",
token,
).Scan(&id, &nick)
if err != nil {
return 0, "", err
}
// Update last_seen
_, _ = s.db.ExecContext(
ctx,
"UPDATE users SET last_seen = ? WHERE id = ?",
time.Now(), id,
)
return id, nick, nil
}
// GetUserByNick returns user id for a given nick.
func (s *Database) GetUserByNick(
ctx context.Context,
nick string,
) (int64, error) {
var id int64
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM users WHERE nick = ?",
nick,
).Scan(&id)
return id, err
}
// GetOrCreateChannel returns the channel id, creating it if
// needed.
func (s *Database) GetOrCreateChannel(
ctx context.Context,
name string,
) (int64, error) {
var id int64
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&id)
if err == nil {
return id, nil
}
now := time.Now()
id = fmt.Sprintf("ch-%d", now.UnixNano())
_, err = s.db.ExecContext(ctx,
"INSERT INTO channels (id, name, topic, modes, created_at, updated_at) VALUES (?, ?, '', '', ?, ?)",
id, name, now, now)
res, err := s.db.ExecContext(ctx,
"INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)",
name, now, now)
if err != nil {
return "", fmt.Errorf("create channel: %w", err)
return 0, fmt.Errorf("create channel: %w", err)
}
id, _ = res.LastInsertId()
return id, nil
}
// JoinChannel adds a user to a channel.
func (s *Database) JoinChannel(ctx context.Context, channelID, userID string) error {
func (s *Database) JoinChannel(
ctx context.Context,
channelID, userID int64,
) error {
_, err := s.db.ExecContext(ctx,
"INSERT OR IGNORE INTO channel_members (channel_id, user_id, modes, joined_at) VALUES (?, ?, '', ?)",
"INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)",
channelID, userID, time.Now())
return err
}
// PartChannel removes a user from a channel.
func (s *Database) PartChannel(ctx context.Context, channelID, userID string) error {
func (s *Database) PartChannel(
ctx context.Context,
channelID, userID int64,
) error {
_, err := s.db.ExecContext(ctx,
"DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?",
channelID, userID)
@@ -47,8 +143,18 @@ func (s *Database) PartChannel(ctx context.Context, channelID, userID string) er
return err
}
// ChannelInfo is a lightweight channel representation.
type ChannelInfo struct {
ID int64 `json:"id"`
Name string `json:"name"`
Topic string `json:"topic"`
}
// ListChannels returns all channels the user has joined.
func (s *Database) ListChannels(ctx context.Context, userID string) ([]ChannelInfo, error) {
func (s *Database) ListChannels(
ctx context.Context,
userID int64,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic FROM channels c
INNER JOIN channel_members cm ON cm.channel_id = c.id
@@ -59,7 +165,7 @@ func (s *Database) ListChannels(ctx context.Context, userID string) ([]ChannelIn
defer func() { _ = rows.Close() }()
channels := []ChannelInfo{}
var channels []ChannelInfo
for rows.Next() {
var ch ChannelInfo
@@ -72,20 +178,32 @@ func (s *Database) ListChannels(ctx context.Context, userID string) ([]ChannelIn
channels = append(channels, ch)
}
return channels, rows.Err()
err = rows.Err()
if err != nil {
return nil, err
}
if channels == nil {
channels = []ChannelInfo{}
}
return channels, nil
}
// ChannelInfo is a lightweight channel representation.
type ChannelInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Topic string `json:"topic"`
// MemberInfo represents a channel member.
type MemberInfo struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
LastSeen time.Time `json:"lastSeen"`
}
// ChannelMembers returns all members of a channel.
func (s *Database) ChannelMembers(ctx context.Context, channelID string) ([]MemberInfo, error) {
func (s *Database) ChannelMembers(
ctx context.Context,
channelID int64,
) ([]MemberInfo, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen_at FROM users u
`SELECT u.id, u.nick, u.last_seen FROM users u
INNER JOIN channel_members cm ON cm.user_id = u.id
WHERE cm.channel_id = ? ORDER BY u.nick`, channelID)
if err != nil {
@@ -94,7 +212,7 @@ func (s *Database) ChannelMembers(ctx context.Context, channelID string) ([]Memb
defer func() { _ = rows.Close() }()
members := []MemberInfo{}
var members []MemberInfo
for rows.Next() {
var m MemberInfo
@@ -107,19 +225,21 @@ func (s *Database) ChannelMembers(ctx context.Context, channelID string) ([]Memb
members = append(members, m)
}
return members, rows.Err()
}
err = rows.Err()
if err != nil {
return nil, err
}
// MemberInfo represents a channel member.
type MemberInfo struct {
ID string `json:"id"`
Nick string `json:"nick"`
LastSeen *time.Time `json:"lastSeen"`
if members == nil {
members = []MemberInfo{}
}
return members, nil
}
// MessageInfo represents a chat message.
type MessageInfo struct {
ID string `json:"id"`
ID int64 `json:"id"`
Channel string `json:"channel,omitempty"`
Nick string `json:"nick"`
Content string `json:"content"`
@@ -128,67 +248,40 @@ type MessageInfo struct {
CreatedAt time.Time `json:"createdAt"`
}
// SendMessage inserts a channel message.
func (s *Database) SendMessage(ctx context.Context, channelID, userID, nick, content string) (string, error) {
now := time.Now()
id := fmt.Sprintf("msg-%d", now.UnixNano())
_, err := s.db.ExecContext(ctx,
`INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at)
VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`,
id, now, userID, nick, channelID, content, now)
if err != nil {
return "", err
}
return id, nil
}
// SendDM inserts a direct message.
func (s *Database) SendDM(ctx context.Context, fromID, fromNick, toID, content string) (string, error) {
now := time.Now()
id := fmt.Sprintf("msg-%d", now.UnixNano())
_, err := s.db.ExecContext(ctx,
`INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at)
VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`,
id, now, fromID, fromNick, toID, content, now)
if err != nil {
return "", err
}
return id, nil
}
// PollMessages returns all new messages for a user's joined channels, ordered by timestamp.
func (s *Database) PollMessages(ctx context.Context, userID string, afterTS string, limit int) ([]MessageInfo, error) {
// GetMessages returns messages for a channel, optionally
// after a given ID.
func (s *Database) GetMessages(
ctx context.Context,
channelID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = 100
limit = defaultMessageLimit
}
rows, err := s.db.QueryContext(ctx,
`SELECT m.id, m.target, m.from_nick, m.body, m.created_at
`SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m
WHERE m.created_at > COALESCE(NULLIF(?, ''), '1970-01-01')
AND (
m.target IN (SELECT cm.channel_id FROM channel_members cm WHERE cm.user_id = ?)
OR m.target = ?
OR m.from_user_id = ?
)
ORDER BY m.created_at ASC LIMIT ?`,
afterTS, userID, userID, userID, limit)
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id > ?
ORDER BY m.id ASC LIMIT ?`, channelID, afterID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs := []MessageInfo{}
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt)
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil {
return nil, err
}
@@ -196,144 +289,392 @@ func (s *Database) PollMessages(ctx context.Context, userID string, afterTS stri
msgs = append(msgs, m)
}
return msgs, rows.Err()
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
// GetMessagesBefore returns channel messages before a given timestamp (for history scrollback).
func (s *Database) GetMessagesBefore(
ctx context.Context, target string, beforeTS string, limit int,
// SendMessage inserts a channel message.
func (s *Database) SendMessage(
ctx context.Context,
channelID, userID int64,
content string,
) (int64, error) {
res, err := s.db.ExecContext(ctx,
"INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)",
channelID, userID, content, time.Now())
if err != nil {
return 0, err
}
return res.LastInsertId()
}
// SendDM inserts a direct message.
func (s *Database) SendDM(
ctx context.Context,
fromID, toID int64,
content string,
) (int64, error) {
res, err := s.db.ExecContext(ctx,
"INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)",
fromID, content, toID, time.Now())
if err != nil {
return 0, err
}
return res.LastInsertId()
}
// GetDMs returns direct messages between two users after a
// given ID.
func (s *Database) GetDMs(
ctx context.Context,
userA, userB int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = 50
}
var rows *sql.Rows
var err error
if beforeTS != "" {
rows, err = s.db.QueryContext(ctx,
`SELECT m.id, m.target, m.from_nick, m.body, m.created_at
FROM messages m
WHERE m.target = ? AND m.created_at < ?
ORDER BY m.created_at DESC LIMIT ?`,
target, beforeTS, limit)
} else {
rows, err = s.db.QueryContext(ctx,
`SELECT m.id, m.target, m.from_nick, m.body, m.created_at
FROM messages m
WHERE m.target = ?
ORDER BY m.created_at DESC LIMIT ?`,
target, limit)
limit = defaultMessageLimit
}
rows, err := s.db.QueryContext(ctx,
`SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id > ?
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id ASC LIMIT ?`,
afterID, userA, userB, userB, userA, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs := []MessageInfo{}
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt)
if err != nil {
return nil, err
}
msgs = append(msgs, m)
}
// Reverse to ascending order.
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
return msgs, rows.Err()
}
// GetDMsBefore returns DMs between two users before a given timestamp.
func (s *Database) GetDMsBefore(
ctx context.Context, userA, userB string, beforeTS string, limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = 50
}
var rows *sql.Rows
var err error
if beforeTS != "" {
rows, err = s.db.QueryContext(ctx,
`SELECT m.id, m.from_nick, m.body, m.target, m.created_at
FROM messages m
WHERE m.created_at < ?
AND ((m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?))
ORDER BY m.created_at DESC LIMIT ?`,
beforeTS, userA, userB, userB, userA, limit)
} else {
rows, err = s.db.QueryContext(ctx,
`SELECT m.id, m.from_nick, m.body, m.target, m.created_at
FROM messages m
WHERE (m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?)
ORDER BY m.created_at DESC LIMIT ?`,
userA, userB, userB, userA, limit)
}
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs := []MessageInfo{}
for rows.Next() {
var m MessageInfo
err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt)
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
// Reverse to ascending order.
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
// PollMessages returns all new messages (channel + DM) for
// a user after a given ID.
func (s *Database) PollMessages(
ctx context.Context,
userID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultPollLimit
}
rows, err := s.db.QueryContext(ctx,
`SELECT m.id, COALESCE(c.name, ''), u.nick, m.content,
m.is_dm, COALESCE(t.nick, ''), m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
LEFT JOIN channels c ON c.id = m.channel_id
LEFT JOIN users t ON t.id = m.dm_target_id
WHERE m.id > ? AND (
(m.is_dm = 0 AND m.channel_id IN
(SELECT channel_id FROM channel_members
WHERE user_id = ?))
OR (m.is_dm = 1
AND (m.user_id = ? OR m.dm_target_id = ?))
)
ORDER BY m.id ASC LIMIT ?`,
afterID, userID, userID, userID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs := make([]MessageInfo, 0)
for rows.Next() {
var (
m MessageInfo
isDM int
)
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick, &m.Content,
&isDM, &m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = isDM == 1
msgs = append(msgs, m)
}
err = rows.Err()
if err != nil {
return nil, err
}
return msgs, nil
}
func scanChannelMessages(
rows *sql.Rows,
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil {
return nil, err
}
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func scanDMMessages(
rows *sql.Rows,
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func reverseMessages(msgs []MessageInfo) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
}
return msgs, rows.Err()
// GetMessagesBefore returns channel messages before a given
// ID (for history scrollback).
func (s *Database) GetMessagesBefore(
ctx context.Context,
channelID int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultMessageLimit
}
var query string
var args []any
if beforeID > 0 {
query = `SELECT m.id, c.name, u.nick, m.content,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
AND m.id < ?
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, beforeID, limit}
} else {
query = `SELECT m.id, c.name, u.nick, m.content,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs, scanErr := scanChannelMessages(rows)
if scanErr != nil {
return nil, scanErr
}
// Reverse to ascending order.
reverseMessages(msgs)
return msgs, nil
}
// GetDMsBefore returns DMs between two users before a given
// ID (for history scrollback).
func (s *Database) GetDMsBefore(
ctx context.Context,
userA, userB int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) {
if limit <= 0 {
limit = defaultMessageLimit
}
var query string
var args []any
if beforeID > 0 {
query = `SELECT m.id, u.nick, m.content, t.nick,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id < ?
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{
beforeID, userA, userB, userB, userA, limit,
}
} else {
query = `SELECT m.id, u.nick, m.content, t.nick,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{userA, userB, userB, userA, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
msgs, scanErr := scanDMMessages(rows)
if scanErr != nil {
return nil, scanErr
}
// Reverse to ascending order.
reverseMessages(msgs)
return msgs, nil
}
// ChangeNick updates a user's nickname.
func (s *Database) ChangeNick(ctx context.Context, userID string, newNick string) error {
func (s *Database) ChangeNick(
ctx context.Context,
userID int64,
newNick string,
) error {
_, err := s.db.ExecContext(ctx,
"UPDATE users SET nick = ? WHERE id = ?", newNick, userID)
"UPDATE users SET nick = ? WHERE id = ?",
newNick, userID)
return err
}
// SetTopic sets the topic for a channel.
func (s *Database) SetTopic(ctx context.Context, channelName string, _ string, topic string) error {
func (s *Database) SetTopic(
ctx context.Context,
channelName string,
_ int64,
topic string,
) error {
_, err := s.db.ExecContext(ctx,
"UPDATE channels SET topic = ? WHERE name = ?", topic, channelName)
"UPDATE channels SET topic = ? WHERE name = ?",
topic, channelName)
return err
}
// GetServerName returns the server name (unused, config provides this).
// GetServerName returns the server name (unused, config
// provides this).
func (s *Database) GetServerName() string {
return ""
}
// ListAllChannels returns all channels.
func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) {
func (s *Database) ListAllChannels(
ctx context.Context,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx,
"SELECT id, name, topic FROM channels ORDER BY name")
if err != nil {
@@ -342,12 +683,14 @@ func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) {
defer func() { _ = rows.Close() }()
channels := []ChannelInfo{}
var channels []ChannelInfo
for rows.Next() {
var ch ChannelInfo
err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic)
err := rows.Scan(
&ch.ID, &ch.Name, &ch.Topic,
)
if err != nil {
return nil, err
}
@@ -355,5 +698,14 @@ func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) {
channels = append(channels, ch)
}
return channels, rows.Err()
err = rows.Err()
if err != nil {
return nil, err
}
if channels == nil {
channels = []ChannelInfo{}
}
return channels, nil
}

View File

@@ -1,4 +1,53 @@
-- Migration 003: no-op (schema already created by 002_schema.sql)
-- This migration previously conflicted with 002 by attempting to recreate
-- tables with incompatible column types (INTEGER vs TEXT IDs).
SELECT 1;
-- Migration 003: Replace UUID-based tables with simple integer-keyed
-- tables for the HTTP API. Drops the 002 tables and recreates them.
PRAGMA foreign_keys = OFF;
DROP TABLE IF EXISTS message_queue;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS server_links;
DROP TABLE IF EXISTS messages;
DROP TABLE IF EXISTS channel_members;
DROP TABLE IF EXISTS auth_tokens;
DROP TABLE IF EXISTS channels;
DROP TABLE IF EXISTS users;
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
nick TEXT NOT NULL UNIQUE,
token TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
topic TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, user_id)
);
CREATE TABLE messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
is_dm INTEGER NOT NULL DEFAULT 0,
dm_target_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_messages_channel ON messages(channel_id, created_at);
CREATE INDEX idx_messages_dm ON messages(user_id, dm_target_id, created_at);
CREATE INDEX idx_users_token ON users(token);
PRAGMA foreign_keys = ON;

View File

@@ -1,9 +1,7 @@
package handlers
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"net/http"
"strconv"
@@ -13,51 +11,93 @@ import (
"github.com/go-chi/chi"
)
// authUser extracts the user from the Authorization header (Bearer token).
func (s *Handlers) authUser(r *http.Request) (string, string, error) {
const (
maxNickLen = 32
defaultHistory = 50
)
// authUser extracts the user from the Authorization header
// (Bearer token).
func (s *Handlers) authUser(
r *http.Request,
) (int64, string, error) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
return "", "", sql.ErrNoRows
return 0, "", sql.ErrNoRows
}
token := strings.TrimPrefix(auth, "Bearer ")
u, err := s.params.Database.GetUserByToken(r.Context(), token)
if err != nil {
return "", "", err
}
return u.ID, u.Nick, nil
return s.params.Database.GetUserByToken(r.Context(), token)
}
func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (string, string, bool) {
func (s *Handlers) requireAuth(
w http.ResponseWriter,
r *http.Request,
) (int64, string, bool) {
uid, nick, err := s.authUser(r)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized)
s.respondJSON(
w, r,
map[string]string{"error": "unauthorized"},
http.StatusUnauthorized,
)
return "", "", false
return 0, "", false
}
return uid, nick, true
}
const idBytes = 16
func generateID() string {
b := make([]byte, idBytes)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
func (s *Handlers) respondError(
w http.ResponseWriter,
r *http.Request,
msg string,
code int,
) {
s.respondJSON(w, r, map[string]string{"error": msg}, code)
}
// HandleCreateSession creates a new user session and returns the auth token.
func (s *Handlers) internalError(
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
s.log.Error(msg, "error", err)
s.respondError(w, r, "internal error", http.StatusInternalServerError)
}
// bodyLines extracts body as string lines from a request body
// field.
func bodyLines(body any) []string {
switch v := body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)
}
}
return lines
case []string:
return v
default:
return nil
}
}
// HandleCreateSession creates a new user session and returns
// the auth token.
func (s *Handlers) HandleCreateSession() http.HandlerFunc {
type request struct {
Nick string `json:"nick"`
}
type response struct {
ID string `json:"id"`
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
@@ -67,52 +107,56 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc {
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
s.respondError(
w, r, "invalid request",
http.StatusBadRequest,
)
return
}
req.Nick = strings.TrimSpace(req.Nick)
if req.Nick == "" || len(req.Nick) > 32 {
s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
if req.Nick == "" || len(req.Nick) > maxNickLen {
s.respondError(
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return
}
id := generateID()
u, err := s.params.Database.CreateUser(r.Context(), id, req.Nick, "")
id, token, err := s.params.Database.CreateUser(
r.Context(), req.Nick,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict)
s.respondError(
w, r, "nick already taken",
http.StatusConflict,
)
return
}
s.log.Error("create user failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "create user failed", err)
return
}
tokenStr := generateID()
_, err = s.params.Database.CreateAuthToken(r.Context(), tokenStr, u.ID)
if err != nil {
s.log.Error("create auth token failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, &response{ID: u.ID, Nick: req.Nick, Token: tokenStr}, http.StatusCreated)
s.respondJSON(
w, r,
&response{ID: id, Nick: req.Nick, Token: token},
http.StatusCreated,
)
}
}
// HandleState returns the current user's info and joined channels.
// HandleState returns the current user's info and joined
// channels.
func (s *Handlers) HandleState() http.HandlerFunc {
type response struct {
ID string `json:"id"`
ID int64 `json:"id"`
Nick string `json:"nick"`
Channels []db.ChannelInfo `json:"channels"`
}
@@ -123,15 +167,25 @@ func (s *Handlers) HandleState() http.HandlerFunc {
return
}
channels, err := s.params.Database.ListChannels(r.Context(), uid)
channels, err := s.params.Database.ListChannels(
r.Context(), uid,
)
if err != nil {
s.log.Error("list channels failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(
w, r, "list channels failed", err,
)
return
}
s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK)
s.respondJSON(
w, r,
&response{
ID: uid, Nick: nick,
Channels: channels,
},
http.StatusOK,
)
}
}
@@ -143,10 +197,13 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc {
return
}
channels, err := s.params.Database.ListAllChannels(r.Context())
channels, err := s.params.Database.ListAllChannels(
r.Context(),
)
if err != nil {
s.log.Error("list all channels failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(
w, r, "list all channels failed", err,
)
return
}
@@ -165,21 +222,29 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
name := "#" + chi.URLParam(r, "channel")
var chID string
var chID int64
//nolint:gosec // G701: parameterized query with ? placeholder, not injection
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
members, err := s.params.Database.ChannelMembers(r.Context(), chID)
members, err := s.params.Database.ChannelMembers(
r.Context(), chID,
)
if err != nil {
s.log.Error("channel members failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(
w, r, "channel members failed", err,
)
return
}
@@ -188,8 +253,8 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
}
}
// HandleGetMessages returns all new messages (channel + DM) for the user via long-polling.
// This is the single unified message stream — replaces separate channel/DM/poll endpoints.
// HandleGetMessages returns all new messages (channel + DM)
// for the user via long-polling.
func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
@@ -197,13 +262,21 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return
}
afterTS := r.URL.Query().Get("after")
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
afterID, _ := strconv.ParseInt(
r.URL.Query().Get("after"), 10, 64,
)
msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterTS, limit)
limit, _ := strconv.Atoi(
r.URL.Query().Get("limit"),
)
msgs, err := s.params.Database.PollMessages(
r.Context(), uid, afterID, limit,
)
if err != nil {
s.log.Error("get messages failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(
w, r, "get messages failed", err,
)
return
}
@@ -212,284 +285,389 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc {
}
}
// HandleSendCommand handles all C2S commands via POST /messages.
// The "command" field dispatches to the appropriate logic.
func (s *Handlers) HandleSendCommand() http.HandlerFunc {
type request struct {
Command string `json:"command"`
To string `json:"to"`
Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"`
}
type sendRequest struct {
Command string `json:"command"`
To string `json:"to"`
Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"`
}
// HandleSendCommand handles all C2S commands via POST
// /messages.
func (s *Handlers) HandleSendCommand() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r)
if !ok {
return
}
var req request
var req sendRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
s.respondError(
w, r, "invalid request",
http.StatusBadRequest,
)
return
}
req.Command = strings.ToUpper(strings.TrimSpace(req.Command))
req.Command = strings.ToUpper(
strings.TrimSpace(req.Command),
)
req.To = strings.TrimSpace(req.To)
lines := extractBodyLines(req.Body)
s.dispatchCommand(w, r, uid, nick, req.Command, req.To, lines)
}
}
// extractBodyLines converts the request body to string lines.
func extractBodyLines(body any) []string {
switch v := body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if str, ok := item.(string); ok {
lines = append(lines, str)
}
}
return lines
case []string:
return v
default:
return nil
s.dispatchCommand(w, r, uid, nick, &req)
}
}
func (s *Handlers) dispatchCommand(
w http.ResponseWriter, r *http.Request,
uid, nick, command, to string, lines []string,
w http.ResponseWriter,
r *http.Request,
uid int64,
nick string,
req *sendRequest,
) {
switch command {
switch req.Command {
case "PRIVMSG", "NOTICE":
s.handlePrivmsg(w, r, uid, nick, to, lines)
s.handlePrivmsg(w, r, uid, req)
case "JOIN":
s.handleJoin(w, r, uid, to)
s.handleJoin(w, r, uid, req)
case "PART":
s.handlePart(w, r, uid, to)
s.handlePart(w, r, uid, req)
case "NICK":
s.handleNick(w, r, uid, lines)
s.handleNick(w, r, uid, req)
case "TOPIC":
s.handleTopic(w, r, uid, to, lines)
s.handleTopic(w, r, uid, req)
case "PING":
s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK)
s.respondJSON(
w, r,
map[string]string{
"command": "PONG",
"from": s.params.Config.ServerName,
},
http.StatusOK,
)
default:
_ = nick
s.respondJSON(w, r, map[string]string{"error": "unknown command: " + command}, http.StatusBadRequest)
s.respondError(
w, r,
"unknown command: "+req.Command,
http.StatusBadRequest,
)
}
}
func (s *Handlers) handlePrivmsg(w http.ResponseWriter, r *http.Request, uid, nick, to string, lines []string) {
if to == "" {
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
func (s *Handlers) handlePrivmsg(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
lines := bodyLines(req.Body)
if len(lines) == 0 {
s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest)
s.respondError(
w, r, "body required", http.StatusBadRequest,
)
return
}
content := strings.Join(lines, "\n")
if strings.HasPrefix(to, "#") {
s.sendChannelMessage(w, r, uid, nick, to, content)
return
if strings.HasPrefix(req.To, "#") {
s.sendChannelMsg(w, r, uid, req.To, content)
} else {
s.sendDM(w, r, uid, req.To, content)
}
// DM.
s.sendDirectMessage(w, r, uid, nick, to, content)
}
func (s *Handlers) sendChannelMessage(
w http.ResponseWriter, r *http.Request,
uid, nick, channel, content string,
func (s *Handlers) sendChannelMsg(
w http.ResponseWriter,
r *http.Request,
uid int64,
channel, content string,
) {
var chID string
var chID int64
//nolint:gosec // G701: parameterized query, not injection
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", channel).Scan(&chID)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, nick, content)
msgID, err := s.params.Database.SendMessage(
r.Context(), chID, uid, content,
)
if err != nil {
s.log.Error("send message failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "send message failed", err)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
s.respondJSON(
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
}
func (s *Handlers) sendDirectMessage(
w http.ResponseWriter, r *http.Request,
uid, nick, to, content string,
func (s *Handlers) sendDM(
w http.ResponseWriter,
r *http.Request,
uid int64,
toNick, content string,
) {
targetUser, err := s.params.Database.GetUserByNick(r.Context(), to)
targetID, err := s.params.Database.GetUserByNick(
r.Context(), toNick,
)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
s.respondError(
w, r, "user not found", http.StatusNotFound,
)
return
}
msgID, err := s.params.Database.SendDM(r.Context(), uid, nick, targetUser.ID, content)
msgID, err := s.params.Database.SendDM(
r.Context(), uid, targetID, content,
)
if err != nil {
s.log.Error("send dm failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "send dm failed", err)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
s.respondJSON(
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
}
func (s *Handlers) handleJoin(w http.ResponseWriter, r *http.Request, uid, to string) {
if to == "" {
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
func (s *Handlers) handleJoin(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
channel := to
channel := req.To
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel)
chID, err := s.params.Database.GetOrCreateChannel(
r.Context(), channel,
)
if err != nil {
s.log.Error("get/create channel failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(
w, r, "get/create channel failed", err,
)
return
}
err = s.params.Database.JoinChannel(r.Context(), chID, uid)
err = s.params.Database.JoinChannel(
r.Context(), chID, uid,
)
if err != nil {
s.log.Error("join channel failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "join channel failed", err)
return
}
s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK)
s.respondJSON(
w, r,
map[string]string{
"status": "joined", "channel": channel,
},
http.StatusOK,
)
}
func (s *Handlers) handlePart(w http.ResponseWriter, r *http.Request, uid, to string) {
if to == "" {
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
func (s *Handlers) handlePart(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
channel := to
channel := req.To
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
var chID string
var chID int64
//nolint:gosec // G701: parameterized query, not injection
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", channel).Scan(&chID)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
err = s.params.Database.PartChannel(r.Context(), chID, uid)
err = s.params.Database.PartChannel(
r.Context(), chID, uid,
)
if err != nil {
s.log.Error("part channel failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "part channel failed", err)
return
}
s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK)
s.respondJSON(
w, r,
map[string]string{
"status": "parted", "channel": channel,
},
http.StatusOK,
)
}
func (s *Handlers) handleNick(w http.ResponseWriter, r *http.Request, uid string, lines []string) {
func (s *Handlers) handleNick(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
lines := bodyLines(req.Body)
if len(lines) == 0 {
s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest)
s.respondError(
w, r, "body required (new nick)",
http.StatusBadRequest,
)
return
}
newNick := strings.TrimSpace(lines[0])
if newNick == "" || len(newNick) > 32 {
s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
if newNick == "" || len(newNick) > maxNickLen {
s.respondError(
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return
}
err := s.params.Database.ChangeNick(r.Context(), uid, newNick)
err := s.params.Database.ChangeNick(
r.Context(), uid, newNick,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict)
s.respondError(
w, r, "nick already in use",
http.StatusConflict,
)
return
}
s.log.Error("change nick failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "change nick failed", err)
return
}
s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK)
s.respondJSON(
w, r,
map[string]string{"status": "ok", "nick": newNick},
http.StatusOK,
)
}
func (s *Handlers) handleTopic(w http.ResponseWriter, r *http.Request, uid, to string, lines []string) {
if to == "" {
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
func (s *Handlers) handleTopic(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
lines := bodyLines(req.Body)
if len(lines) == 0 {
s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest)
s.respondError(
w, r, "body required (topic text)",
http.StatusBadRequest,
)
return
}
topic := strings.Join(lines, " ")
channel := to
channel := req.To
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
err := s.params.Database.SetTopic(r.Context(), channel, uid, topic)
err := s.params.Database.SetTopic(
r.Context(), channel, uid, topic,
)
if err != nil {
s.log.Error("set topic failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "set topic failed", err)
return
}
s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK)
s.respondJSON(
w, r,
map[string]string{"status": "ok", "topic": topic},
http.StatusOK,
)
}
// HandleGetHistory returns message history for a specific target (channel or DM).
// HandleGetHistory returns message history for a specific
// target (channel or DM).
func (s *Handlers) HandleGetHistory() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
@@ -499,47 +677,65 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc {
target := r.URL.Query().Get("target")
if target == "" {
s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest)
s.respondError(
w, r, "target required",
http.StatusBadRequest,
)
return
}
beforeTS := r.URL.Query().Get("before")
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
beforeID, _ := strconv.ParseInt(
r.URL.Query().Get("before"), 10, 64,
)
limit, _ := strconv.Atoi(
r.URL.Query().Get("limit"),
)
if limit <= 0 {
limit = 50
limit = defaultHistory
}
if strings.HasPrefix(target, "#") {
s.getChannelHistory(w, r, target, beforeTS, limit)
return
s.getChannelHistory(
w, r, target, beforeID, limit,
)
} else {
s.getDMHistory(
w, r, uid, target, beforeID, limit,
)
}
s.getDMHistory(w, r, uid, target, beforeTS, limit)
}
}
func (s *Handlers) getChannelHistory(
w http.ResponseWriter, r *http.Request,
channel, beforeTS string, limit int,
w http.ResponseWriter,
r *http.Request,
target string,
beforeID int64,
limit int,
) {
var chID string
var chID int64
//nolint:gosec // G701: parameterized query, not injection
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", channel).Scan(&chID)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
target,
).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeTS, limit)
msgs, err := s.params.Database.GetMessagesBefore(
r.Context(), chID, beforeID, limit,
)
if err != nil {
s.log.Error("get history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(w, r, "get history failed", err)
return
}
@@ -548,20 +744,31 @@ func (s *Handlers) getChannelHistory(
}
func (s *Handlers) getDMHistory(
w http.ResponseWriter, r *http.Request,
uid, target, beforeTS string, limit int,
w http.ResponseWriter,
r *http.Request,
uid int64,
target string,
beforeID int64,
limit int,
) {
targetUser, err := s.params.Database.GetUserByNick(r.Context(), target)
targetID, err := s.params.Database.GetUserByNick(
r.Context(), target,
)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
s.respondError(
w, r, "user not found", http.StatusNotFound,
)
return
}
msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetUser.ID, beforeTS, limit)
msgs, err := s.params.Database.GetDMsBefore(
r.Context(), uid, targetID, beforeID, limit,
)
if err != nil {
s.log.Error("get dm history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
s.internalError(
w, r, "get dm history failed", err,
)
return
}

View File

@@ -9,13 +9,6 @@ import (
"errors"
)
var (
// ErrUserLookupNotAvailable is returned when the user lookup interface is not set.
ErrUserLookupNotAvailable = errors.New("user lookup not available")
// ErrChannelLookupNotAvailable is returned when the channel lookup interface is not set.
ErrChannelLookupNotAvailable = errors.New("channel lookup not available")
)
// DB is the interface that models use to query the database.
// This avoids a circular import with the db package.
type DB interface {
@@ -32,6 +25,12 @@ type ChannelLookup interface {
GetChannelByID(ctx context.Context, id string) (*Channel, error)
}
// Sentinel errors for model lookup methods.
var (
ErrUserLookupNotAvailable = errors.New("user lookup not available")
ErrChannelLookupNotAvailable = errors.New("channel lookup not available")
)
// Base is embedded in all model structs to provide database access.
type Base struct {
db DB
@@ -48,7 +47,7 @@ func (b *Base) GetDB() *sql.DB {
}
// GetUserLookup returns the DB as a UserLookup if it implements the interface.
func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn // intentional interface return for dependency inversion
func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn,nolintlint // returns interface by design
if ul, ok := b.db.(UserLookup); ok {
return ul
}
@@ -56,11 +55,8 @@ func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn // intentional inte
return nil
}
// GetChannelLookup returns the DB as a ChannelLookup
// if it implements the interface.
//
//nolint:ireturn // intentional interface return for dependency inversion
func (b *Base) GetChannelLookup() ChannelLookup {
// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface.
func (b *Base) GetChannelLookup() ChannelLookup { //nolint:ireturn,nolintlint // returns interface by design
if cl, ok := b.db.(ChannelLookup); ok {
return cl
}

View File

@@ -65,38 +65,43 @@ func (s *Server) SetupRoutes() {
r.Get("/channels/{channel}/members", s.h.HandleChannelMembers())
})
s.setupSPA()
}
func (s *Server) setupSPA() {
// Serve embedded SPA
distFS, err := fs.Sub(web.Dist, "dist")
if err != nil {
s.log.Error("failed to get web dist filesystem", "error", err)
} else {
fileServer := http.FileServer(http.FS(distFS))
s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
s.serveSPA(distFS, fileServer, w, r)
})
}
}
func (s *Server) serveSPA(
distFS fs.FS,
fileServer http.Handler,
w http.ResponseWriter,
r *http.Request,
) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
http.Error(w, "filesystem error", http.StatusInternalServerError)
return
}
fileServer := http.FileServer(http.FS(distFS))
// Try to serve the file; fall back to index.html for SPA routing.
f, err := readFS.ReadFile(r.URL.Path[1:])
if err != nil || len(f) == 0 {
indexHTML, _ := readFS.ReadFile("index.html")
s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
http.Error(w, "internal error", http.StatusInternalServerError)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(indexHTML)
return
}
return
}
f, err := readFS.ReadFile(r.URL.Path[1:])
if err != nil || len(f) == 0 {
indexHTML, _ := readFS.ReadFile("index.html")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(indexHTML)
return
}
fileServer.ServeHTTP(w, r)
})
fileServer.ServeHTTP(w, r)
}