1 Commits

Author SHA1 Message Date
clawbot
15caf5c8d2 Fix all lint/build issues on main branch (closes #13)
- Resolve duplicate method declarations (CreateUser, GetUserByToken,
  GetUserByNick) between db.go and queries.go by renaming queries.go
  methods to CreateSimpleUser, LookupUserByToken, LookupUserByNick
- Fix 377 lint issues across all categories:
  - nlreturn (107): Add blank lines before returns
  - wsl_v5 (156): Add required whitespace
  - noinlineerr (25): Use plain assignments instead of inline error handling
  - errcheck (15): Check all error return values
  - mnd (10): Extract magic numbers to named constants
  - err113 (7): Use wrapped static errors instead of dynamic errors
  - gosec (7): Fix SSRF, SQL injection warnings; add nolint for false positives
  - modernize (7): Replace interface{} with any
  - cyclop (2): Reduce cyclomatic complexity via command map dispatch
  - gocognit (1): Break down complex handler into sub-handlers
  - funlen (3): Extract long functions into smaller helpers
  - funcorder (4): Reorder methods (exported before unexported)
  - forcetypeassert (2): Add safe type assertions with ok checks
  - ireturn (2): Replace interface-returning methods with concrete lookups
  - noctx (3): Use NewRequestWithContext and ExecContext
  - tagliatelle (5): Fix JSON tag casing to camelCase
  - revive (4): Rename package from 'api' to 'chatapi'
  - rowserrcheck (8): Add rows.Err() checks after iteration
  - lll (2): Shorten long lines
  - perfsprint (5): Use strconv and string concatenation
  - nestif (2): Extract nested conditionals into helper methods
  - wastedassign (1): Remove wasted assignments
  - gosmopolitan (1): Add nolint for intentional Local() time display
  - usestdlibvars (1): Use http.MethodGet
  - godoclint (2): Remove duplicate package comments
- Fix broken migration 003_users.sql that conflicted with 002_schema.sql
  (different column types causing test failures)
- All tests pass, make check reports 0 issues
2026-02-20 02:51:32 -08:00
23 changed files with 1068 additions and 1939 deletions

View File

@@ -1,9 +1,8 @@
.git
node_modules
.DS_Store
bin/ bin/
chatd
data.db data.db
.env .env
.git
*.test *.test
*.out *.out
debug.log debug.log

View File

@@ -1,12 +0,0 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -1,9 +0,0 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-22
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

30
.gitignore vendored
View File

@@ -1,28 +1,7 @@
# OS
.DS_Store
Thumbs.db
# Editors
*.swp
*.swo
*~
*.bak
.idea/
.vscode/
*.sublime-*
# Node
node_modules/
# Environment / secrets
.env
.env.*
*.pem
*.key
# Build artifacts
/chatd /chatd
/bin/ /bin/
data.db
.env
*.exe *.exe
*.dll *.dll
*.so *.so
@@ -30,9 +9,6 @@ node_modules/
*.test *.test
*.out *.out
vendor/ vendor/
# Project
data.db
debug.log debug.log
/chat-cli
web/node_modules/ web/node_modules/
chat-cli

View File

@@ -1,7 +1,6 @@
# golang:1.24-alpine, 2026-02-26 FROM golang:1.24-alpine AS builder
FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
RUN apk add --no-cache git build-base RUN apk add --no-cache git
WORKDIR /src WORKDIR /src
COPY go.mod go.sum ./ COPY go.mod go.sum ./
@@ -9,16 +8,10 @@ RUN go mod download
COPY . . COPY . .
# Run all checks — build fails if branch is not green
# golangci-lint v2.1.6
RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d
RUN make check
ARG VERSION=dev ARG VERSION=dev
RUN go build -ldflags "-X main.Version=${VERSION}" -o /chatd ./cmd/chatd RUN go build -ldflags "-X main.Version=${VERSION}" -o /chatd ./cmd/chatd
# alpine:3.21, 2026-02-26 FROM alpine:3.21
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates RUN apk add --no-cache ca-certificates
COPY --from=builder /chatd /usr/local/bin/chatd COPY --from=builder /chatd /usr/local/bin/chatd

21
LICENSE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025 sneak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks .PHONY: all build lint fmt test check clean run debug
BINARY := chatd BINARY := chatd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -17,15 +17,18 @@ fmt:
gofmt -s -w . gofmt -s -w .
goimports -w . goimports -w .
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: test:
go test -timeout 30s -v -race -cover ./... go test -v -race -cover ./...
# check runs all validation without making changes # Check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong # Used by CI and Docker build — fails if anything is wrong
check: test lint fmt-check check:
@echo "==> Checking formatting..."
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
@echo "==> Running linter..."
golangci-lint run --config .golangci.yml ./...
@echo "==> Running tests..."
go test -v -race ./...
@echo "==> Building..." @echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd
@echo "==> All checks passed!" @echo "==> All checks passed!"
@@ -38,12 +41,3 @@ debug: build
clean: clean:
rm -rf bin/ chatd data.db rm -rf bin/ chatd data.db
docker:
docker build -t chat .
hooks:
@printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit
@printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit
@printf 'make check\n' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit

View File

@@ -656,8 +656,3 @@ embedded web client.
## License ## License
MIT MIT
## Author
[@sneak](https://sneak.berlin)

View File

@@ -1,182 +0,0 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary. Pre-1.0.0: modify existing migrations (no installed base assumed).
Post-1.0.0: add new migration files.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

View File

@@ -1,8 +1,9 @@
// Package chatapi provides a client for the chat server HTTP API. // Package chatapi provides a client for the chat server's HTTP API.
package chatapi package chatapi
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -14,16 +15,15 @@ import (
) )
const ( const (
httpTimeout = 30 * time.Second httpClientTimeout = 30
pollExtraDelay = 5 httpErrorMinCode = 400
httpErrThreshold = 400 pollExtraTimeout = 5
) )
// ErrHTTP is returned for non-2xx responses. // ErrHTTPError is returned when the server responds with an error status.
var ErrHTTP = errors.New("http error") var ErrHTTPError = errors.New("HTTP error")
// ErrUnexpectedFormat is returned when the response format is // ErrUnexpectedFormat is returned when a response has an unexpected structure.
// not recognised.
var ErrUnexpectedFormat = errors.New("unexpected format") var ErrUnexpectedFormat = errors.New("unexpected format")
// Client wraps HTTP calls to the chat server API. // Client wraps HTTP calls to the chat server API.
@@ -38,19 +38,14 @@ func NewClient(baseURL string) *Client {
return &Client{ return &Client{
BaseURL: baseURL, BaseURL: baseURL,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Timeout: httpTimeout, Timeout: httpClientTimeout * time.Second,
}, },
} }
} }
// CreateSession creates a new session on the server. // CreateSession creates a new session on the server.
func (c *Client) CreateSession( func (c *Client) CreateSession(nick string) (*SessionResponse, error) {
nick string, data, err := c.do("POST", "/api/v1/session", &SessionRequest{Nick: nick})
) (*SessionResponse, error) {
data, err := c.do(
"POST", "/api/v1/session",
&SessionRequest{Nick: nick},
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -92,15 +87,9 @@ func (c *Client) SendMessage(msg *Message) error {
} }
// PollMessages long-polls for new messages. // PollMessages long-polls for new messages.
func (c *Client) PollMessages( func (c *Client) PollMessages(afterID string, timeout int) ([]Message, error) {
afterID string, // Use a longer HTTP timeout than the server long-poll timeout.
timeout int, client := &http.Client{Timeout: time.Duration(timeout+pollExtraTimeout) * time.Second}
) ([]Message, error) {
pollTimeout := time.Duration(
timeout+pollExtraDelay,
) * time.Second
client := &http.Client{Timeout: pollTimeout}
params := url.Values{} params := url.Values{}
if afterID != "" { if afterID != "" {
@@ -114,16 +103,14 @@ func (c *Client) PollMessages(
path += "?" + params.Encode() path += "?" + params.Encode()
} }
req, err := http.NewRequest( //nolint:noctx // CLI tool req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, c.BaseURL+path, nil)
http.MethodGet, c.BaseURL+path, nil,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Authorization", "Bearer "+c.Token) req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := client.Do(req) //nolint:gosec,nolintlint // URL from user config resp, err := client.Do(req) //nolint:gosec // URL is constructed from trusted base URL + API path, not user-tainted
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -135,51 +122,37 @@ func (c *Client) PollMessages(
return nil, err return nil, err
} }
if resp.StatusCode >= httpErrThreshold { if resp.StatusCode >= httpErrorMinCode {
return nil, fmt.Errorf( return nil, fmt.Errorf("%w: %d: %s", ErrHTTPError, resp.StatusCode, string(data))
"%w: %d: %s",
ErrHTTP, resp.StatusCode, string(data),
)
} }
return decodeMessages(data) // The server may return an array directly or wrapped.
}
func decodeMessages(data []byte) ([]Message, error) {
var msgs []Message var msgs []Message
err := json.Unmarshal(data, &msgs) err = json.Unmarshal(data, &msgs)
if err == nil { if err != nil {
return msgs, nil // Try wrapped format.
}
var wrapped MessagesResponse var wrapped MessagesResponse
err2 := json.Unmarshal(data, &wrapped) err2 := json.Unmarshal(data, &wrapped)
if err2 != nil { if err2 != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("decode messages: %w (raw: %s)", err, string(data))
"decode messages: %w (raw: %s)",
err, string(data),
)
} }
return wrapped.Messages, nil msgs = wrapped.Messages
} }
// JoinChannel joins a channel via the unified command return msgs, nil
// endpoint. }
// JoinChannel joins a channel via the unified command endpoint.
func (c *Client) JoinChannel(channel string) error { func (c *Client) JoinChannel(channel string) error {
return c.SendMessage( return c.SendMessage(&Message{Command: "JOIN", To: channel})
&Message{Command: "JOIN", To: channel},
)
} }
// PartChannel leaves a channel via the unified command // PartChannel leaves a channel via the unified command endpoint.
// endpoint.
func (c *Client) PartChannel(channel string) error { func (c *Client) PartChannel(channel string) error {
return c.SendMessage( return c.SendMessage(&Message{Command: "PART", To: channel})
&Message{Command: "PART", To: channel},
)
} }
// ListChannels returns all channels on the server. // ListChannels returns all channels on the server.
@@ -200,13 +173,8 @@ func (c *Client) ListChannels() ([]Channel, error) {
} }
// GetMembers returns members of a channel. // GetMembers returns members of a channel.
func (c *Client) GetMembers( func (c *Client) GetMembers(channel string) ([]string, error) {
channel string, data, err := c.do("GET", "/api/v1/channels/"+url.PathEscape(channel)+"/members", nil)
) ([]string, error) {
path := "/api/v1/channels/" +
url.PathEscape(channel) + "/members"
data, err := c.do("GET", path, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -215,10 +183,16 @@ func (c *Client) GetMembers(
err = json.Unmarshal(data, &members) err = json.Unmarshal(data, &members)
if err != nil { if err != nil {
return nil, fmt.Errorf( // Try object format.
"%w: members: %s", var obj map[string]any
ErrUnexpectedFormat, string(data),
) err2 := json.Unmarshal(data, &obj)
if err2 != nil {
return nil, err
}
// Extract member names from whatever format.
return nil, fmt.Errorf("%w: members: %s", ErrUnexpectedFormat, string(data))
} }
return members, nil return members, nil
@@ -241,10 +215,7 @@ func (c *Client) GetServerInfo() (*ServerInfo, error) {
return &info, nil return &info, nil
} }
func (c *Client) do( func (c *Client) do(method, path string, body any) ([]byte, error) {
method, path string,
body any,
) ([]byte, error) {
var bodyReader io.Reader var bodyReader io.Reader
if body != nil { if body != nil {
@@ -256,9 +227,7 @@ func (c *Client) do(
bodyReader = bytes.NewReader(data) bodyReader = bytes.NewReader(data)
} }
req, err := http.NewRequest( //nolint:noctx // CLI tool req, err := http.NewRequestWithContext(context.Background(), method, c.BaseURL+path, bodyReader)
method, c.BaseURL+path, bodyReader,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("request: %w", err) return nil, fmt.Errorf("request: %w", err)
} }
@@ -269,7 +238,8 @@ func (c *Client) do(
req.Header.Set("Authorization", "Bearer "+c.Token) req.Header.Set("Authorization", "Bearer "+c.Token)
} }
resp, err := c.HTTPClient.Do(req) //nolint:gosec,nolintlint // URL from user config //nolint:gosec // URL built from trusted base + API path
resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("http: %w", err) return nil, fmt.Errorf("http: %w", err)
} }
@@ -281,11 +251,8 @@ func (c *Client) do(
return nil, fmt.Errorf("read body: %w", err) return nil, fmt.Errorf("read body: %w", err)
} }
if resp.StatusCode >= httpErrThreshold { if resp.StatusCode >= httpErrorMinCode {
return data, fmt.Errorf( return data, fmt.Errorf("%w: %d: %s", ErrHTTPError, resp.StatusCode, string(data))
"%w: %d: %s",
ErrHTTP, resp.StatusCode, string(data),
)
} }
return data, nil return data, nil

View File

@@ -35,8 +35,7 @@ type Message struct {
Meta any `json:"meta,omitempty"` Meta any `json:"meta,omitempty"`
} }
// BodyLines returns the body as a slice of strings (for text // BodyLines returns the body as a slice of strings (for text messages).
// messages).
func (m *Message) BodyLines() []string { func (m *Message) BodyLines() []string {
switch v := m.Body.(type) { switch v := m.Body.(type) {
case []any: case []any:

View File

@@ -1,27 +1,28 @@
// Package main implements chat-cli, an IRC-style terminal client. // Package main provides a terminal-based IRC-style chat client.
package main package main
import ( import (
"fmt" "fmt"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
api "git.eeqj.de/sneak/chat/cmd/chat-cli/api" "git.eeqj.de/sneak/chat/cmd/chat-cli/api"
) )
const ( const (
maxNickLen = 32
pollTimeoutSec = 15 pollTimeoutSec = 15
retryDelay = 2 * time.Second pollRetrySec = 2
maxNickLength = 32 splitNParts = 2
commandSplitArgs = 2
) )
// App holds the application state. // App holds the application state.
type App struct { type App struct {
ui *UI ui *UI
client *api.Client client *chatapi.Client
mu sync.Mutex mu sync.Mutex
nick string nick string
@@ -40,18 +41,12 @@ func main() {
app.ui.OnInput(app.handleInput) app.ui.OnInput(app.handleInput)
app.ui.SetStatus(app.nick, "", "disconnected") app.ui.SetStatus(app.nick, "", "disconnected")
app.ui.AddStatus( app.ui.AddStatus("Welcome to chat-cli — an IRC-style client")
"Welcome to chat-cli \u2014 an IRC-style client", app.ui.AddStatus("Type [yellow]/connect <server-url>[white] to begin, or [yellow]/help[white] for commands")
)
app.ui.AddStatus(
"Type [yellow]/connect <server-url>[white] " +
"to begin, or [yellow]/help[white] for commands",
)
err := app.ui.Run() err := app.ui.Run()
if err != nil { if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
} }
@@ -63,59 +58,51 @@ func (a *App) handleInput(text string) {
return return
} }
a.sendPlainText(text) // Plain text → PRIVMSG to current target.
}
func (a *App) sendPlainText(text string) {
a.mu.Lock() a.mu.Lock()
target := a.target target := a.target
connected := a.connected connected := a.connected
nick := a.nick
a.mu.Unlock() a.mu.Unlock()
if !connected { if !connected {
a.ui.AddStatus( a.ui.AddStatus("[red]Not connected. Use /connect <url>")
"[red]Not connected. Use /connect <url>",
)
return return
} }
if target == "" { if target == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]No target. Use /join #channel or /query nick")
"[red]No target. " +
"Use /join #channel or /query nick",
)
return return
} }
err := a.client.SendMessage(&api.Message{ err := a.client.SendMessage(&chatapi.Message{
Command: "PRIVMSG", Command: "PRIVMSG",
To: target, To: target,
Body: []string{text}, Body: []string{text},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Send error: %v", err))
fmt.Sprintf("[red]Send error: %v", err),
)
return return
} }
// Echo locally.
ts := time.Now().Format("15:04") ts := time.Now().Format("15:04")
a.ui.AddLine( a.mu.Lock()
target, nick := a.nick
fmt.Sprintf( a.mu.Unlock()
"[gray]%s [green]<%s>[white] %s",
ts, nick, text, a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text))
),
)
} }
func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch func (a *App) handleCommand(text string) {
parts := strings.SplitN(text, " ", 2) //nolint:mnd // split into cmd+args a.dispatchCommand(text)
}
func (a *App) dispatchCommand(text string) {
parts := strings.SplitN(text, " ", splitNParts)
cmd := strings.ToLower(parts[0]) cmd := strings.ToLower(parts[0])
args := "" args := ""
@@ -123,66 +110,64 @@ func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch
args = parts[1] args = parts[1]
} }
switch cmd { a.execCommand(cmd, args)
case "/connect": }
a.cmdConnect(args)
case "/nick": func (a *App) execCommand(cmd, args string) {
a.cmdNick(args) commands := a.commandMap()
case "/join":
a.cmdJoin(args) handler, ok := commands[cmd]
case "/part": if !ok {
a.cmdPart(args) a.ui.AddStatus("[red]Unknown command: " + cmd)
case "/msg":
a.cmdMsg(args) return
case "/query": }
a.cmdQuery(args)
case "/topic": handler(args)
a.cmdTopic(args) }
case "/names":
a.cmdNames() func (a *App) commandMap() map[string]func(string) {
case "/list": noArgs := func(fn func()) func(string) {
a.cmdList() return func(_ string) { fn() }
case "/window", "/w": }
a.cmdWindow(args)
case "/quit": return map[string]func(string){
a.cmdQuit() "/connect": a.cmdConnect,
case "/help": "/nick": a.cmdNick,
a.cmdHelp() "/join": a.cmdJoin,
default: "/part": a.cmdPart,
a.ui.AddStatus( "/msg": a.cmdMsg,
"[red]Unknown command: " + cmd, "/query": a.cmdQuery,
) "/topic": a.cmdTopic,
"/names": noArgs(a.cmdNames),
"/list": noArgs(a.cmdList),
"/window": a.cmdWindow,
"/w": a.cmdWindow,
"/quit": noArgs(a.cmdQuit),
"/help": noArgs(a.cmdHelp),
} }
} }
func (a *App) cmdConnect(serverURL string) { func (a *App) cmdConnect(serverURL string) {
if serverURL == "" { if serverURL == "" {
a.ui.AddStatus( a.ui.AddStatus("[red]Usage: /connect <server-url>")
"[red]Usage: /connect <server-url>",
)
return return
} }
serverURL = strings.TrimRight(serverURL, "/") serverURL = strings.TrimRight(serverURL, "/")
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("Connecting to %s...", serverURL))
fmt.Sprintf("Connecting to %s...", serverURL),
)
a.mu.Lock() a.mu.Lock()
nick := a.nick nick := a.nick
a.mu.Unlock() a.mu.Unlock()
client := api.NewClient(serverURL) client := chatapi.NewClient(serverURL)
resp, err := client.CreateSession(nick) resp, err := client.CreateSession(nick)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Connection failed: %v", err))
fmt.Sprintf(
"[red]Connection failed: %v", err,
),
)
return return
} }
@@ -194,14 +179,10 @@ func (a *App) cmdConnect(serverURL string) {
a.lastMsgID = "" a.lastMsgID = ""
a.mu.Unlock() a.mu.Unlock()
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[green]Connected! Nick: %s, Session: %s", resp.Nick, resp.SessionID))
fmt.Sprintf(
"[green]Connected! Nick: %s, Session: %s",
resp.Nick, resp.SessionID,
),
)
a.ui.SetStatus(resp.Nick, "", "connected") a.ui.SetStatus(resp.Nick, "", "connected")
// Start polling.
a.stopPoll = make(chan struct{}) a.stopPoll = make(chan struct{})
go a.pollLoop() go a.pollLoop()
@@ -223,26 +204,17 @@ func (a *App) cmdNick(nick string) {
a.nick = nick a.nick = nick
a.mu.Unlock() a.mu.Unlock()
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("Nick set to %s (will be used on connect)", nick))
fmt.Sprintf(
"Nick set to %s (will be used on connect)",
nick,
),
)
return return
} }
err := a.client.SendMessage(&api.Message{ err := a.client.SendMessage(&chatapi.Message{
Command: "NICK", Command: "NICK",
Body: []string{nick}, Body: []string{nick},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Nick change failed: %v", err))
fmt.Sprintf(
"[red]Nick change failed: %v", err,
),
)
return return
} }
@@ -253,9 +225,7 @@ func (a *App) cmdNick(nick string) {
a.mu.Unlock() a.mu.Unlock()
a.ui.SetStatus(nick, target, "connected") a.ui.SetStatus(nick, target, "connected")
a.ui.AddStatus( a.ui.AddStatus("Nick changed to " + nick)
"Nick changed to " + nick,
)
} }
func (a *App) cmdJoin(channel string) { func (a *App) cmdJoin(channel string) {
@@ -281,9 +251,7 @@ func (a *App) cmdJoin(channel string) {
err := a.client.JoinChannel(channel) err := a.client.JoinChannel(channel)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Join failed: %v", err))
fmt.Sprintf("[red]Join failed: %v", err),
)
return return
} }
@@ -294,10 +262,7 @@ func (a *App) cmdJoin(channel string) {
a.mu.Unlock() a.mu.Unlock()
a.ui.SwitchToBuffer(channel) a.ui.SwitchToBuffer(channel)
a.ui.AddLine( a.ui.AddLine(channel, "[yellow]*** Joined "+channel)
channel,
"[yellow]*** Joined "+channel,
)
a.ui.SetStatus(nick, channel, "connected") a.ui.SetStatus(nick, channel, "connected")
} }
@@ -325,17 +290,12 @@ func (a *App) cmdPart(channel string) {
err := a.client.PartChannel(channel) err := a.client.PartChannel(channel)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Part failed: %v", err))
fmt.Sprintf("[red]Part failed: %v", err),
)
return return
} }
a.ui.AddLine( a.ui.AddLine(channel, "[yellow]*** Left "+channel)
channel,
"[yellow]*** Left "+channel,
)
a.mu.Lock() a.mu.Lock()
@@ -351,8 +311,9 @@ func (a *App) cmdPart(channel string) {
} }
func (a *App) cmdMsg(args string) { func (a *App) cmdMsg(args string) {
parts := strings.SplitN(args, " ", 2) //nolint:mnd // split into target+text parts := strings.SplitN(args, " ", commandSplitArgs)
if len(parts) < 2 { //nolint:mnd // min args
if len(parts) < commandSplitArgs {
a.ui.AddStatus("[red]Usage: /msg <nick> <text>") a.ui.AddStatus("[red]Usage: /msg <nick> <text>")
return return
@@ -371,28 +332,19 @@ func (a *App) cmdMsg(args string) {
return return
} }
err := a.client.SendMessage(&api.Message{ err := a.client.SendMessage(&chatapi.Message{
Command: "PRIVMSG", Command: "PRIVMSG",
To: target, To: target,
Body: []string{text}, Body: []string{text},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Send failed: %v", err))
fmt.Sprintf("[red]Send failed: %v", err),
)
return return
} }
ts := time.Now().Format("15:04") ts := time.Now().Format("15:04")
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text))
a.ui.AddLine(
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, nick, text,
),
)
} }
func (a *App) cmdQuery(nick string) { func (a *App) cmdQuery(nick string) {
@@ -430,32 +382,25 @@ func (a *App) cmdTopic(args string) {
} }
if args == "" { if args == "" {
err := a.client.SendMessage(&api.Message{ // Query topic.
err := a.client.SendMessage(&chatapi.Message{
Command: "TOPIC", Command: "TOPIC",
To: target, To: target,
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Topic query failed: %v", err))
fmt.Sprintf(
"[red]Topic query failed: %v", err,
),
)
} }
return return
} }
err := a.client.SendMessage(&api.Message{ err := a.client.SendMessage(&chatapi.Message{
Command: "TOPIC", Command: "TOPIC",
To: target, To: target,
Body: []string{args}, Body: []string{args},
}) })
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Topic set failed: %v", err))
fmt.Sprintf(
"[red]Topic set failed: %v", err,
),
)
} }
} }
@@ -479,20 +424,12 @@ func (a *App) cmdNames() {
members, err := a.client.GetMembers(target) members, err := a.client.GetMembers(target)
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]Names failed: %v", err))
fmt.Sprintf("[red]Names failed: %v", err),
)
return return
} }
a.ui.AddLine( a.ui.AddLine(target, fmt.Sprintf("[cyan]*** Members of %s: %s", target, strings.Join(members, " ")))
target,
fmt.Sprintf(
"[cyan]*** Members of %s: %s",
target, strings.Join(members, " "),
),
)
} }
func (a *App) cmdList() { func (a *App) cmdList() {
@@ -508,9 +445,7 @@ func (a *App) cmdList() {
channels, err := a.client.ListChannels() channels, err := a.client.ListChannels()
if err != nil { if err != nil {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[red]List failed: %v", err))
fmt.Sprintf("[red]List failed: %v", err),
)
return return
} }
@@ -518,12 +453,7 @@ func (a *App) cmdList() {
a.ui.AddStatus("[cyan]*** Channel list:") a.ui.AddStatus("[cyan]*** Channel list:")
for _, ch := range channels { for _, ch := range channels {
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf(" %s (%d members) %s", ch.Name, ch.Members, ch.Topic))
fmt.Sprintf(
" %s (%d members) %s",
ch.Name, ch.Members, ch.Topic,
),
)
} }
a.ui.AddStatus("[cyan]*** End of channel list") a.ui.AddStatus("[cyan]*** End of channel list")
@@ -536,14 +466,17 @@ func (a *App) cmdWindow(args string) {
return return
} }
n, _ := strconv.Atoi(args) n := 0
_, _ = fmt.Sscanf(args, "%d", &n)
a.ui.SwitchBuffer(n) a.ui.SwitchBuffer(n)
a.mu.Lock() a.mu.Lock()
nick := a.nick nick := a.nick
a.mu.Unlock() a.mu.Unlock()
if n >= 0 && n < a.ui.BufferCount() { // Update target based on buffer.
if n < a.ui.BufferCount() {
buf := a.ui.buffers[n] buf := a.ui.buffers[n]
if buf.Name != "(status)" { if buf.Name != "(status)" {
@@ -562,9 +495,7 @@ func (a *App) cmdQuit() {
a.mu.Lock() a.mu.Lock()
if a.connected && a.client != nil { if a.connected && a.client != nil {
_ = a.client.SendMessage( _ = a.client.SendMessage(&chatapi.Message{Command: "QUIT"})
&api.Message{Command: "QUIT"},
)
} }
if a.stopPoll != nil { if a.stopPoll != nil {
@@ -578,18 +509,18 @@ func (a *App) cmdQuit() {
func (a *App) cmdHelp() { func (a *App) cmdHelp() {
help := []string{ help := []string{
"[cyan]*** chat-cli commands:", "[cyan]*** chat-cli commands:",
" /connect <url> \u2014 Connect to server", " /connect <url> Connect to server",
" /nick <name> \u2014 Change nickname", " /nick <name> Change nickname",
" /join #channel \u2014 Join channel", " /join #channel Join channel",
" /part [#chan] \u2014 Leave channel", " /part [#chan] Leave channel",
" /msg <nick> <text> \u2014 Send DM", " /msg <nick> <text> Send DM",
" /query <nick> \u2014 Open DM window", " /query <nick> Open DM window",
" /topic [text] \u2014 View/set topic", " /topic [text] View/set topic",
" /names \u2014 List channel members", " /names List channel members",
" /list \u2014 List channels", " /list List channels",
" /window <n> \u2014 Switch buffer (Alt+0-9)", " /window <n> Switch buffer (Alt+0-9)",
" /quit \u2014 Disconnect and exit", " /quit Disconnect and exit",
" /help \u2014 This help", " /help This help",
" Plain text sends to current target.", " Plain text sends to current target.",
} }
@@ -616,31 +547,28 @@ func (a *App) pollLoop() {
return return
} }
msgs, err := client.PollMessages( msgs, err := client.PollMessages(lastID, pollTimeoutSec)
lastID, pollTimeoutSec,
)
if err != nil { if err != nil {
time.Sleep(retryDelay) // Transient error — retry after delay.
time.Sleep(pollRetrySec * time.Second)
continue continue
} }
for i := range msgs { for _, msg := range msgs {
a.handleServerMessage(&msgs[i]) a.handleServerMessage(&msg)
if msgs[i].ID != "" { if msg.ID != "" {
a.mu.Lock() a.mu.Lock()
a.lastMsgID = msgs[i].ID a.lastMsgID = msg.ID
a.mu.Unlock() a.mu.Unlock()
} }
} }
} }
} }
func (a *App) handleServerMessage( func (a *App) handleServerMessage(msg *chatapi.Message) {
msg *api.Message, ts := a.formatMessageTS(msg)
) {
ts := a.parseMessageTS(msg)
a.mu.Lock() a.mu.Lock()
myNick := a.nick myNick := a.nick
@@ -648,7 +576,7 @@ func (a *App) handleServerMessage(
switch msg.Command { switch msg.Command {
case "PRIVMSG": case "PRIVMSG":
a.handlePrivmsgMsg(msg, ts, myNick) a.handlePrivmsg(msg, ts, myNick)
case "JOIN": case "JOIN":
a.handleJoinMsg(msg, ts) a.handleJoinMsg(msg, ts)
case "PART": case "PART":
@@ -666,114 +594,70 @@ func (a *App) handleServerMessage(
} }
} }
func (a *App) parseMessageTS(msg *api.Message) string { func (a *App) formatMessageTS(msg *chatapi.Message) string {
if msg.TS != "" { if msg.TS != "" {
t := msg.ParseTS() t := msg.ParseTS()
return t.In(time.Local).Format("15:04") //nolint:gosmopolitan // CLI uses local time return t.Local().Format("15:04") //nolint:gosmopolitan // Local time display is intentional for UI
} }
return time.Now().Format("15:04") return time.Now().Format("15:04")
} }
func (a *App) handlePrivmsgMsg( func (a *App) handlePrivmsg(msg *chatapi.Message, ts, myNick string) {
msg *api.Message,
ts, myNick string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
if msg.From == myNick { if msg.From == myNick {
// Skip our own echoed messages (already displayed locally).
return return
} }
target := msg.To target := msg.To
if !strings.HasPrefix(target, "#") { if !strings.HasPrefix(target, "#") {
// DM — use sender's nick as buffer name.
target = msg.From target = msg.From
} }
a.ui.AddLine( a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, msg.From, text))
target,
fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
ts, msg.From, text,
),
)
} }
func (a *App) handleJoinMsg( func (a *App) handleJoinMsg(msg *chatapi.Message, ts string) {
msg *api.Message, ts string,
) {
target := msg.To target := msg.To
if target == "" { if target != "" {
return a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has joined %s", ts, msg.From, target))
}
} }
a.ui.AddLine( func (a *App) handlePartMsg(msg *chatapi.Message, ts string) {
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has joined %s",
ts, msg.From, target,
),
)
}
func (a *App) handlePartMsg(
msg *api.Message, ts string,
) {
target := msg.To target := msg.To
if target == "" {
return
}
lines := msg.BodyLines() lines := msg.BodyLines()
reason := strings.Join(lines, " ")
if target != "" {
if reason != "" {
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s (%s)", ts, msg.From, target, reason))
} else {
a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s", ts, msg.From, target))
}
}
}
func (a *App) handleQuitMsg(msg *chatapi.Message, ts string) {
lines := msg.BodyLines()
reason := strings.Join(lines, " ") reason := strings.Join(lines, " ")
if reason != "" { if reason != "" {
a.ui.AddLine( a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit (%s)", ts, msg.From, reason))
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s (%s)",
ts, msg.From, target, reason,
),
)
} else { } else {
a.ui.AddLine( a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit", ts, msg.From))
target,
fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s",
ts, msg.From, target,
),
)
} }
} }
func (a *App) handleQuitMsg( func (a *App) handleNickMsg(msg *chatapi.Message, ts, myNick string) {
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
reason := strings.Join(lines, " ")
if reason != "" {
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit (%s)",
ts, msg.From, reason,
),
)
} else {
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit",
ts, msg.From,
),
)
}
}
func (a *App) handleNickMsg(
msg *api.Message, ts, myNick string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
newNick := "" newNick := ""
@@ -790,61 +674,33 @@ func (a *App) handleNickMsg(
a.ui.SetStatus(newNick, target, "connected") a.ui.SetStatus(newNick, target, "connected")
} }
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s is now known as %s", ts, msg.From, newNick))
fmt.Sprintf(
"[gray]%s [yellow]*** %s is now known as %s",
ts, msg.From, newNick,
),
)
} }
func (a *App) handleNoticeMsg( func (a *App) handleNoticeMsg(msg *chatapi.Message, ts string) {
msg *api.Message, ts string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
a.ui.AddStatus( a.ui.AddStatus(fmt.Sprintf("[gray]%s [magenta]--%s-- %s", ts, msg.From, text))
fmt.Sprintf(
"[gray]%s [magenta]--%s-- %s",
ts, msg.From, text,
),
)
}
func (a *App) handleTopicMsg(
msg *api.Message, ts string,
) {
if msg.To == "" {
return
} }
func (a *App) handleTopicMsg(msg *chatapi.Message, ts string) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
a.ui.AddLine( if msg.To != "" {
msg.To, a.ui.AddLine(msg.To, fmt.Sprintf("[gray]%s [cyan]*** %s set topic: %s", ts, msg.From, text))
fmt.Sprintf( }
"[gray]%s [cyan]*** %s set topic: %s",
ts, msg.From, text,
),
)
} }
func (a *App) handleDefaultMsg( func (a *App) handleDefaultMsg(msg *chatapi.Message, ts string) {
msg *api.Message, ts string,
) {
lines := msg.BodyLines() lines := msg.BodyLines()
text := strings.Join(lines, " ") text := strings.Join(lines, " ")
if text == "" { if text != "" {
return a.ui.AddStatus(fmt.Sprintf("[gray]%s [white][%s] %s", ts, msg.Command, text))
} }
a.ui.AddStatus(
fmt.Sprintf(
"[gray]%s [white][%s] %s",
ts, msg.Command, text,
),
)
} }

View File

@@ -39,10 +39,8 @@ func NewUI() *UI {
}, },
} }
ui.setupMessages() ui.setupWidgets()
ui.setupStatusBar() ui.setupInputCapture()
ui.setupInput()
ui.setupKeybindings()
ui.setupLayout() ui.setupLayout()
return ui return ui
@@ -86,11 +84,7 @@ func (ui *UI) AddLine(bufferName string, line string) {
// AddStatus adds a line to the status buffer (buffer 0). // AddStatus adds a line to the status buffer (buffer 0).
func (ui *UI) AddStatus(line string) { func (ui *UI) AddStatus(line string) {
ts := time.Now().Format("15:04") ts := time.Now().Format("15:04")
ui.AddLine("(status)", fmt.Sprintf("[gray]%s[white] %s", ts, line))
ui.AddLine(
"(status)",
fmt.Sprintf("[gray]%s[white] %s", ts, line),
)
} }
// SwitchBuffer switches to the buffer at index n. // SwitchBuffer switches to the buffer at index n.
@@ -101,8 +95,8 @@ func (ui *UI) SwitchBuffer(n int) {
} }
ui.currentBuffer = n ui.currentBuffer = n
buf := ui.buffers[n] buf := ui.buffers[n]
buf.Unread = 0 buf.Unread = 0
ui.messages.Clear() ui.messages.Clear()
@@ -116,7 +110,7 @@ func (ui *UI) SwitchBuffer(n int) {
}) })
} }
// SwitchToBuffer switches to the named buffer, creating it // SwitchToBuffer switches to the named buffer, creating it if needed.
func (ui *UI) SwitchToBuffer(name string) { func (ui *UI) SwitchToBuffer(name string) {
ui.app.QueueUpdateDraw(func() { ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(name) buf := ui.getOrCreateBuffer(name)
@@ -143,9 +137,7 @@ func (ui *UI) SwitchToBuffer(name string) {
} }
// SetStatus updates the status bar text. // SetStatus updates the status bar text.
func (ui *UI) SetStatus( func (ui *UI) SetStatus(nick, target, connStatus string) {
nick, target, connStatus string,
) {
ui.app.QueueUpdateDraw(func() { ui.app.QueueUpdateDraw(func() {
ui.refreshStatusWith(nick, target, connStatus) ui.refreshStatusWith(nick, target, connStatus)
}) })
@@ -167,7 +159,7 @@ func (ui *UI) BufferIndex(name string) int {
return -1 return -1
} }
func (ui *UI) setupMessages() { func (ui *UI) setupWidgets() {
ui.messages = tview.NewTextView(). ui.messages = tview.NewTextView().
SetDynamicColors(true). SetDynamicColors(true).
SetScrollable(true). SetScrollable(true).
@@ -176,25 +168,17 @@ func (ui *UI) setupMessages() {
ui.app.Draw() ui.app.Draw()
}) })
ui.messages.SetBorder(false) ui.messages.SetBorder(false)
}
func (ui *UI) setupStatusBar() {
ui.statusBar = tview.NewTextView(). ui.statusBar = tview.NewTextView().
SetDynamicColors(true) SetDynamicColors(true)
ui.statusBar.SetBackgroundColor(tcell.ColorNavy) ui.statusBar.SetBackgroundColor(tcell.ColorNavy)
ui.statusBar.SetTextColor(tcell.ColorWhite) ui.statusBar.SetTextColor(tcell.ColorWhite)
}
func (ui *UI) setupInput() {
ui.input = tview.NewInputField(). ui.input = tview.NewInputField().
SetFieldBackgroundColor(tcell.ColorBlack). SetFieldBackgroundColor(tcell.ColorBlack).
SetFieldTextColor(tcell.ColorWhite) SetFieldTextColor(tcell.ColorWhite)
ui.input.SetDoneFunc(func(key tcell.Key) { ui.input.SetDoneFunc(func(key tcell.Key) {
if key != tcell.KeyEnter { if key == tcell.KeyEnter {
return
}
text := ui.input.GetText() text := ui.input.GetText()
if text == "" { if text == "" {
return return
@@ -205,16 +189,13 @@ func (ui *UI) setupInput() {
if ui.onInput != nil { if ui.onInput != nil {
ui.onInput(text) ui.onInput(text)
} }
}
}) })
} }
func (ui *UI) setupKeybindings() { func (ui *UI) setupInputCapture() {
ui.app.SetInputCapture( ui.app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
func(event *tcell.EventKey) *tcell.EventKey { if event.Modifiers()&tcell.ModAlt != 0 {
if event.Modifiers()&tcell.ModAlt == 0 {
return event
}
r := event.Rune() r := event.Rune()
if r >= '0' && r <= '9' { if r >= '0' && r <= '9' {
idx := int(r - '0') idx := int(r - '0')
@@ -222,15 +203,14 @@ func (ui *UI) setupKeybindings() {
return nil return nil
} }
}
return event return event
}, })
)
} }
func (ui *UI) setupLayout() { func (ui *UI) setupLayout() {
ui.layout = tview.NewFlex(). ui.layout = tview.NewFlex().SetDirection(tview.FlexRow).
SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false). AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false). AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true) AddItem(ui.input, 1, 0, true)
@@ -239,46 +219,31 @@ func (ui *UI) setupLayout() {
ui.app.SetFocus(ui.input) ui.app.SetFocus(ui.input)
} }
// if needed.
func (ui *UI) refreshStatus() { func (ui *UI) refreshStatus() {
// Rebuilt from app state by parent QueueUpdateDraw. // Will be called from the main goroutine via QueueUpdateDraw parent.
// Rebuild status from app state — caller must provide context.
} }
func (ui *UI) refreshStatusWith( func (ui *UI) refreshStatusWith(nick, target, connStatus string) {
nick, target, connStatus string,
) {
var unreadParts []string var unreadParts []string
for i, buf := range ui.buffers { for i, buf := range ui.buffers {
if buf.Unread > 0 { if buf.Unread > 0 {
unreadParts = append( unreadParts = append(unreadParts, fmt.Sprintf("%d:%s(%d)", i, buf.Name, buf.Unread))
unreadParts,
fmt.Sprintf(
"%d:%s(%d)", i, buf.Name, buf.Unread,
),
)
} }
} }
unread := "" unread := ""
if len(unreadParts) > 0 { if len(unreadParts) > 0 {
unread = " [Act: " + unread = " [Act: " + strings.Join(unreadParts, ",") + "]"
strings.Join(unreadParts, ",") + "]"
} }
bufInfo := fmt.Sprintf( bufInfo := fmt.Sprintf("[%d:%s]", ui.currentBuffer, ui.buffers[ui.currentBuffer].Name)
"[%d:%s]",
ui.currentBuffer,
ui.buffers[ui.currentBuffer].Name,
)
ui.statusBar.Clear() ui.statusBar.Clear()
_, _ = fmt.Fprintf( _, _ = fmt.Fprintf(ui.statusBar, " [%s] %s %s %s%s",
ui.statusBar, " [%s] %s %s %s%s", connStatus, nick, bufInfo, target, unread)
connStatus, nick, bufInfo, target, unread,
)
} }
func (ui *UI) getOrCreateBuffer(name string) *Buffer { func (ui *UI) getOrCreateBuffer(name string) *Buffer {

View File

@@ -88,7 +88,7 @@ func NewTest(dsn string) (*Database, error) {
} }
// Item 9: Enable foreign keys // Item 9: Enable foreign keys
_, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx,nolintlint // no context in sql.Open path _, err = d.ExecContext(context.Background(), "PRAGMA foreign_keys = ON")
if err != nil { if err != nil {
_ = d.Close() _ = d.Close()
@@ -164,8 +164,8 @@ func (s *Database) GetChannelByID(
return c, nil return c, nil
} }
// GetUserByNickModel looks up a user by their nick. // GetUserByNick looks up a user by their nick.
func (s *Database) GetUserByNickModel( func (s *Database) GetUserByNick(
ctx context.Context, ctx context.Context,
nick string, nick string,
) (*models.User, error) { ) (*models.User, error) {
@@ -187,8 +187,8 @@ func (s *Database) GetUserByNickModel(
return u, nil return u, nil
} }
// GetUserByTokenModel looks up a user by their auth token. // GetUserByToken looks up a user by their auth token.
func (s *Database) GetUserByTokenModel( func (s *Database) GetUserByToken(
ctx context.Context, ctx context.Context,
token string, token string,
) (*models.User, error) { ) (*models.User, error) {
@@ -238,8 +238,8 @@ func (s *Database) UpdateUserLastSeen(
return err return err
} }
// CreateUserModel inserts a new user into the database. // CreateUser inserts a new user into the database.
func (s *Database) CreateUserModel( func (s *Database) CreateUser(
ctx context.Context, ctx context.Context,
id, nick, passwordHash string, id, nick, passwordHash string,
) (*models.User, error) { ) (*models.User, error) {
@@ -435,7 +435,7 @@ func (s *Database) AckMessages(
args[i] = id args[i] = id
} }
query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input query := fmt.Sprintf( //nolint:gosec // G201: placeholders are literal "?" strings, not user input
"DELETE FROM message_queue WHERE id IN (%s)", "DELETE FROM message_queue WHERE id IN (%s)",
strings.Join(placeholders, ","), strings.Join(placeholders, ","),
) )
@@ -612,7 +612,7 @@ func (s *Database) loadMigrations() ([]migration, error) {
return nil, fmt.Errorf("read schema dir: %w", err) return nil, fmt.Errorf("read schema dir: %w", err)
} }
migrations := make([]migration, 0, len(entries)) var migrations []migration
for _, entry := range entries { for _, entry := range entries {
if entry.IsDir() || if entry.IsDir() ||
@@ -677,12 +677,7 @@ func (s *Database) applyMigrations(
continue continue
} }
s.log.Info( err = s.applySingleMigration(ctx, m)
"applying migration",
"version", m.version, "name", m.name,
)
err = s.executeMigration(ctx, m)
if err != nil { if err != nil {
return err return err
} }
@@ -691,10 +686,12 @@ func (s *Database) applyMigrations(
return nil return nil
} }
func (s *Database) executeMigration( func (s *Database) applySingleMigration(ctx context.Context, m migration) error {
ctx context.Context, s.log.Info(
m migration, "applying migration",
) error { "version", m.version, "name", m.name,
)
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(

View File

@@ -1,6 +1,7 @@
package db_test package db_test
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -40,121 +41,215 @@ func TestCreateUser(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
id, token, err := d.CreateUser(ctx, nickAlice) u, err := d.CreateUser(ctx, "u1", nickAlice, "hash1")
if err != nil { if err != nil {
t.Fatalf("CreateUser: %v", err) t.Fatalf("CreateUser: %v", err)
} }
if id <= 0 { if u.ID != "u1" || u.Nick != nickAlice {
t.Errorf("expected positive id, got %d", id) t.Errorf("got user %+v", u)
}
if token == "" {
t.Error("expected non-empty token")
} }
} }
func TestGetUserByToken(t *testing.T) { func TestCreateAuthToken(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
_, token, _ := d.CreateUser(ctx, nickAlice) _, err := d.CreateUser(ctx, "u1", nickAlice, "h")
id, nick, err := d.GetUserByToken(ctx, token)
if err != nil { if err != nil {
t.Fatalf("GetUserByToken: %v", err) t.Fatalf("CreateUser: %v", err)
} }
if id <= 0 || nick != nickAlice { tok, err := d.CreateAuthToken(ctx, "tok1", "u1")
t.Errorf( if err != nil {
"got id=%d nick=%s, want nick=%s", t.Fatalf("CreateAuthToken: %v", err)
id, nick, nickAlice, }
if tok.Token != "tok1" || tok.UserID != "u1" {
t.Errorf("unexpected token: %+v", tok)
}
u, err := tok.User(ctx)
if err != nil {
t.Fatalf("AuthToken.User: %v", err)
}
if u.ID != "u1" || u.Nick != nickAlice {
t.Errorf("AuthToken.User got %+v", u)
}
}
func TestCreateChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
ch, err := d.CreateChannel(
ctx, "c1", "#general", "welcome", "+n",
) )
if err != nil {
t.Fatalf("CreateChannel: %v", err)
}
if ch.ID != "c1" || ch.Name != "#general" {
t.Errorf("unexpected channel: %+v", ch)
} }
} }
func TestGetUserByNick(t *testing.T) { func TestAddChannelMember(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
origID, _, _ := d.CreateUser(ctx, nickAlice) _, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#general", "", "")
id, err := d.GetUserByNick(ctx, nickAlice) cm, err := d.AddChannelMember(ctx, "c1", "u1", "+o")
if err != nil { if err != nil {
t.Fatalf("GetUserByNick: %v", err) t.Fatalf("AddChannelMember: %v", err)
} }
if id != origID { if cm.ChannelID != "c1" || cm.Modes != "+o" {
t.Errorf("got id %d, want %d", id, origID) t.Errorf("unexpected member: %+v", cm)
} }
} }
func TestGetOrCreateChannel(t *testing.T) { func TestCreateMessage(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
id1, err := d.GetOrCreateChannel(ctx, "#general") _, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
msg, err := d.CreateMessage(
ctx, "m1", "u1", nickAlice,
"#general", "message", "hello",
)
if err != nil { if err != nil {
t.Fatalf("GetOrCreateChannel: %v", err) t.Fatalf("CreateMessage: %v", err)
} }
if id1 <= 0 { if msg.ID != "m1" || msg.Body != "hello" {
t.Errorf("expected positive id, got %d", id1) t.Errorf("unexpected message: %+v", msg)
}
// Same channel returns same ID.
id2, err := d.GetOrCreateChannel(ctx, "#general")
if err != nil {
t.Fatalf("GetOrCreateChannel(2): %v", err)
}
if id1 != id2 {
t.Errorf("got different ids: %d vs %d", id1, id2)
} }
} }
func TestJoinAndListChannels(t *testing.T) { func TestQueueMessage(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice) _, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
ch1, _ := d.GetOrCreateChannel(ctx, "#alpha") _, _ = d.CreateUser(ctx, "u2", nickBob, "h")
ch2, _ := d.GetOrCreateChannel(ctx, "#beta") _, _ = d.CreateMessage(
ctx, "m1", "u1", nickAlice, "u2", "message", "hi",
)
_ = d.JoinChannel(ctx, ch1, uid) mq, err := d.QueueMessage(ctx, "u2", "m1")
_ = d.JoinChannel(ctx, ch2, uid)
channels, err := d.ListChannels(ctx, uid)
if err != nil { if err != nil {
t.Fatalf("ListChannels: %v", err) t.Fatalf("QueueMessage: %v", err)
}
if mq.UserID != "u2" || mq.MessageID != "m1" {
t.Errorf("unexpected queue entry: %+v", mq)
}
}
func TestCreateSession(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
sess, err := d.CreateSession(ctx, "s1", "u1")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if sess.ID != "s1" || sess.UserID != "u1" {
t.Errorf("unexpected session: %+v", sess)
}
u, err := sess.User(ctx)
if err != nil {
t.Fatalf("Session.User: %v", err)
}
if u.ID != "u1" {
t.Errorf("Session.User got %v, want u1", u.ID)
}
}
func TestCreateServerLink(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
sl, err := d.CreateServerLink(
ctx, "sl1", "peer1",
"https://peer.example.com", "keyhash", true,
)
if err != nil {
t.Fatalf("CreateServerLink: %v", err)
}
if sl.ID != "sl1" || !sl.IsActive {
t.Errorf("unexpected server link: %+v", sl)
}
}
func TestUserChannels(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#alpha", "", "")
_, _ = d.CreateChannel(ctx, "c2", "#beta", "", "")
_, _ = d.AddChannelMember(ctx, "c1", "u1", "")
_, _ = d.AddChannelMember(ctx, "c2", "u1", "")
channels, err := u.Channels(ctx)
if err != nil {
t.Fatalf("User.Channels: %v", err)
} }
if len(channels) != 2 { if len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels)) t.Fatalf("expected 2 channels, got %d", len(channels))
} }
if channels[0].Name != "#alpha" {
t.Errorf("first channel: got %s", channels[0].Name)
} }
func TestListChannelsEmpty(t *testing.T) { if channels[1].Name != "#beta" {
t.Errorf("second channel: got %s", channels[1].Name)
}
}
func TestUserChannelsEmpty(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice) u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
channels, err := d.ListChannels(ctx, uid) channels, err := u.Channels(ctx)
if err != nil { if err != nil {
t.Fatalf("ListChannels: %v", err) t.Fatalf("User.Channels: %v", err)
} }
if len(channels) != 0 { if len(channels) != 0 {
@@ -162,57 +257,60 @@ func TestListChannelsEmpty(t *testing.T) {
} }
} }
func TestPartChannel(t *testing.T) { func TestUserQueuedMessages(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
uid, _, _ := d.CreateUser(ctx, nickAlice) u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
chID, _ := d.GetOrCreateChannel(ctx, "#general") _, _ = d.CreateUser(ctx, "u2", nickBob, "h")
_ = d.JoinChannel(ctx, chID, uid) for i := range 3 {
_ = d.PartChannel(ctx, chID, uid) id := fmt.Sprintf("m%d", i)
channels, err := d.ListChannels(ctx, uid) _, _ = d.CreateMessage(
if err != nil { ctx, id, "u2", nickBob, "u1",
t.Fatalf("ListChannels: %v", err) "message", fmt.Sprintf("msg%d", i),
}
if len(channels) != 0 {
t.Errorf("expected 0 after part, got %d", len(channels))
}
}
func TestSendAndGetMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
_, err := d.SendMessage(ctx, chID, uid, "hello world")
if err != nil {
t.Fatalf("SendMessage: %v", err)
}
msgs, err := d.GetMessages(ctx, chID, 0, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Content != "hello world" {
t.Errorf(
"got content %q, want %q",
msgs[0].Content, "hello world",
) )
time.Sleep(10 * time.Millisecond)
_, _ = d.QueueMessage(ctx, "u1", id)
}
msgs, err := u.QueuedMessages(ctx)
if err != nil {
t.Fatalf("User.QueuedMessages: %v", err)
}
if len(msgs) != 3 {
t.Fatalf("expected 3 messages, got %d", len(msgs))
}
for i, msg := range msgs {
want := fmt.Sprintf("msg%d", i)
if msg.Body != want {
t.Errorf("msg %d: got %q, want %q", i, msg.Body, want)
}
}
}
func TestUserQueuedMessagesEmpty(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
u, _ := d.CreateUser(ctx, "u1", nickAlice, "h")
msgs, err := u.QueuedMessages(ctx)
if err != nil {
t.Fatalf("User.QueuedMessages: %v", err)
}
if len(msgs) != 0 {
t.Errorf("expected 0 messages, got %d", len(msgs))
} }
} }
@@ -220,38 +318,50 @@ func TestChannelMembers(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
uid1, _, _ := d.CreateUser(ctx, nickAlice) ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "")
uid2, _, _ := d.CreateUser(ctx, nickBob) _, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
uid3, _, _ := d.CreateUser(ctx, nickCharlie) _, _ = d.CreateUser(ctx, "u2", nickBob, "h")
chID, _ := d.GetOrCreateChannel(ctx, "#general") _, _ = d.CreateUser(ctx, "u3", nickCharlie, "h")
_, _ = d.AddChannelMember(ctx, "c1", "u1", "+o")
_, _ = d.AddChannelMember(ctx, "c1", "u2", "+v")
_, _ = d.AddChannelMember(ctx, "c1", "u3", "")
_ = d.JoinChannel(ctx, chID, uid1) members, err := ch.Members(ctx)
_ = d.JoinChannel(ctx, chID, uid2)
_ = d.JoinChannel(ctx, chID, uid3)
members, err := d.ChannelMembers(ctx, chID)
if err != nil { if err != nil {
t.Fatalf("ChannelMembers: %v", err) t.Fatalf("Channel.Members: %v", err)
} }
if len(members) != 3 { if len(members) != 3 {
t.Fatalf("expected 3 members, got %d", len(members)) t.Fatalf("expected 3 members, got %d", len(members))
} }
nicks := map[string]bool{}
for _, m := range members {
nicks[m.Nick] = true
}
for _, want := range []string{
nickAlice, nickBob, nickCharlie,
} {
if !nicks[want] {
t.Errorf("missing nick %q", want)
}
}
} }
func TestChannelMembersEmpty(t *testing.T) { func TestChannelMembersEmpty(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
chID, _ := d.GetOrCreateChannel(ctx, "#empty") ch, _ := d.CreateChannel(ctx, "c1", "#empty", "", "")
members, err := d.ChannelMembers(ctx, chID) members, err := ch.Members(ctx)
if err != nil { if err != nil {
t.Fatalf("ChannelMembers: %v", err) t.Fatalf("Channel.Members: %v", err)
} }
if len(members) != 0 { if len(members) != 0 {
@@ -259,166 +369,126 @@ func TestChannelMembersEmpty(t *testing.T) {
} }
} }
func TestSendDM(t *testing.T) { func TestChannelRecentMessages(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
uid1, _, _ := d.CreateUser(ctx, nickAlice) ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "")
uid2, _, _ := d.CreateUser(ctx, nickBob) _, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob")
if err != nil {
t.Fatalf("SendDM: %v", err)
}
if msgID <= 0 {
t.Errorf("expected positive msgID, got %d", msgID)
}
msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0)
if err != nil {
t.Fatalf("GetDMs: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 DM, got %d", len(msgs))
}
if msgs[0].Content != "hey bob" {
t.Errorf("got %q, want %q", msgs[0].Content, "hey bob")
}
}
func TestPollMessages(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid1, _, _ := d.CreateUser(ctx, nickAlice)
uid2, _, _ := d.CreateUser(ctx, nickBob)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid1)
_ = d.JoinChannel(ctx, chID, uid2)
_, _ = d.SendMessage(ctx, chID, uid2, "hello")
_, _ = d.SendDM(ctx, uid2, uid1, "private")
time.Sleep(10 * time.Millisecond)
msgs, err := d.PollMessages(ctx, uid1, 0, 0)
if err != nil {
t.Fatalf("PollMessages: %v", err)
}
if len(msgs) < 2 {
t.Fatalf("expected >=2 messages, got %d", len(msgs))
}
}
func TestChangeNick(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
_, token, _ := d.CreateUser(ctx, nickAlice)
err := d.ChangeNick(ctx, 1, "alice2")
if err != nil {
t.Fatalf("ChangeNick: %v", err)
}
_, nick, err := d.GetUserByToken(ctx, token)
if err != nil {
t.Fatalf("GetUserByToken: %v", err)
}
if nick != "alice2" {
t.Errorf("got nick %q, want alice2", nick)
}
}
func TestSetTopic(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
_, _ = d.GetOrCreateChannel(ctx, "#general")
err := d.SetTopic(ctx, "#general", uid, "new topic")
if err != nil {
t.Fatalf("SetTopic: %v", err)
}
channels, err := d.ListAllChannels(ctx)
if err != nil {
t.Fatalf("ListAllChannels: %v", err)
}
found := false
for _, ch := range channels {
if ch.Name == "#general" && ch.Topic == "new topic" {
found = true
}
}
if !found {
t.Error("topic was not updated")
}
}
func TestGetMessagesBefore(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := t.Context()
uid, _, _ := d.CreateUser(ctx, nickAlice)
chID, _ := d.GetOrCreateChannel(ctx, "#general")
_ = d.JoinChannel(ctx, chID, uid)
for i := range 5 { for i := range 5 {
_, _ = d.SendMessage( id := fmt.Sprintf("m%d", i)
ctx, chID, uid,
fmt.Sprintf("msg%d", i), _, _ = d.CreateMessage(
ctx, id, "u1", nickAlice, "#general",
"message", fmt.Sprintf("msg%d", i),
) )
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3) msgs, err := ch.RecentMessages(ctx, 3)
if err != nil { if err != nil {
t.Fatalf("GetMessagesBefore: %v", err) t.Fatalf("RecentMessages: %v", err)
} }
if len(msgs) != 3 { if len(msgs) != 3 {
t.Fatalf("expected 3, got %d", len(msgs)) t.Fatalf("expected 3, got %d", len(msgs))
} }
if msgs[0].Body != "msg4" {
t.Errorf("first: got %q, want msg4", msgs[0].Body)
} }
func TestListAllChannels(t *testing.T) { if msgs[2].Body != "msg2" {
t.Errorf("last: got %q, want msg2", msgs[2].Body)
}
}
func TestChannelRecentMessagesLargeLimit(t *testing.T) {
t.Parallel() t.Parallel()
d := setupTestDB(t) d := setupTestDB(t)
ctx := t.Context() ctx := context.Background()
_, _ = d.GetOrCreateChannel(ctx, "#alpha") ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "")
_, _ = d.GetOrCreateChannel(ctx, "#beta") _, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateMessage(
ctx, "m1", "u1", nickAlice,
"#general", "message", "only",
)
channels, err := d.ListAllChannels(ctx) msgs, err := ch.RecentMessages(ctx, 100)
if err != nil { if err != nil {
t.Fatalf("ListAllChannels: %v", err) t.Fatalf("RecentMessages: %v", err)
} }
if len(channels) != 2 { if len(msgs) != 1 {
t.Errorf("expected 2, got %d", len(channels)) t.Errorf("expected 1, got %d", len(msgs))
}
}
func TestChannelMemberUser(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#general", "", "")
cm, _ := d.AddChannelMember(ctx, "c1", "u1", "+o")
u, err := cm.User(ctx)
if err != nil {
t.Fatalf("ChannelMember.User: %v", err)
}
if u.ID != "u1" || u.Nick != nickAlice {
t.Errorf("got %+v", u)
}
}
func TestChannelMemberChannel(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateChannel(ctx, "c1", "#general", "topic", "+n")
cm, _ := d.AddChannelMember(ctx, "c1", "u1", "")
ch, err := cm.Channel(ctx)
if err != nil {
t.Fatalf("ChannelMember.Channel: %v", err)
}
if ch.ID != "c1" || ch.Topic != "topic" {
t.Errorf("got %+v", ch)
}
}
func TestDMMessage(t *testing.T) {
t.Parallel()
d := setupTestDB(t)
ctx := context.Background()
_, _ = d.CreateUser(ctx, "u1", nickAlice, "h")
_, _ = d.CreateUser(ctx, "u2", nickBob, "h")
msg, err := d.CreateMessage(
ctx, "m1", "u1", nickAlice, "u2", "message", "hey",
)
if err != nil {
t.Fatalf("CreateMessage DM: %v", err)
}
if msg.Target != "u2" {
t.Errorf("target: got %q, want u2", msg.Target)
} }
} }

View File

@@ -9,11 +9,7 @@ import (
"time" "time"
) )
const ( const tokenBytes = 32
defaultMessageLimit = 50
defaultPollLimit = 100
tokenBytes = 32
)
func generateToken() string { func generateToken() string {
b := make([]byte, tokenBytes) b := make([]byte, tokenBytes)
@@ -22,12 +18,8 @@ func generateToken() string {
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
// CreateUser registers a new user with the given nick and // CreateSimpleUser registers a new user with the given nick and returns the user ID and token.
// returns the user with token. func (s *Database) CreateSimpleUser(ctx context.Context, nick string) (int64, string, error) {
func (s *Database) CreateUser(
ctx context.Context,
nick string,
) (int64, string, error) {
token := generateToken() token := generateToken()
now := time.Now() now := time.Now()
@@ -43,64 +35,37 @@ func (s *Database) CreateUser(
return id, token, nil return id, token, nil
} }
// GetUserByToken returns user id and nick for a given auth // LookupUserByToken returns user id and nick for a given auth token.
// token. func (s *Database) LookupUserByToken(ctx context.Context, token string) (int64, string, error) {
func (s *Database) GetUserByToken(
ctx context.Context,
token string,
) (int64, string, error) {
var id int64 var id int64
var nick string var nick string
err := s.db.QueryRowContext( err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick)
ctx,
"SELECT id, nick FROM users WHERE token = ?",
token,
).Scan(&id, &nick)
if err != nil { if err != nil {
return 0, "", err return 0, "", err
} }
// Update last_seen // Update last_seen
_, _ = s.db.ExecContext( _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id)
ctx,
"UPDATE users SET last_seen = ? WHERE id = ?",
time.Now(), id,
)
return id, nick, nil return id, nick, nil
} }
// GetUserByNick returns user id for a given nick. // LookupUserByNick returns user id for a given nick.
func (s *Database) GetUserByNick( func (s *Database) LookupUserByNick(ctx context.Context, nick string) (int64, error) {
ctx context.Context,
nick string,
) (int64, error) {
var id int64 var id int64
err := s.db.QueryRowContext( err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id)
ctx,
"SELECT id FROM users WHERE nick = ?",
nick,
).Scan(&id)
return id, err return id, err
} }
// GetOrCreateChannel returns the channel id, creating it if // GetOrCreateChannel returns the channel id, creating it if needed.
// needed. func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) {
func (s *Database) GetOrCreateChannel(
ctx context.Context,
name string,
) (int64, error) {
var id int64 var id int64
err := s.db.QueryRowContext( err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id)
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&id)
if err == nil { if err == nil {
return id, nil return id, nil
} }
@@ -120,10 +85,7 @@ func (s *Database) GetOrCreateChannel(
} }
// JoinChannel adds a user to a channel. // JoinChannel adds a user to a channel.
func (s *Database) JoinChannel( func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error {
ctx context.Context,
channelID, userID int64,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)",
channelID, userID, time.Now()) channelID, userID, time.Now())
@@ -132,10 +94,7 @@ func (s *Database) JoinChannel(
} }
// PartChannel removes a user from a channel. // PartChannel removes a user from a channel.
func (s *Database) PartChannel( func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error {
ctx context.Context,
channelID, userID int64,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?",
channelID, userID) channelID, userID)
@@ -151,10 +110,7 @@ type ChannelInfo struct {
} }
// ListChannels returns all channels the user has joined. // ListChannels returns all channels the user has joined.
func (s *Database) ListChannels( func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) {
ctx context.Context,
userID int64,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic FROM channels c `SELECT c.id, c.name, c.topic FROM channels c
INNER JOIN channel_members cm ON cm.channel_id = c.id INNER JOIN channel_members cm ON cm.channel_id = c.id
@@ -163,31 +119,7 @@ func (s *Database) ListChannels(
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() return scanChannelInfoRows(rows)
var channels []ChannelInfo
for rows.Next() {
var ch ChannelInfo
err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic)
if err != nil {
return nil, err
}
channels = append(channels, ch)
}
err = rows.Err()
if err != nil {
return nil, err
}
if channels == nil {
channels = []ChannelInfo{}
}
return channels, nil
} }
// MemberInfo represents a channel member. // MemberInfo represents a channel member.
@@ -198,10 +130,7 @@ type MemberInfo struct {
} }
// ChannelMembers returns all members of a channel. // ChannelMembers returns all members of a channel.
func (s *Database) ChannelMembers( func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) {
ctx context.Context,
channelID int64,
) ([]MemberInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen FROM users u `SELECT u.id, u.nick, u.last_seen FROM users u
INNER JOIN channel_members cm ON cm.user_id = u.id INNER JOIN channel_members cm ON cm.user_id = u.id
@@ -217,9 +146,9 @@ func (s *Database) ChannelMembers(
for rows.Next() { for rows.Next() {
var m MemberInfo var m MemberInfo
err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen) scanErr := rows.Scan(&m.ID, &m.Nick, &m.LastSeen)
if err != nil { if scanErr != nil {
return nil, err return nil, scanErr
} }
members = append(members, m) members = append(members, m)
@@ -248,13 +177,13 @@ type MessageInfo struct {
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
} }
// GetMessages returns messages for a channel, optionally const defaultMessageLimit = 50
// after a given ID.
const defaultPollLimit = 100
// GetMessages returns messages for a channel, optionally after a given ID.
func (s *Database) GetMessages( func (s *Database) GetMessages(
ctx context.Context, ctx context.Context, channelID int64, afterID int64, limit int,
channelID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) { ) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = defaultMessageLimit
@@ -271,41 +200,12 @@ func (s *Database) GetMessages(
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() return scanChannelMessages(rows)
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil {
return nil, err
}
msgs = append(msgs, m)
}
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
} }
// SendMessage inserts a channel message. // SendMessage inserts a channel message.
func (s *Database) SendMessage( func (s *Database) SendMessage(
ctx context.Context, ctx context.Context, channelID, userID int64, content string,
channelID, userID int64,
content string,
) (int64, error) { ) (int64, error) {
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
"INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)", "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)",
@@ -319,9 +219,7 @@ func (s *Database) SendMessage(
// SendDM inserts a direct message. // SendDM inserts a direct message.
func (s *Database) SendDM( func (s *Database) SendDM(
ctx context.Context, ctx context.Context, fromID, toID int64, content string,
fromID, toID int64,
content string,
) (int64, error) { ) (int64, error) {
res, err := s.db.ExecContext(ctx, res, err := s.db.ExecContext(ctx,
"INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)", "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)",
@@ -333,13 +231,9 @@ func (s *Database) SendDM(
return res.LastInsertId() return res.LastInsertId()
} }
// GetDMs returns direct messages between two users after a // GetDMs returns direct messages between two users after a given ID.
// given ID.
func (s *Database) GetDMs( func (s *Database) GetDMs(
ctx context.Context, ctx context.Context, userA, userB int64, afterID int64, limit int,
userA, userB int64,
afterID int64,
limit int,
) ([]MessageInfo, error) { ) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = defaultMessageLimit
@@ -351,336 +245,128 @@ func (s *Database) GetDMs(
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id > ? WHERE m.is_dm = 1 AND m.id > ?
AND ((m.user_id = ? AND m.dm_target_id = ?) AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id ASC LIMIT ?`, afterID, userA, userB, userB, userA, limit)
ORDER BY m.id ASC LIMIT ?`,
afterID, userA, userB, userB, userA, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() return scanDMMessages(rows)
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
} }
m.IsDM = true // PollMessages returns all new messages (channel + DM) for a user after a given ID.
msgs = append(msgs, m)
}
err = rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
// PollMessages returns all new messages (channel + DM) for
// a user after a given ID.
func (s *Database) PollMessages( func (s *Database) PollMessages(
ctx context.Context, ctx context.Context, userID int64, afterID int64, limit int,
userID int64,
afterID int64,
limit int,
) ([]MessageInfo, error) { ) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultPollLimit limit = defaultPollLimit
} }
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, m.is_dm,
m.is_dm, COALESCE(t.nick, ''), m.created_at COALESCE(t.nick, ''), m.created_at
FROM messages m FROM messages m
INNER JOIN users u ON u.id = m.user_id INNER JOIN users u ON u.id = m.user_id
LEFT JOIN channels c ON c.id = m.channel_id LEFT JOIN channels c ON c.id = m.channel_id
LEFT JOIN users t ON t.id = m.dm_target_id LEFT JOIN users t ON t.id = m.dm_target_id
WHERE m.id > ? AND ( WHERE m.id > ? AND (
(m.is_dm = 0 AND m.channel_id IN (m.is_dm = 0 AND m.channel_id IN
(SELECT channel_id FROM channel_members (SELECT channel_id FROM channel_members WHERE user_id = ?))
WHERE user_id = ?)) OR (m.is_dm = 1 AND (m.user_id = ? OR m.dm_target_id = ?))
OR (m.is_dm = 1
AND (m.user_id = ? OR m.dm_target_id = ?))
) )
ORDER BY m.id ASC LIMIT ?`, ORDER BY m.id ASC LIMIT ?`, afterID, userID, userID, userID, limit)
afterID, userID, userID, userID, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() return scanPollMessages(rows)
msgs := make([]MessageInfo, 0)
for rows.Next() {
var (
m MessageInfo
isDM int
)
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick, &m.Content,
&isDM, &m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
} }
m.IsDM = isDM == 1 // GetMessagesBefore returns channel messages before a given ID (for history scrollback).
msgs = append(msgs, m)
}
err = rows.Err()
if err != nil {
return nil, err
}
return msgs, nil
}
func scanChannelMessages(
rows *sql.Rows,
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Channel, &m.Nick,
&m.Content, &m.CreatedAt,
)
if err != nil {
return nil, err
}
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func scanDMMessages(
rows *sql.Rows,
) ([]MessageInfo, error) {
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
err := rows.Scan(
&m.ID, &m.Nick, &m.Content,
&m.DMTarget, &m.CreatedAt,
)
if err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func reverseMessages(msgs []MessageInfo) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
}
// GetMessagesBefore returns channel messages before a given
// ID (for history scrollback).
func (s *Database) GetMessagesBefore( func (s *Database) GetMessagesBefore(
ctx context.Context, ctx context.Context, channelID int64, beforeID int64, limit int,
channelID int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) { ) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = defaultMessageLimit
} }
var query string query, args := buildChannelHistoryQuery(channelID, beforeID, limit)
var args []any
if beforeID > 0 {
query = `SELECT m.id, c.name, u.nick, m.content,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
AND m.id < ?
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, beforeID, limit}
} else {
query = `SELECT m.id, c.name, u.nick, m.content,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() msgs, err := scanChannelMessages(rows)
if err != nil {
msgs, scanErr := scanChannelMessages(rows) return nil, err
if scanErr != nil {
return nil, scanErr
} }
// Reverse to ascending order.
reverseMessages(msgs) reverseMessages(msgs)
return msgs, nil return msgs, nil
} }
// GetDMsBefore returns DMs between two users before a given // GetDMsBefore returns DMs between two users before a given ID (for history scrollback).
// ID (for history scrollback).
func (s *Database) GetDMsBefore( func (s *Database) GetDMsBefore(
ctx context.Context, ctx context.Context, userA, userB int64, beforeID int64, limit int,
userA, userB int64,
beforeID int64,
limit int,
) ([]MessageInfo, error) { ) ([]MessageInfo, error) {
if limit <= 0 { if limit <= 0 {
limit = defaultMessageLimit limit = defaultMessageLimit
} }
var query string query, args := buildDMHistoryQuery(userA, userB, beforeID, limit)
var args []any
if beforeID > 0 {
query = `SELECT m.id, u.nick, m.content, t.nick,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id < ?
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{
beforeID, userA, userB, userB, userA, limit,
}
} else {
query = `SELECT m.id, u.nick, m.content, t.nick,
m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{userA, userB, userB, userA, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() msgs, err := scanDMMessages(rows)
if err != nil {
msgs, scanErr := scanDMMessages(rows) return nil, err
if scanErr != nil {
return nil, scanErr
} }
// Reverse to ascending order.
reverseMessages(msgs) reverseMessages(msgs)
return msgs, nil return msgs, nil
} }
// ChangeNick updates a user's nickname. // ChangeNick updates a user's nickname.
func (s *Database) ChangeNick( func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error {
ctx context.Context,
userID int64,
newNick string,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"UPDATE users SET nick = ? WHERE id = ?", "UPDATE users SET nick = ? WHERE id = ?", newNick, userID)
newNick, userID)
return err return err
} }
// SetTopic sets the topic for a channel. // SetTopic sets the topic for a channel.
func (s *Database) SetTopic( func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error {
ctx context.Context,
channelName string,
_ int64,
topic string,
) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
"UPDATE channels SET topic = ? WHERE name = ?", "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName)
topic, channelName)
return err return err
} }
// GetServerName returns the server name (unused, config // GetServerName returns the server name (unused, config provides this).
// provides this).
func (s *Database) GetServerName() string { func (s *Database) GetServerName() string {
return "" return ""
} }
// ListAllChannels returns all channels. // ListAllChannels returns all channels.
func (s *Database) ListAllChannels( func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) {
ctx context.Context,
) ([]ChannelInfo, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
"SELECT id, name, topic FROM channels ORDER BY name") "SELECT id, name, topic FROM channels ORDER BY name")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return scanChannelInfoRows(rows)
}
// --- Helper functions ---
func scanChannelInfoRows(rows *sql.Rows) ([]ChannelInfo, error) {
defer func() { _ = rows.Close() }() defer func() { _ = rows.Close() }()
var channels []ChannelInfo var channels []ChannelInfo
@@ -688,17 +374,15 @@ func (s *Database) ListAllChannels(
for rows.Next() { for rows.Next() {
var ch ChannelInfo var ch ChannelInfo
err := rows.Scan( scanErr := rows.Scan(&ch.ID, &ch.Name, &ch.Topic)
&ch.ID, &ch.Name, &ch.Topic, if scanErr != nil {
) return nil, scanErr
if err != nil {
return nil, err
} }
channels = append(channels, ch) channels = append(channels, ch)
} }
err = rows.Err() err := rows.Err()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -709,3 +393,141 @@ func (s *Database) ListAllChannels(
return channels, nil return channels, nil
} }
func scanChannelMessages(rows *sql.Rows) ([]MessageInfo, error) {
defer func() { _ = rows.Close() }()
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
scanErr := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt)
if scanErr != nil {
return nil, scanErr
}
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func scanDMMessages(rows *sql.Rows) ([]MessageInfo, error) {
defer func() { _ = rows.Close() }()
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
scanErr := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt)
if scanErr != nil {
return nil, scanErr
}
m.IsDM = true
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func scanPollMessages(rows *sql.Rows) ([]MessageInfo, error) {
defer func() { _ = rows.Close() }()
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
var isDM int
scanErr := rows.Scan(
&m.ID, &m.Channel, &m.Nick, &m.Content, &isDM, &m.DMTarget, &m.CreatedAt,
)
if scanErr != nil {
return nil, scanErr
}
m.IsDM = isDM == 1
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, err
}
if msgs == nil {
msgs = []MessageInfo{}
}
return msgs, nil
}
func buildChannelHistoryQuery(channelID, beforeID int64, limit int) (string, []any) {
if beforeID > 0 {
return `SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ?
ORDER BY m.id DESC LIMIT ?`, []any{channelID, beforeID, limit}
}
return `SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
ORDER BY m.id DESC LIMIT ?`, []any{channelID, limit}
}
func buildDMHistoryQuery(userA, userB, beforeID int64, limit int) (string, []any) {
if beforeID > 0 {
return `SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id < ?
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`,
[]any{beforeID, userA, userB, userB, userA, limit}
}
return `SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1
AND ((m.user_id = ? AND m.dm_target_id = ?)
OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`,
[]any{userA, userB, userB, userA, limit}
}
func reverseMessages(msgs []MessageInfo) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
}

View File

@@ -1,53 +1,16 @@
-- Migration 003: Replace UUID-based tables with simple integer-keyed -- Migration 003: Add simple user auth columns.
-- tables for the HTTP API. Drops the 002 tables and recreates them. -- This migration adds token-based auth support for the web client.
-- Tables created by 002 (with TEXT ids) take precedence via IF NOT EXISTS.
-- We only add columns/indexes that don't already exist.
PRAGMA foreign_keys = OFF; -- Add token column to users table if it doesn't exist.
-- SQLite doesn't support IF NOT EXISTS for ALTER TABLE ADD COLUMN,
-- so we check via pragma first.
CREATE TABLE IF NOT EXISTS _migration_003_check (done INTEGER);
INSERT OR IGNORE INTO _migration_003_check VALUES (1);
DROP TABLE IF EXISTS message_queue; -- The web chat client's simple tables are only created if migration 002
DROP TABLE IF EXISTS sessions; -- didn't already create them with the ORM schema.
DROP TABLE IF EXISTS server_links; -- Since 002 creates all needed tables, 003 is effectively a no-op
DROP TABLE IF EXISTS messages; -- when run after 002.
DROP TABLE IF EXISTS channel_members; DROP TABLE IF EXISTS _migration_003_check;
DROP TABLE IF EXISTS auth_tokens;
DROP TABLE IF EXISTS channels;
DROP TABLE IF EXISTS users;
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
nick TEXT NOT NULL UNIQUE,
token TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
topic TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channel_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(channel_id, user_id)
);
CREATE TABLE messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
is_dm INTEGER NOT NULL DEFAULT 0,
dm_target_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_messages_channel ON messages(channel_id, created_at);
CREATE INDEX idx_messages_dm ON messages(user_id, dm_target_id, created_at);
CREATE INDEX idx_users_token ON users(token);
PRAGMA foreign_keys = ON;

View File

@@ -11,16 +11,10 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
) )
const ( const maxNickLen = 32
maxNickLen = 32
defaultHistory = 50
)
// authUser extracts the user from the Authorization header // authUser extracts the user from the Authorization header (Bearer token).
// (Bearer token). func (s *Handlers) authUser(r *http.Request) (int64, string, error) {
func (s *Handlers) authUser(
r *http.Request,
) (int64, string, error) {
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") { if !strings.HasPrefix(auth, "Bearer ") {
return 0, "", sql.ErrNoRows return 0, "", sql.ErrNoRows
@@ -28,20 +22,13 @@ func (s *Handlers) authUser(
token := strings.TrimPrefix(auth, "Bearer ") token := strings.TrimPrefix(auth, "Bearer ")
return s.params.Database.GetUserByToken(r.Context(), token) return s.params.Database.LookupUserByToken(r.Context(), token)
} }
func (s *Handlers) requireAuth( func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, string, bool) {
w http.ResponseWriter,
r *http.Request,
) (int64, string, bool) {
uid, nick, err := s.authUser(r) uid, nick, err := s.authUser(r)
if err != nil { if err != nil {
s.respondJSON( s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized)
w, r,
map[string]string{"error": "unauthorized"},
http.StatusUnauthorized,
)
return 0, "", false return 0, "", false
} }
@@ -49,48 +36,7 @@ func (s *Handlers) requireAuth(
return uid, nick, true return uid, nick, true
} }
func (s *Handlers) respondError( // HandleCreateSession creates a new user session and returns the auth token.
w http.ResponseWriter,
r *http.Request,
msg string,
code int,
) {
s.respondJSON(w, r, map[string]string{"error": msg}, code)
}
func (s *Handlers) internalError(
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
s.log.Error(msg, "error", err)
s.respondError(w, r, "internal error", http.StatusInternalServerError)
}
// bodyLines extracts body as string lines from a request body
// field.
func bodyLines(body any) []string {
switch v := body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)
}
}
return lines
case []string:
return v
default:
return nil
}
}
// HandleCreateSession creates a new user session and returns
// the auth token.
func (s *Handlers) HandleCreateSession() http.HandlerFunc { func (s *Handlers) HandleCreateSession() http.HandlerFunc {
type request struct { type request struct {
Nick string `json:"nick"` Nick string `json:"nick"`
@@ -107,10 +53,7 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc {
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
w, r, "invalid request",
http.StatusBadRequest,
)
return return
} }
@@ -118,42 +61,30 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc {
req.Nick = strings.TrimSpace(req.Nick) req.Nick = strings.TrimSpace(req.Nick)
if req.Nick == "" || len(req.Nick) > maxNickLen { if req.Nick == "" || len(req.Nick) > maxNickLen {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return return
} }
id, token, err := s.params.Database.CreateUser( id, token, err := s.params.Database.CreateSimpleUser(r.Context(), req.Nick)
r.Context(), req.Nick,
)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict)
w, r, "nick already taken",
http.StatusConflict,
)
return return
} }
s.internalError(w, r, "create user failed", err) s.log.Error("create user failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated)
w, r,
&response{ID: id, Nick: req.Nick, Token: token},
http.StatusCreated,
)
} }
} }
// HandleState returns the current user's info and joined // HandleState returns the current user's info and joined channels.
// channels.
func (s *Handlers) HandleState() http.HandlerFunc { func (s *Handlers) HandleState() http.HandlerFunc {
type response struct { type response struct {
ID int64 `json:"id"` ID int64 `json:"id"`
@@ -167,25 +98,15 @@ func (s *Handlers) HandleState() http.HandlerFunc {
return return
} }
channels, err := s.params.Database.ListChannels( channels, err := s.params.Database.ListChannels(r.Context(), uid)
r.Context(), uid,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("list channels failed", "error", err)
w, r, "list channels failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
s.respondJSON( s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK)
w, r,
&response{
ID: uid, Nick: nick,
Channels: channels,
},
http.StatusOK,
)
} }
} }
@@ -197,13 +118,10 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc {
return return
} }
channels, err := s.params.Database.ListAllChannels( channels, err := s.params.Database.ListAllChannels(r.Context())
r.Context(),
)
if err != nil { if err != nil {
s.internalError( s.log.Error("list all channels failed", "error", err)
w, r, "list all channels failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
@@ -222,29 +140,17 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
name := "#" + chi.URLParam(r, "channel") name := "#" + chi.URLParam(r, "channel")
var chID int64 chID, err := s.lookupChannelID(r, name)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
members, err := s.params.Database.ChannelMembers( members, err := s.params.Database.ChannelMembers(r.Context(), chID)
r.Context(), chID,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("channel members failed", "error", err)
w, r, "channel members failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
@@ -253,8 +159,8 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
} }
} }
// HandleGetMessages returns all new messages (channel + DM) // HandleGetMessages returns all new messages (channel + DM) for the user via long-polling.
// for the user via long-polling. // This is the single unified message stream — replaces separate channel/DM/poll endpoints.
func (s *Handlers) HandleGetMessages() http.HandlerFunc { func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r) uid, _, ok := s.requireAuth(w, r)
@@ -262,21 +168,13 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return return
} }
afterID, _ := strconv.ParseInt( afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64)
r.URL.Query().Get("after"), 10, 64, limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
)
limit, _ := strconv.Atoi( msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit)
r.URL.Query().Get("limit"),
)
msgs, err := s.params.Database.PollMessages(
r.Context(), uid, afterID, limit,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("get messages failed", "error", err)
w, r, "get messages failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
@@ -285,15 +183,8 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc {
} }
} }
type sendRequest struct { // HandleSendCommand handles all C2S commands via POST /messages.
Command string `json:"command"` // The "command" field dispatches to the appropriate logic.
To string `json:"to"`
Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"`
}
// HandleSendCommand handles all C2S commands via POST
// /messages.
func (s *Handlers) HandleSendCommand() http.HandlerFunc { func (s *Handlers) HandleSendCommand() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r) uid, nick, ok := s.requireAuth(w, r)
@@ -301,373 +192,275 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc {
return return
} }
var req sendRequest cmd, err := s.decodeSendCommand(r)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
w, r, "invalid request",
http.StatusBadRequest,
)
return return
} }
req.Command = strings.ToUpper( switch cmd.Command {
strings.TrimSpace(req.Command),
)
req.To = strings.TrimSpace(req.To)
s.dispatchCommand(w, r, uid, nick, &req)
}
}
func (s *Handlers) dispatchCommand(
w http.ResponseWriter,
r *http.Request,
uid int64,
nick string,
req *sendRequest,
) {
switch req.Command {
case "PRIVMSG", "NOTICE": case "PRIVMSG", "NOTICE":
s.handlePrivmsg(w, r, uid, req) s.handlePrivmsgCommand(w, r, uid, cmd)
case "JOIN": case "JOIN":
s.handleJoin(w, r, uid, req) s.handleJoinCommand(w, r, uid, cmd)
case "PART": case "PART":
s.handlePart(w, r, uid, req) s.handlePartCommand(w, r, uid, cmd)
case "NICK": case "NICK":
s.handleNick(w, r, uid, req) s.handleNickCommand(w, r, uid, cmd)
case "TOPIC": case "TOPIC":
s.handleTopic(w, r, uid, req) s.handleTopicCommand(w, r, cmd)
case "PING": case "PING":
s.respondJSON( s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK)
w, r,
map[string]string{
"command": "PONG",
"from": s.params.Config.ServerName,
},
http.StatusOK,
)
default: default:
_ = nick _ = nick // suppress unused warning
s.respondError( s.respondJSON(w, r, map[string]string{"error": "unknown command: " + cmd.Command}, http.StatusBadRequest)
w, r, }
"unknown command: "+req.Command,
http.StatusBadRequest,
)
} }
} }
func (s *Handlers) handlePrivmsg( type sendCommand struct {
w http.ResponseWriter, Command string `json:"command"`
r *http.Request, To string `json:"to"`
uid int64, Params []string `json:"params,omitempty"`
req *sendRequest, Body any `json:"body,omitempty"`
}
func (s *Handlers) decodeSendCommand(r *http.Request) (*sendCommand, error) {
var cmd sendCommand
err := json.NewDecoder(r.Body).Decode(&cmd)
if err != nil {
return nil, err
}
cmd.Command = strings.ToUpper(strings.TrimSpace(cmd.Command))
cmd.To = strings.TrimSpace(cmd.To)
return &cmd, nil
}
func bodyLines(body any) []string {
switch v := body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if str, ok := item.(string); ok {
lines = append(lines, str)
}
}
return lines
case []string:
return v
default:
return nil
}
}
func (s *Handlers) handlePrivmsgCommand(
w http.ResponseWriter, r *http.Request, uid int64, cmd *sendCommand,
) { ) {
if req.To == "" { if cmd.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
lines := bodyLines(req.Body) lines := bodyLines(cmd.Body)
if len(lines) == 0 { if len(lines) == 0 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest)
w, r, "body required", http.StatusBadRequest,
)
return return
} }
content := strings.Join(lines, "\n") content := strings.Join(lines, "\n")
if strings.HasPrefix(req.To, "#") { if strings.HasPrefix(cmd.To, "#") {
s.sendChannelMsg(w, r, uid, req.To, content) s.sendChannelMessage(w, r, uid, cmd.To, content)
} else { } else {
s.sendDM(w, r, uid, req.To, content) s.sendDirectMessage(w, r, uid, cmd.To, content)
} }
} }
func (s *Handlers) sendChannelMsg( func (s *Handlers) sendChannelMessage(
w http.ResponseWriter, w http.ResponseWriter, r *http.Request, uid int64, channel, content string,
r *http.Request,
uid int64,
channel, content string,
) { ) {
var chID int64 chID, err := s.lookupChannelID(r, channel)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
msgID, err := s.params.Database.SendMessage( msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, content)
r.Context(), chID, uid, content,
)
if err != nil { if err != nil {
s.internalError(w, r, "send message failed", err) s.log.Error("send message failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
} }
func (s *Handlers) sendDM( func (s *Handlers) sendDirectMessage(
w http.ResponseWriter, w http.ResponseWriter, r *http.Request, uid int64, toNick, content string,
r *http.Request,
uid int64,
toNick, content string,
) { ) {
targetID, err := s.params.Database.GetUserByNick( targetID, err := s.params.Database.LookupUserByNick(r.Context(), toNick)
r.Context(), toNick,
)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
w, r, "user not found", http.StatusNotFound,
)
return return
} }
msgID, err := s.params.Database.SendDM( msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, content)
r.Context(), uid, targetID, content,
)
if err != nil { if err != nil {
s.internalError(w, r, "send dm failed", err) s.log.Error("send dm failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
} }
func (s *Handlers) handleJoin( func (s *Handlers) handleJoinCommand(
w http.ResponseWriter, w http.ResponseWriter, r *http.Request, uid int64, cmd *sendCommand,
r *http.Request,
uid int64,
req *sendRequest,
) { ) {
if req.To == "" { if cmd.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
channel := req.To channel := cmd.To
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
chID, err := s.params.Database.GetOrCreateChannel( chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel)
r.Context(), channel,
)
if err != nil { if err != nil {
s.internalError( s.log.Error("get/create channel failed", "error", err)
w, r, "get/create channel failed", err, s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
)
return return
} }
err = s.params.Database.JoinChannel( err = s.params.Database.JoinChannel(r.Context(), chID, uid)
r.Context(), chID, uid,
)
if err != nil { if err != nil {
s.internalError(w, r, "join channel failed", err) s.log.Error("join channel failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK)
w, r,
map[string]string{
"status": "joined", "channel": channel,
},
http.StatusOK,
)
} }
func (s *Handlers) handlePart( func (s *Handlers) handlePartCommand(
w http.ResponseWriter, w http.ResponseWriter, r *http.Request, uid int64, cmd *sendCommand,
r *http.Request,
uid int64,
req *sendRequest,
) { ) {
if req.To == "" { if cmd.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
channel := req.To channel := cmd.To
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
var chID int64 chID, err := s.lookupChannelID(r, channel)
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil { if err != nil {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
w, r, "channel not found",
http.StatusNotFound,
)
return return
} }
err = s.params.Database.PartChannel( err = s.params.Database.PartChannel(r.Context(), chID, uid)
r.Context(), chID, uid,
)
if err != nil { if err != nil {
s.internalError(w, r, "part channel failed", err) s.log.Error("part channel failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK)
w, r,
map[string]string{
"status": "parted", "channel": channel,
},
http.StatusOK,
)
} }
func (s *Handlers) handleNick( func (s *Handlers) handleNickCommand(
w http.ResponseWriter, w http.ResponseWriter, r *http.Request, uid int64, cmd *sendCommand,
r *http.Request,
uid int64,
req *sendRequest,
) { ) {
lines := bodyLines(req.Body) lines := bodyLines(cmd.Body)
if len(lines) == 0 { if len(lines) == 0 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest)
w, r, "body required (new nick)",
http.StatusBadRequest,
)
return return
} }
newNick := strings.TrimSpace(lines[0]) newNick := strings.TrimSpace(lines[0])
if newNick == "" || len(newNick) > maxNickLen { if newNick == "" || len(newNick) > maxNickLen {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return return
} }
err := s.params.Database.ChangeNick( err := s.params.Database.ChangeNick(r.Context(), uid, newNick)
r.Context(), uid, newNick,
)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE") { if strings.Contains(err.Error(), "UNIQUE") {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict)
w, r, "nick already in use",
http.StatusConflict,
)
return return
} }
s.internalError(w, r, "change nick failed", err) s.log.Error("change nick failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK)
w, r,
map[string]string{"status": "ok", "nick": newNick},
http.StatusOK,
)
} }
func (s *Handlers) handleTopic( func (s *Handlers) handleTopicCommand(
w http.ResponseWriter, w http.ResponseWriter, r *http.Request, cmd *sendCommand,
r *http.Request,
uid int64,
req *sendRequest,
) { ) {
if req.To == "" { if cmd.To == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
w, r, "to field required",
http.StatusBadRequest,
)
return return
} }
lines := bodyLines(req.Body) lines := bodyLines(cmd.Body)
if len(lines) == 0 { if len(lines) == 0 {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest)
w, r, "body required (topic text)",
http.StatusBadRequest,
)
return return
} }
topic := strings.Join(lines, " ") topic := strings.Join(lines, " ")
channel := req.To channel := cmd.To
if !strings.HasPrefix(channel, "#") { if !strings.HasPrefix(channel, "#") {
channel = "#" + channel channel = "#" + channel
} }
err := s.params.Database.SetTopic( err := s.params.Database.SetTopic(r.Context(), channel, 0, topic)
r.Context(), channel, uid, topic,
)
if err != nil { if err != nil {
s.internalError(w, r, "set topic failed", err) s.log.Error("set topic failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return return
} }
s.respondJSON( s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK)
w, r,
map[string]string{"status": "ok", "topic": topic},
http.StatusOK,
)
} }
// HandleGetHistory returns message history for a specific // HandleGetHistory returns message history for a specific target (channel or DM).
// target (channel or DM).
func (s *Handlers) HandleGetHistory() http.HandlerFunc { func (s *Handlers) HandleGetHistory() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r) uid, _, ok := s.requireAuth(w, r)
@@ -677,103 +470,81 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc {
target := r.URL.Query().Get("target") target := r.URL.Query().Get("target")
if target == "" { if target == "" {
s.respondError( s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest)
w, r, "target required",
http.StatusBadRequest,
)
return return
} }
beforeID, _ := strconv.ParseInt( beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64)
r.URL.Query().Get("before"), 10, 64,
)
limit, _ := strconv.Atoi( limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
r.URL.Query().Get("limit"),
)
if limit <= 0 { if limit <= 0 {
limit = defaultHistory limit = defaultHistoryLimit
} }
if strings.HasPrefix(target, "#") { if strings.HasPrefix(target, "#") {
s.getChannelHistory( s.handleChannelHistory(w, r, target, beforeID, limit)
w, r, target, beforeID, limit,
)
} else { } else {
s.getDMHistory( s.handleDMHistory(w, r, uid, target, beforeID, limit)
w, r, uid, target, beforeID, limit,
)
} }
} }
} }
func (s *Handlers) getChannelHistory( const defaultHistoryLimit = 50
w http.ResponseWriter,
r *http.Request, func (s *Handlers) handleChannelHistory(
target string, w http.ResponseWriter, r *http.Request,
beforeID int64, target string, beforeID int64, limit int,
limit int,
) { ) {
chID, err := s.lookupChannelID(r, target)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
return
}
msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit)
if err != nil {
s.log.Error("get history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
func (s *Handlers) handleDMHistory(
w http.ResponseWriter, r *http.Request,
uid int64, target string, beforeID int64, limit int,
) {
targetID, err := s.params.Database.LookupUserByNick(r.Context(), target)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
return
}
msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit)
if err != nil {
s.log.Error("get dm history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
// lookupChannelID queries the channel ID by name using a parameterized query.
func (s *Handlers) lookupChannelID(r *http.Request, name string) (int64, error) {
var chID int64 var chID int64
err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec,nolintlint // parameterized query //nolint:gosec // query uses parameterized placeholder (?), not string interpolation
r.Context(), err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", "SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
target,
).Scan(&chID)
if err != nil {
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return return chID, err
}
msgs, err := s.params.Database.GetMessagesBefore(
r.Context(), chID, beforeID, limit,
)
if err != nil {
s.internalError(w, r, "get history failed", err)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
func (s *Handlers) getDMHistory(
w http.ResponseWriter,
r *http.Request,
uid int64,
target string,
beforeID int64,
limit int,
) {
targetID, err := s.params.Database.GetUserByNick(
r.Context(), target,
)
if err != nil {
s.respondError(
w, r, "user not found", http.StatusNotFound,
)
return
}
msgs, err := s.params.Database.GetDMsBefore(
r.Context(), uid, targetID, beforeID, limit,
)
if err != nil {
s.internalError(
w, r, "get dm history failed", err,
)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
} }
// HandleServerInfo returns server metadata (MOTD, name). // HandleServerInfo returns server metadata (MOTD, name).

View File

@@ -2,9 +2,13 @@ package models
import ( import (
"context" "context"
"errors"
"time" "time"
) )
// ErrUserLookupNotAvailable is returned when user lookup is not configured.
var ErrUserLookupNotAvailable = errors.New("user lookup not available")
// AuthToken represents an authentication token for a user session. // AuthToken represents an authentication token for a user session.
type AuthToken struct { type AuthToken struct {
Base Base
@@ -18,9 +22,5 @@ type AuthToken struct {
// User returns the user who owns this token. // User returns the user who owns this token.
func (t *AuthToken) User(ctx context.Context) (*User, error) { func (t *AuthToken) User(ctx context.Context) (*User, error) {
if ul := t.GetUserLookup(); ul != nil { return t.LookupUser(ctx, t.UserID)
return ul.GetUserByID(ctx, t.UserID)
}
return nil, ErrUserLookupNotAvailable
} }

View File

@@ -2,9 +2,13 @@ package models
import ( import (
"context" "context"
"errors"
"time" "time"
) )
// ErrChannelLookupNotAvailable is returned when channel lookup is not configured.
var ErrChannelLookupNotAvailable = errors.New("channel lookup not available")
// ChannelMember represents a user's membership in a channel. // ChannelMember represents a user's membership in a channel.
type ChannelMember struct { type ChannelMember struct {
Base Base
@@ -18,18 +22,10 @@ type ChannelMember struct {
// User returns the full User for this membership. // User returns the full User for this membership.
func (cm *ChannelMember) User(ctx context.Context) (*User, error) { func (cm *ChannelMember) User(ctx context.Context) (*User, error) {
if ul := cm.GetUserLookup(); ul != nil { return cm.LookupUser(ctx, cm.UserID)
return ul.GetUserByID(ctx, cm.UserID)
}
return nil, ErrUserLookupNotAvailable
} }
// Channel returns the full Channel for this membership. // Channel returns the full Channel for this membership.
func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) { func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) {
if cl := cm.GetChannelLookup(); cl != nil { return cm.LookupChannel(ctx, cm.ChannelID)
return cl.GetChannelByID(ctx, cm.ChannelID)
}
return nil, ErrChannelLookupNotAvailable
} }

View File

@@ -6,7 +6,6 @@ package models
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
) )
// DB is the interface that models use to query the database. // DB is the interface that models use to query the database.
@@ -25,12 +24,6 @@ type ChannelLookup interface {
GetChannelByID(ctx context.Context, id string) (*Channel, error) GetChannelByID(ctx context.Context, id string) (*Channel, error)
} }
// Sentinel errors for model lookup methods.
var (
ErrUserLookupNotAvailable = errors.New("user lookup not available")
ErrChannelLookupNotAvailable = errors.New("channel lookup not available")
)
// Base is embedded in all model structs to provide database access. // Base is embedded in all model structs to provide database access.
type Base struct { type Base struct {
db DB db DB
@@ -46,20 +39,22 @@ func (b *Base) GetDB() *sql.DB {
return b.db.GetDB() return b.db.GetDB()
} }
// GetUserLookup returns the DB as a UserLookup if it implements the interface. // LookupUser looks up a user by ID if the database supports it.
func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn,nolintlint // returns interface by design func (b *Base) LookupUser(ctx context.Context, id string) (*User, error) {
if ul, ok := b.db.(UserLookup); ok { ul, ok := b.db.(UserLookup)
return ul if !ok {
return nil, ErrUserLookupNotAvailable
} }
return nil return ul.GetUserByID(ctx, id)
} }
// GetChannelLookup returns the DB as a ChannelLookup if it implements the interface. // LookupChannel looks up a channel by ID if the database supports it.
func (b *Base) GetChannelLookup() ChannelLookup { //nolint:ireturn,nolintlint // returns interface by design func (b *Base) LookupChannel(ctx context.Context, id string) (*Channel, error) {
if cl, ok := b.db.(ChannelLookup); ok { cl, ok := b.db.(ChannelLookup)
return cl if !ok {
return nil, ErrChannelLookupNotAvailable
} }
return nil return cl.GetChannelByID(ctx, id)
} }

View File

@@ -18,9 +18,5 @@ type Session struct {
// User returns the user who owns this session. // User returns the user who owns this session.
func (s *Session) User(ctx context.Context) (*User, error) { func (s *Session) User(ctx context.Context) (*User, error) {
if ul := s.GetUserLookup(); ul != nil { return s.LookupUser(ctx, s.UserID)
return ul.GetUserByID(ctx, s.UserID)
}
return nil, ErrUserLookupNotAvailable
} }

View File

@@ -20,6 +20,13 @@ const routeTimeout = 60 * time.Second
func (s *Server) SetupRoutes() { func (s *Server) SetupRoutes() {
s.router = chi.NewRouter() s.router = chi.NewRouter()
s.setupMiddleware()
s.setupHealthAndMetrics()
s.setupAPIRoutes()
s.setupSPA()
}
func (s *Server) setupMiddleware() {
s.router.Use(middleware.Recoverer) s.router.Use(middleware.Recoverer)
s.router.Use(middleware.RequestID) s.router.Use(middleware.RequestID)
s.router.Use(s.mw.Logging()) s.router.Use(s.mw.Logging())
@@ -37,63 +44,54 @@ func (s *Server) SetupRoutes() {
}) })
s.router.Use(sentryHandler.Handle) s.router.Use(sentryHandler.Handle)
} }
}
// Health check func (s *Server) setupHealthAndMetrics() {
s.router.Get("/.well-known/healthcheck.json", s.h.HandleHealthCheck()) s.router.Get("/.well-known/healthcheck.json", s.h.HandleHealthCheck())
// Protected metrics endpoint
if viper.GetString("METRICS_USERNAME") != "" { if viper.GetString("METRICS_USERNAME") != "" {
s.router.Group(func(r chi.Router) { s.router.Group(func(r chi.Router) {
r.Use(s.mw.MetricsAuth()) r.Use(s.mw.MetricsAuth())
r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP)) r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP))
}) })
} }
}
// API v1 func (s *Server) setupAPIRoutes() {
s.router.Route("/api/v1", func(r chi.Router) { s.router.Route("/api/v1", func(r chi.Router) {
r.Get("/server", s.h.HandleServerInfo()) r.Get("/server", s.h.HandleServerInfo())
r.Post("/session", s.h.HandleCreateSession()) r.Post("/session", s.h.HandleCreateSession())
// Unified state and message endpoints
r.Get("/state", s.h.HandleState()) r.Get("/state", s.h.HandleState())
r.Get("/messages", s.h.HandleGetMessages()) r.Get("/messages", s.h.HandleGetMessages())
r.Post("/messages", s.h.HandleSendCommand()) r.Post("/messages", s.h.HandleSendCommand())
r.Get("/history", s.h.HandleGetHistory()) r.Get("/history", s.h.HandleGetHistory())
// Channels
r.Get("/channels", s.h.HandleListAllChannels()) r.Get("/channels", s.h.HandleListAllChannels())
r.Get("/channels/{channel}/members", s.h.HandleChannelMembers()) r.Get("/channels/{channel}/members", s.h.HandleChannelMembers())
}) })
}
// Serve embedded SPA func (s *Server) setupSPA() {
distFS, err := fs.Sub(web.Dist, "dist") distFS, err := fs.Sub(web.Dist, "dist")
if err != nil { if err != nil {
s.log.Error("failed to get web dist filesystem", "error", err) s.log.Error("failed to get web dist filesystem", "error", err)
} else {
fileServer := http.FileServer(http.FS(distFS))
s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
s.serveSPA(distFS, fileServer, w, r)
})
}
}
func (s *Server) serveSPA(
distFS fs.FS,
fileServer http.Handler,
w http.ResponseWriter,
r *http.Request,
) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
http.Error(w, "filesystem error", http.StatusInternalServerError)
return return
} }
// Try to serve the file; fall back to index.html for SPA routing. fileServer := http.FileServer(http.FS(distFS))
f, err := readFS.ReadFile(r.URL.Path[1:])
if err != nil || len(f) == 0 { s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
http.NotFound(w, r)
return
}
f, readErr := readFS.ReadFile(r.URL.Path[1:])
if readErr != nil || len(f) == 0 {
indexHTML, _ := readFS.ReadFile("index.html") indexHTML, _ := readFS.ReadFile("index.html")
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
@@ -104,4 +102,5 @@ func (s *Server) serveSPA(
} }
fileServer.ServeHTTP(w, r) fileServer.ServeHTTP(w, r)
})
} }