16 Commits

Author SHA1 Message Date
clawbot
e0a93dea5f fix: suppress gosec false positives for trusted URL construction
Add nolint:gosec annotations for:
- Client.Do calls using URLs built from trusted BaseURL + hardcoded paths
- Test helper HTTP calls using test server URLs
- Safe integer-to-rune conversion in bounded loop (0-19)
2026-02-20 02:06:31 -08:00
clawbot
1aac9cf480 build: Dockerfile non-root user, healthcheck, .dockerignore 2026-02-11 00:50:13 -08:00
clawbot
eefb81ed8d fix: resolve all golangci-lint issues
- Refactor test helpers (sendCommand, getJSON) to return (int, map[string]any)
  instead of (*http.Response, map[string]any) to fix bodyclose warnings
- Add doReq/doReqAuth helpers using NewRequestWithContext to fix noctx
- Check all error returns (errcheck, errchkjson)
- Use integer range syntax (intrange) for Go 1.22+
- Use http.Method* constants (usestdlibvars)
- Replace fmt.Sprintf with string concatenation where possible (perfsprint)
- Reorder UI methods: exported before unexported (funcorder)
- Add lint target to Makefile
- Disable overly pedantic linters in .golangci.yml (paralleltest, dupl,
  noinlineerr, wsl_v5, nlreturn, lll, tagliatelle, goconst, funlen)
2026-02-10 18:52:17 -08:00
clawbot
d0656b069d fix: golangci-lint v2 config and lint-clean production code
- Fix .golangci.yml for v2 format (linters-settings -> linters.settings)
- All production code now passes golangci-lint with zero issues
- Line length 88, funlen 80/50, cyclop 15, dupl 100
- Extract shared helpers in db (scanChannels, scanInt64s, scanMessages)
- Split runMigrations into applyMigration/execMigration
- Fix fanOut return signature (remove unused int64)
- Add fanOutSilent helper to avoid dogsled
- Rewrite CLI code for lint compliance (nlreturn, wsl_v5, noctx, etc)
- Rename CLI api package to chatapi to avoid revive var-naming
- Fix all noinlineerr, mnd, perfsprint, funcorder issues
- Fix db tests: extract helpers, add t.Parallel, proper error checks
- Broker tests already clean
- Handler integration tests still have lint issues (next commit)
2026-02-10 18:50:24 -08:00
clawbot
deda5d81ae fix: CLI client types mismatched server response format
- SessionResponse: use 'id' (int64) not 'session_id'/'client_id'
- StateResponse: match actual server response shape
- GetMembers: strip '#' from channel name for URL path
- These bugs prevented the CLI from working correctly with the server
2026-02-10 18:23:19 -08:00
clawbot
b0358471ae chore: deduplicate broker tests, clean up test imports 2026-02-10 18:22:38 -08:00
clawbot
5c95d7cdf6 fix: CLI poll loop used UUID instead of queue cursor (last_id)
The poll loop was storing msg.ID (UUID string) as afterID, but the server
expects the integer queue cursor from last_id. This caused the CLI to
re-fetch ALL messages on every poll cycle.

- Change PollMessages to accept int64 afterID and return PollResult with LastID
- Track lastQID (queue cursor) instead of lastMsgID (UUID)
- Parse the wrapped MessagesResponse properly
2026-02-10 18:22:08 -08:00
clawbot
ae1c67050e build: update Dockerfile with tests and multi-stage build, add Makefile
- Dockerfile runs go test before building
- Multi-stage: build+test stage, minimal alpine final image
- Add gcc/musl-dev for CGO (sqlite)
- Trim binaries with -s -w ldflags
- Makefile with build, test, clean, docker targets
- Version injection via ldflags
2026-02-10 18:21:07 -08:00
clawbot
736d179513 test: add comprehensive test suite
- Integration tests for all API endpoints (session, state, channels, messages)
- Tests for all commands: PRIVMSG, JOIN, PART, NICK, TOPIC, QUIT, PING
- Edge cases: duplicate nick, empty/invalid inputs, malformed JSON, bad auth
- Long-poll tests: delivery on notify and timeout behavior
- DM tests: delivery to recipient, echo to sender, nonexistent user
- Ephemeral channel cleanup test
- Concurrent session creation test
- Nick broadcast to channel members test
- DB unit tests: all CRUD operations, message queue, history
- Broker unit tests: wait/notify, remove, concurrent access
2026-02-10 18:20:45 -08:00
clawbot
9c8e40d014 Comprehensive README: full protocol spec, API reference, architecture, security model
Expanded from ~700 lines to ~2200 lines covering:
- Complete protocol specification (every command, field, behavior)
- Full API reference with request/response examples for all endpoints
- Architecture deep-dive (session model, queue system, broker, message flow)
- Sequence diagrams for channel messages, DMs, and JOIN flows
- All design decisions with rationale (no accounts, JSON, opaque tokens, etc.)
- Canonicalization and signing spec (JCS, Ed25519, TOFU)
- Security model (threat model, authentication, key management)
- Federation design (link establishment, relay, state sync, S2S commands)
- Storage schema with all tables and columns documented
- Configuration reference with all environment variables
- Deployment guide (Docker, binary, reverse proxy, SQLite considerations)
- Client development guide with curl examples and Python/JS code
- Hashcash proof-of-work spec (challenge/response flow, adaptive difficulty)
- Detailed roadmap (MVP, post-MVP, future)
- Project structure with every directory explained
2026-02-10 18:18:29 -08:00
clawbot
c4d97b2cad refactor: clean up handlers, add input validation, remove raw SQL from handlers
- Merge fanOut/fanOutDirect into single fanOut method
- Move channel lookup to db.GetChannelByName
- Add regex validation for nicks and channel names
- Split HandleSendCommand into per-command helper methods
- Add charset to Content-Type header
- Add sentinel error for unauthorized
- Cap history limit to 500
- Skip NICK change if new == old
- Add empty command check
2026-02-10 18:16:23 -08:00
clawbot
647a7ce313 Revert: exclude chat-cli from final Docker image (server-only)
CLI is built during Docker build to verify compilation, but only chatd
is included in the final image. CLI distributed separately.
2026-02-10 18:10:45 -08:00
clawbot
cee2506e8d Document hashcash proof-of-work plan for session rate limiting 2026-02-10 18:10:32 -08:00
clawbot
63957c6a46 Include chat-cli in final Docker image 2026-02-10 18:10:05 -08:00
clawbot
0b84872178 Update Dockerfile for Go 1.24, no Node build step needed
SPA is vanilla JS shipped as static files in web/dist/,
no npm build step required.
2026-02-10 18:09:24 -08:00
clawbot
1f54b281fd MVP: IRC envelope format, long-polling, per-client queues, SPA rewrite
Major changes:
- Consolidated schema into single migration with IRC envelope format
- Messages table stores command/from/to/body(JSON)/meta(JSON) per spec
- Per-client delivery queues (client_queues table) with fan-out
- In-memory broker for long-poll notifications (no busy polling)
- GET /messages supports ?after=<queue_id>&timeout=15 long-polling
- All commands (JOIN/PART/NICK/TOPIC/QUIT/PING) broadcast events
- Channels are ephemeral (deleted when last member leaves)
- PRIVMSG to nicks (DMs) fan out to both sender and recipient
- SPA rewritten in vanilla JS (no build step needed):
  - Long-poll via recursive fetch (not setInterval)
  - IRC envelope parsing with system message display
  - /nick, /join, /part, /msg, /quit commands
  - Unread indicators on inactive tabs
  - DM tabs from user list clicks
- Removed unused models package (was for UUID-based schema)
- Removed conflicting UUID-based db methods
- Increased HTTP write timeout to 60s for long-poll support
2026-02-10 18:09:10 -08:00
29 changed files with 1734 additions and 2957 deletions

View File

@@ -1,12 +0,0 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -1,9 +0,0 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-22
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

30
.gitignore vendored
View File

@@ -1,28 +1,7 @@
# OS
.DS_Store
Thumbs.db
# Editors
*.swp
*.swo
*~
*.bak
.idea/
.vscode/
*.sublime-*
# Node
node_modules/
# Environment / secrets
.env
.env.*
*.pem
*.key
# Build artifacts
/chatd
/bin/
data.db
.env
*.exe
*.dll
*.so
@@ -30,9 +9,6 @@ node_modules/
*.test
*.out
vendor/
# Project
data.db
debug.log
/chat-cli
web/node_modules/
chat-cli

View File

@@ -7,7 +7,24 @@ run:
linters:
default: all
disable:
- wsl # Deprecated in v2, replaced by wsl_v5
- exhaustruct
- depguard
- godot
- wsl
- wsl_v5
- wrapcheck
- varnamelen
- noinlineerr
- dupl
- paralleltest
- nlreturn
- tagliatelle
- goconst
- funlen
- maintidx
- cyclop
- gocognit
- lll
settings:
lll:
line-length: 88
@@ -18,19 +35,7 @@ linters:
max-complexity: 15
dupl:
threshold: 100
gosec:
excludes:
- G704
depguard:
rules:
all:
deny:
- pkg: "io/ioutil"
desc: "Deprecated; use io and os packages."
- pkg: "math/rand$"
desc: "Use crypto/rand for security-sensitive code."
issues:
exclude-use-default: false
max-issues-per-linter: 0
max-same-issues: 0

View File

@@ -1,26 +1,23 @@
# golang:1.24-alpine, 2026-02-26
FROM golang@sha256:8bee1901f1e530bfb4a7850aa7a479d17ae3a18beb6e09064ed54cfd245b7191 AS builder
# Build stage
FROM golang:1.24-alpine AS builder
WORKDIR /src
RUN apk add --no-cache git build-base make
# golangci-lint v2.1.6 (eabc2638a66d), 2026-02-26
RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d
RUN apk add --no-cache make gcc musl-dev
COPY go.mod go.sum ./
RUN go mod download
COPY . .
# Run all checks — build fails if branch is not green
RUN make check
# Run tests
ENV DBURL="file::memory:?cache=shared"
RUN go test ./...
# Build static binaries (no cgo needed at runtime — modernc.org/sqlite is pure Go)
ARG VERSION=dev
RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w -X main.Version=${VERSION}" -o /chatd ./cmd/chatd/
RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o /chat-cli ./cmd/chat-cli/
# Build binaries
RUN CGO_ENABLED=1 go build -trimpath -ldflags="-s -w" -o /chatd ./cmd/chatd/
RUN CGO_ENABLED=1 go build -trimpath -ldflags="-s -w" -o /chat-cli ./cmd/chat-cli/
# alpine:3.21, 2026-02-26
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
# Final stage — server only
FROM alpine:3.21
RUN apk add --no-cache ca-certificates \
&& addgroup -S chat && adduser -S chat -G chat
COPY --from=builder /chatd /usr/local/bin/chatd

21
LICENSE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025 sneak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,49 +1,20 @@
.PHONY: all build lint fmt fmt-check test check clean run debug docker hooks
BINARY := chatd
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILDARCH := $(shell go env GOARCH)
LDFLAGS := -X main.Version=$(VERSION) -X main.Buildarch=$(BUILDARCH)
LDFLAGS := -ldflags "-X main.Version=$(VERSION)"
all: check build
.PHONY: build test clean docker lint
build:
go build -ldflags "$(LDFLAGS)" -o bin/$(BINARY) ./cmd/chatd
lint:
golangci-lint run --config .golangci.yml ./...
fmt:
gofmt -s -w .
goimports -w .
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
go build $(LDFLAGS) -o chatd ./cmd/chatd/
go build $(LDFLAGS) -o chat-cli ./cmd/chat-cli/
test:
go test -timeout 30s -v -race -cover ./...
# check runs all validation without making changes
# Used by CI and Docker build — fails if anything is wrong
check: test lint fmt-check
@echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/chatd
@echo "==> All checks passed!"
run: build
./bin/$(BINARY)
debug: build
DEBUG=1 GOTRACEBACK=all ./bin/$(BINARY)
DBURL="file::memory:?cache=shared" go test ./...
clean:
rm -rf bin/ chatd data.db
rm -f chatd chat-cli
lint:
GOFLAGS=-buildvcs=false golangci-lint run ./...
docker:
docker build -t chat .
hooks:
@printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit
@printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit
@printf 'make check\n' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit
docker build -t chat:$(VERSION) .

View File

@@ -2199,8 +2199,3 @@ See [Roadmap](#roadmap) for what's next.
## License
MIT
## Author
[@sneak](https://sneak.berlin)

View File

@@ -1,182 +0,0 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary. Pre-1.0.0: modify existing migrations (no installed base assumed).
Post-1.0.0: add new migration files.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

View File

@@ -1,4 +1,3 @@
// Package chatapi provides a client for the chat server API.
package chatapi
import (
@@ -16,8 +15,8 @@ import (
)
const (
httpTimeout = 30 * time.Second
pollExtraTime = 5
httpTimeout = 30 * time.Second
pollExtraTime = 5
httpErrThreshold = 400
)
@@ -32,19 +31,17 @@ type Client struct {
// NewClient creates a new API client.
func NewClient(baseURL string) *Client {
return &Client{ //nolint:exhaustruct // Token set after CreateSession
BaseURL: baseURL,
HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: httpTimeout,
},
return &Client{
BaseURL: baseURL,
HTTPClient: &http.Client{Timeout: httpTimeout},
}
}
// CreateSession creates a new session on the server.
func (client *Client) CreateSession(
func (c *Client) CreateSession(
nick string,
) (*SessionResponse, error) {
data, err := client.do(
data, err := c.do(
http.MethodPost,
"/api/v1/session",
&SessionRequest{Nick: nick},
@@ -60,14 +57,14 @@ func (client *Client) CreateSession(
return nil, fmt.Errorf("decode session: %w", err)
}
client.Token = resp.Token
c.Token = resp.Token
return &resp, nil
}
// GetState returns the current user state.
func (client *Client) GetState() (*StateResponse, error) {
data, err := client.do(
func (c *Client) GetState() (*StateResponse, error) {
data, err := c.do(
http.MethodGet, "/api/v1/state", nil,
)
if err != nil {
@@ -85,8 +82,8 @@ func (client *Client) GetState() (*StateResponse, error) {
}
// SendMessage sends a message (any IRC command).
func (client *Client) SendMessage(msg *Message) error {
_, err := client.do(
func (c *Client) SendMessage(msg *Message) error {
_, err := c.do(
http.MethodPost, "/api/v1/messages", msg,
)
@@ -94,11 +91,11 @@ func (client *Client) SendMessage(msg *Message) error {
}
// PollMessages long-polls for new messages.
func (client *Client) PollMessages(
func (c *Client) PollMessages(
afterID int64,
timeout int,
) (*PollResult, error) {
pollClient := &http.Client{ //nolint:exhaustruct // defaults fine
client := &http.Client{
Timeout: time.Duration(
timeout+pollExtraTime,
) * time.Second,
@@ -116,30 +113,28 @@ func (client *Client) PollMessages(
path := "/api/v1/messages?" + params.Encode()
request, err := http.NewRequestWithContext(
req, err := http.NewRequestWithContext(
context.Background(),
http.MethodGet,
client.BaseURL+path,
c.BaseURL+path,
nil,
)
if err != nil {
return nil, fmt.Errorf("new request: %w", err)
return nil, err
}
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := pollClient.Do(request)
resp, err := client.Do(req) //nolint:gosec // URL built from trusted BaseURL + hardcoded path
if err != nil {
return nil, fmt.Errorf("poll request: %w", err)
return nil, err
}
defer func() { _ = resp.Body.Close() }()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read poll body: %w", err)
return nil, err
}
if resp.StatusCode >= httpErrThreshold {
@@ -165,28 +160,22 @@ func (client *Client) PollMessages(
}
// JoinChannel joins a channel.
func (client *Client) JoinChannel(channel string) error {
return client.SendMessage(
&Message{ //nolint:exhaustruct // only command+to needed
Command: "JOIN", To: channel,
},
func (c *Client) JoinChannel(channel string) error {
return c.SendMessage(
&Message{Command: "JOIN", To: channel},
)
}
// PartChannel leaves a channel.
func (client *Client) PartChannel(channel string) error {
return client.SendMessage(
&Message{ //nolint:exhaustruct // only command+to needed
Command: "PART", To: channel,
},
func (c *Client) PartChannel(channel string) error {
return c.SendMessage(
&Message{Command: "PART", To: channel},
)
}
// ListChannels returns all channels on the server.
func (client *Client) ListChannels() (
[]Channel, error,
) {
data, err := client.do(
func (c *Client) ListChannels() ([]Channel, error) {
data, err := c.do(
http.MethodGet, "/api/v1/channels", nil,
)
if err != nil {
@@ -197,21 +186,19 @@ func (client *Client) ListChannels() (
err = json.Unmarshal(data, &channels)
if err != nil {
return nil, fmt.Errorf(
"decode channels: %w", err,
)
return nil, err
}
return channels, nil
}
// GetMembers returns members of a channel.
func (client *Client) GetMembers(
func (c *Client) GetMembers(
channel string,
) ([]string, error) {
name := strings.TrimPrefix(channel, "#")
data, err := client.do(
data, err := c.do(
http.MethodGet,
"/api/v1/channels/"+url.PathEscape(name)+
"/members",
@@ -234,10 +221,8 @@ func (client *Client) GetMembers(
}
// GetServerInfo returns server info.
func (client *Client) GetServerInfo() (
*ServerInfo, error,
) {
data, err := client.do(
func (c *Client) GetServerInfo() (*ServerInfo, error) {
data, err := c.do(
http.MethodGet, "/api/v1/server", nil,
)
if err != nil {
@@ -248,15 +233,13 @@ func (client *Client) GetServerInfo() (
err = json.Unmarshal(data, &info)
if err != nil {
return nil, fmt.Errorf(
"decode server info: %w", err,
)
return nil, err
}
return &info, nil
}
func (client *Client) do(
func (c *Client) do(
method, path string,
body any,
) ([]byte, error) {
@@ -271,27 +254,25 @@ func (client *Client) do(
bodyReader = bytes.NewReader(data)
}
request, err := http.NewRequestWithContext(
req, err := http.NewRequestWithContext(
context.Background(),
method,
client.BaseURL+path,
c.BaseURL+path,
bodyReader,
)
if err != nil {
return nil, fmt.Errorf("request: %w", err)
}
request.Header.Set(
"Content-Type", "application/json",
)
req.Header.Set("Content-Type", "application/json")
if client.Token != "" {
request.Header.Set(
"Authorization", "Bearer "+client.Token,
if c.Token != "" {
req.Header.Set(
"Authorization", "Bearer "+c.Token,
)
}
resp, err := client.HTTPClient.Do(request)
resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL built from trusted BaseURL + hardcoded path
if err != nil {
return nil, fmt.Errorf("http: %w", err)
}

View File

@@ -1,3 +1,4 @@
// Package chatapi provides API types and client for chat-cli.
package chatapi
import "time"
@@ -23,9 +24,9 @@ type StateResponse struct {
// Message represents a chat message envelope.
type Message struct {
Command string `json:"command"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Command string `json:"command"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"`
ID string `json:"id,omitempty"`
@@ -35,19 +36,19 @@ type Message struct {
// BodyLines returns the body as a string slice.
func (m *Message) BodyLines() []string {
switch bodyVal := m.Body.(type) {
switch v := m.Body.(type) {
case []any:
lines := make([]string, 0, len(bodyVal))
lines := make([]string, 0, len(v))
for _, item := range bodyVal {
if str, ok := item.(string); ok {
lines = append(lines, str)
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)
}
}
return lines
case []string:
return bodyVal
return v
default:
return nil
}

View File

@@ -12,10 +12,10 @@ import (
)
const (
splitParts = 2
pollTimeout = 15
pollRetry = 2 * time.Second
timeFormat = "15:04"
splitParts = 2
pollTimeout = 15
pollRetry = 2 * time.Second
timeFormat = "15:04"
)
// App holds the application state.
@@ -32,7 +32,7 @@ type App struct {
}
func main() {
app := &App{ //nolint:exhaustruct
app := &App{
ui: NewUI(),
nick: "guest",
}
@@ -85,7 +85,7 @@ func (a *App) handleInput(text string) {
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
err := a.client.SendMessage(&api.Message{
Command: "PRIVMSG",
To: target,
Body: []string{text},
@@ -98,7 +98,7 @@ func (a *App) handleInput(text string) {
return
}
timestamp := time.Now().Format(timeFormat)
ts := time.Now().Format(timeFormat)
a.mu.Lock()
nick := a.nick
@@ -106,7 +106,7 @@ func (a *App) handleInput(text string) {
a.ui.AddLine(target, fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
timestamp, nick, text,
ts, nick, text,
))
}
@@ -227,7 +227,7 @@ func (a *App) cmdNick(nick string) {
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
err := a.client.SendMessage(&api.Message{
Command: "NICK",
Body: []string{nick},
})
@@ -362,7 +362,7 @@ func (a *App) cmdMsg(args string) {
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
err := a.client.SendMessage(&api.Message{
Command: "PRIVMSG",
To: target,
Body: []string{text},
@@ -375,11 +375,11 @@ func (a *App) cmdMsg(args string) {
return
}
timestamp := time.Now().Format(timeFormat)
ts := time.Now().Format(timeFormat)
a.ui.AddLine(target, fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
timestamp, nick, text,
ts, nick, text,
))
}
@@ -420,7 +420,7 @@ func (a *App) cmdTopic(args string) {
}
if args == "" {
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
err := a.client.SendMessage(&api.Message{
Command: "TOPIC",
To: target,
})
@@ -433,7 +433,7 @@ func (a *App) cmdTopic(args string) {
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
err := a.client.SendMessage(&api.Message{
Command: "TOPIC",
To: target,
Body: []string{args},
@@ -519,18 +519,18 @@ func (a *App) cmdWindow(args string) {
return
}
var bufIndex int
var n int
_, _ = fmt.Sscanf(args, "%d", &bufIndex)
_, _ = fmt.Sscanf(args, "%d", &n)
a.ui.SwitchBuffer(bufIndex)
a.ui.SwitchBuffer(n)
a.mu.Lock()
nick := a.nick
a.mu.Unlock()
if bufIndex >= 0 && bufIndex < a.ui.BufferCount() {
buf := a.ui.buffers[bufIndex]
if n >= 0 && n < a.ui.BufferCount() {
buf := a.ui.buffers[n]
if buf.Name != "(status)" {
a.mu.Lock()
a.target = buf.Name
@@ -550,7 +550,7 @@ func (a *App) cmdQuit() {
if a.connected && a.client != nil {
_ = a.client.SendMessage(
&api.Message{Command: "QUIT"}, //nolint:exhaustruct
&api.Message{Command: "QUIT"},
)
}
@@ -625,7 +625,7 @@ func (a *App) pollLoop() {
}
func (a *App) handleServerMessage(msg *api.Message) {
timestamp := a.formatTS(msg)
ts := a.formatTS(msg)
a.mu.Lock()
myNick := a.nick
@@ -633,21 +633,21 @@ func (a *App) handleServerMessage(msg *api.Message) {
switch msg.Command {
case "PRIVMSG":
a.handlePrivmsgEvent(msg, timestamp, myNick)
a.handlePrivmsgEvent(msg, ts, myNick)
case "JOIN":
a.handleJoinEvent(msg, timestamp)
a.handleJoinEvent(msg, ts)
case "PART":
a.handlePartEvent(msg, timestamp)
a.handlePartEvent(msg, ts)
case "QUIT":
a.handleQuitEvent(msg, timestamp)
a.handleQuitEvent(msg, ts)
case "NICK":
a.handleNickEvent(msg, timestamp, myNick)
a.handleNickEvent(msg, ts, myNick)
case "NOTICE":
a.handleNoticeEvent(msg, timestamp)
a.handleNoticeEvent(msg, ts)
case "TOPIC":
a.handleTopicEvent(msg, timestamp)
a.handleTopicEvent(msg, ts)
default:
a.handleDefaultEvent(msg, timestamp)
a.handleDefaultEvent(msg, ts)
}
}
@@ -660,7 +660,7 @@ func (a *App) formatTS(msg *api.Message) string {
}
func (a *App) handlePrivmsgEvent(
msg *api.Message, timestamp, myNick string,
msg *api.Message, ts, myNick string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
@@ -676,12 +676,12 @@ func (a *App) handlePrivmsgEvent(
a.ui.AddLine(target, fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
timestamp, msg.From, text,
ts, msg.From, text,
))
}
func (a *App) handleJoinEvent(
msg *api.Message, timestamp string,
msg *api.Message, ts string,
) {
if msg.To == "" {
return
@@ -689,12 +689,12 @@ func (a *App) handleJoinEvent(
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [yellow]*** %s has joined %s",
timestamp, msg.From, msg.To,
ts, msg.From, msg.To,
))
}
func (a *App) handlePartEvent(
msg *api.Message, timestamp string,
msg *api.Message, ts string,
) {
if msg.To == "" {
return
@@ -706,18 +706,18 @@ func (a *App) handlePartEvent(
if reason != "" {
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s (%s)",
timestamp, msg.From, msg.To, reason,
ts, msg.From, msg.To, reason,
))
} else {
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s",
timestamp, msg.From, msg.To,
ts, msg.From, msg.To,
))
}
}
func (a *App) handleQuitEvent(
msg *api.Message, timestamp string,
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
reason := strings.Join(lines, " ")
@@ -725,18 +725,18 @@ func (a *App) handleQuitEvent(
if reason != "" {
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit (%s)",
timestamp, msg.From, reason,
ts, msg.From, reason,
))
} else {
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit",
timestamp, msg.From,
ts, msg.From,
))
}
}
func (a *App) handleNickEvent(
msg *api.Message, timestamp, myNick string,
msg *api.Message, ts, myNick string,
) {
lines := msg.BodyLines()
@@ -757,24 +757,24 @@ func (a *App) handleNickEvent(
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [yellow]*** %s is now known as %s",
timestamp, msg.From, newNick,
ts, msg.From, newNick,
))
}
func (a *App) handleNoticeEvent(
msg *api.Message, timestamp string,
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [magenta]--%s-- %s",
timestamp, msg.From, text,
ts, msg.From, text,
))
}
func (a *App) handleTopicEvent(
msg *api.Message, timestamp string,
msg *api.Message, ts string,
) {
if msg.To == "" {
return
@@ -785,12 +785,12 @@ func (a *App) handleTopicEvent(
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [cyan]*** %s set topic: %s",
timestamp, msg.From, text,
ts, msg.From, text,
))
}
func (a *App) handleDefaultEvent(
msg *api.Message, timestamp string,
msg *api.Message, ts string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
@@ -798,7 +798,7 @@ func (a *App) handleDefaultEvent(
if text != "" {
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [white][%s] %s",
timestamp, msg.Command, text,
ts, msg.Command, text,
))
}
}

View File

@@ -32,10 +32,10 @@ type UI struct {
// NewUI creates the tview-based IRC-like UI.
func NewUI() *UI {
ui := &UI{ //nolint:exhaustruct,varnamelen // fields set below; ui is idiomatic
ui := &UI{
app: tview.NewApplication(),
buffers: []*Buffer{
{Name: "(status)", Lines: nil, Unread: 0},
{Name: "(status)", Lines: nil},
},
}
@@ -58,12 +58,7 @@ func NewUI() *UI {
// Run starts the UI event loop (blocks).
func (ui *UI) Run() error {
err := ui.app.Run()
if err != nil {
return fmt.Errorf("run ui: %w", err)
}
return nil
return ui.app.Run()
}
// Stop stops the UI.
@@ -85,7 +80,6 @@ func (ui *UI) AddLine(bufferName, line string) {
cur := ui.buffers[ui.currentBuffer]
if cur != buf {
buf.Unread++
ui.refreshStatusBar()
}
@@ -105,15 +99,15 @@ func (ui *UI) AddStatus(line string) {
}
// SwitchBuffer switches to the buffer at index n.
func (ui *UI) SwitchBuffer(bufIndex int) {
func (ui *UI) SwitchBuffer(n int) {
ui.app.QueueUpdateDraw(func() {
if bufIndex < 0 || bufIndex >= len(ui.buffers) {
if n < 0 || n >= len(ui.buffers) {
return
}
ui.currentBuffer = bufIndex
ui.currentBuffer = n
buf := ui.buffers[bufIndex]
buf := ui.buffers[n]
buf.Unread = 0
ui.messages.Clear()
@@ -287,7 +281,7 @@ func (ui *UI) getOrCreateBuffer(name string) *Buffer {
}
}
buf := &Buffer{Name: name, Lines: nil, Unread: 0}
buf := &Buffer{Name: name}
ui.buffers = append(ui.buffers, buf)
return buf

View File

@@ -8,28 +8,25 @@ import (
// Broker notifies waiting clients when new messages are available.
type Broker struct {
mu sync.Mutex
listeners map[int64][]chan struct{}
listeners map[int64][]chan struct{} // userID -> list of waiting channels
}
// New creates a new Broker.
func New() *Broker {
return &Broker{ //nolint:exhaustruct // mu has zero-value default
return &Broker{
listeners: make(map[int64][]chan struct{}),
}
}
// Wait returns a channel that will be closed when a message
// is available for the user.
// Wait returns a channel that will be closed when a message is available for the user.
func (b *Broker) Wait(userID int64) chan struct{} {
waitCh := make(chan struct{}, 1)
ch := make(chan struct{}, 1)
b.mu.Lock()
b.listeners[userID] = append(
b.listeners[userID], waitCh,
)
b.listeners[userID] = append(b.listeners[userID], ch)
b.mu.Unlock()
return waitCh
return ch
}
// Notify wakes up all waiting clients for a user.
@@ -39,29 +36,24 @@ func (b *Broker) Notify(userID int64) {
delete(b.listeners, userID)
b.mu.Unlock()
for _, waiter := range waiters {
for _, ch := range waiters {
select {
case waiter <- struct{}{}:
case ch <- struct{}{}:
default:
}
}
}
// Remove removes a specific wait channel (for cleanup on timeout).
func (b *Broker) Remove(
userID int64,
waitCh chan struct{},
) {
func (b *Broker) Remove(userID int64, ch chan struct{}) {
b.mu.Lock()
defer b.mu.Unlock()
waiters := b.listeners[userID]
for i, waiter := range waiters {
if waiter == waitCh {
b.listeners[userID] = append(
waiters[:i], waiters[i+1:]...,
)
for i, w := range waiters {
if w == ch {
b.listeners[userID] = append(waiters[:i], waiters[i+1:]...)
break
}

View File

@@ -11,8 +11,8 @@ import (
func TestNewBroker(t *testing.T) {
t.Parallel()
brk := broker.New()
if brk == nil {
b := broker.New()
if b == nil {
t.Fatal("expected non-nil broker")
}
}
@@ -20,16 +20,16 @@ func TestNewBroker(t *testing.T) {
func TestWaitAndNotify(t *testing.T) {
t.Parallel()
brk := broker.New()
waitCh := brk.Wait(1)
b := broker.New()
ch := b.Wait(1)
go func() {
time.Sleep(10 * time.Millisecond)
brk.Notify(1)
b.Notify(1)
}()
select {
case <-waitCh:
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal("timeout")
}
@@ -38,22 +38,21 @@ func TestWaitAndNotify(t *testing.T) {
func TestNotifyWithoutWaiters(t *testing.T) {
t.Parallel()
brk := broker.New()
brk.Notify(42) // should not panic.
b := broker.New()
b.Notify(42) // should not panic
}
func TestRemove(t *testing.T) {
t.Parallel()
brk := broker.New()
waitCh := brk.Wait(1)
b := broker.New()
ch := b.Wait(1)
b.Remove(1, ch)
brk.Remove(1, waitCh)
brk.Notify(1)
b.Notify(1)
select {
case <-waitCh:
case <-ch:
t.Fatal("should not receive after remove")
case <-time.After(50 * time.Millisecond):
}
@@ -62,20 +61,20 @@ func TestRemove(t *testing.T) {
func TestMultipleWaiters(t *testing.T) {
t.Parallel()
brk := broker.New()
waitCh1 := brk.Wait(1)
waitCh2 := brk.Wait(1)
b := broker.New()
ch1 := b.Wait(1)
ch2 := b.Wait(1)
brk.Notify(1)
b.Notify(1)
select {
case <-waitCh1:
case <-ch1:
case <-time.After(time.Second):
t.Fatal("ch1 timeout")
}
select {
case <-waitCh2:
case <-ch2:
case <-time.After(time.Second):
t.Fatal("ch2 timeout")
}
@@ -84,38 +83,36 @@ func TestMultipleWaiters(t *testing.T) {
func TestConcurrentWaitNotify(t *testing.T) {
t.Parallel()
brk := broker.New()
b := broker.New()
var waitGroup sync.WaitGroup
var wg sync.WaitGroup
const concurrency = 100
for idx := range concurrency {
waitGroup.Add(1)
for i := range concurrency {
wg.Add(1)
go func(uid int64) {
defer waitGroup.Done()
defer wg.Done()
waitCh := brk.Wait(uid)
brk.Notify(uid)
ch := b.Wait(uid)
b.Notify(uid)
select {
case <-waitCh:
case <-ch:
case <-time.After(time.Second):
t.Error("timeout")
}
}(int64(idx % 10))
}(int64(i % 10))
}
waitGroup.Wait()
wg.Wait()
}
func TestRemoveNonexistent(t *testing.T) {
t.Parallel()
brk := broker.New()
waitCh := make(chan struct{}, 1)
brk.Remove(999, waitCh) // should not panic.
b := broker.New()
ch := make(chan struct{}, 1)
b.Remove(999, ch) // should not panic
}

View File

@@ -41,9 +41,7 @@ type Config struct {
}
// New creates a new Config by reading from files and environment variables.
func New(
_ fx.Lifecycle, params Params,
) (*Config, error) {
func New(_ fx.Lifecycle, params Params) (*Config, error) {
log := params.Logger.Get()
name := params.Globals.Appname
@@ -76,7 +74,7 @@ func New(
}
}
cfg := &Config{
s := &Config{
DBURL: viper.GetString("DBURL"),
Debug: viper.GetBool("DEBUG"),
Port: viper.GetInt("PORT"),
@@ -94,10 +92,10 @@ func New(
params: &params,
}
if cfg.Debug {
if s.Debug {
params.Logger.EnableDebugLogging()
cfg.log = params.Logger.Get()
s.log = params.Logger.Get()
}
return cfg, nil
return s, nil
}

View File

@@ -37,93 +37,84 @@ type Params struct {
// Database manages the SQLite connection and migrations.
type Database struct {
conn *sql.DB
db *sql.DB
log *slog.Logger
params *Params
}
// New creates a new Database and registers lifecycle hooks.
func New(
lifecycle fx.Lifecycle,
lc fx.Lifecycle,
params Params,
) (*Database, error) {
database := &Database{ //nolint:exhaustruct // conn set in OnStart
params: &params,
log: params.Logger.Get(),
}
s := new(Database)
s.params = &params
s.log = params.Logger.Get()
database.log.Info("Database instantiated")
s.log.Info("Database instantiated")
lifecycle.Append(fx.Hook{
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
database.log.Info("Database OnStart Hook")
s.log.Info("Database OnStart Hook")
return database.connect(ctx)
return s.connect(ctx)
},
OnStop: func(_ context.Context) error {
database.log.Info("Database OnStop Hook")
s.log.Info("Database OnStop Hook")
if database.conn != nil {
closeErr := database.conn.Close()
if closeErr != nil {
return fmt.Errorf(
"close db: %w", closeErr,
)
}
if s.db != nil {
return s.db.Close()
}
return nil
},
})
return database, nil
return s, nil
}
// GetDB returns the underlying sql.DB connection.
func (database *Database) GetDB() *sql.DB {
return database.conn
func (s *Database) GetDB() *sql.DB {
return s.db
}
func (database *Database) connect(ctx context.Context) error {
dbURL := database.params.Config.DBURL
func (s *Database) connect(ctx context.Context) error {
dbURL := s.params.Config.DBURL
if dbURL == "" {
dbURL = "file:./data.db?_journal_mode=WAL&_busy_timeout=5000"
dbURL = "file:./data.db?_journal_mode=WAL"
}
database.log.Info(
"connecting to database", "url", dbURL,
)
s.log.Info("connecting to database", "url", dbURL)
conn, err := sql.Open("sqlite", dbURL)
d, err := sql.Open("sqlite", dbURL)
if err != nil {
return fmt.Errorf("open database: %w", err)
s.log.Error(
"failed to open database", "error", err,
)
return err
}
err = conn.PingContext(ctx)
err = d.PingContext(ctx)
if err != nil {
return fmt.Errorf("ping database: %w", err)
s.log.Error(
"failed to ping database", "error", err,
)
return err
}
conn.SetMaxOpenConns(1)
s.db = d
s.log.Info("database connected")
database.conn = conn
database.log.Info("database connected")
_, err = database.conn.ExecContext(
_, err = s.db.ExecContext(
ctx, "PRAGMA foreign_keys = ON",
)
if err != nil {
return fmt.Errorf("enable foreign keys: %w", err)
}
_, err = database.conn.ExecContext(
ctx, "PRAGMA busy_timeout = 5000",
)
if err != nil {
return fmt.Errorf("set busy timeout: %w", err)
}
return database.runMigrations(ctx)
return s.runMigrations(ctx)
}
type migration struct {
@@ -132,10 +123,10 @@ type migration struct {
sql string
}
func (database *Database) runMigrations(
func (s *Database) runMigrations(
ctx context.Context,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)`)
@@ -145,37 +136,37 @@ func (database *Database) runMigrations(
)
}
migrations, err := database.loadMigrations()
migrations, err := s.loadMigrations()
if err != nil {
return err
}
for _, mig := range migrations {
err = database.applyMigration(ctx, mig)
for _, m := range migrations {
err = s.applyMigration(ctx, m)
if err != nil {
return err
}
}
database.log.Info("database migrations complete")
s.log.Info("database migrations complete")
return nil
}
func (database *Database) applyMigration(
func (s *Database) applyMigration(
ctx context.Context,
mig migration,
m migration,
) error {
var exists int
err := database.conn.QueryRowContext(ctx,
err := s.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM schema_migrations
WHERE version = ?`,
mig.version,
m.version,
).Scan(&exists)
if err != nil {
return fmt.Errorf(
"check migration %d: %w", mig.version, err,
"check migration %d: %w", m.version, err,
)
}
@@ -183,63 +174,55 @@ func (database *Database) applyMigration(
return nil
}
database.log.Info(
s.log.Info(
"applying migration",
"version", mig.version,
"name", mig.name,
"version", m.version,
"name", m.name,
)
return database.execMigration(ctx, mig)
return s.execMigration(ctx, m)
}
func (database *Database) execMigration(
func (s *Database) execMigration(
ctx context.Context,
mig migration,
m migration,
) error {
transaction, err := database.conn.BeginTx(ctx, nil)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf(
"begin tx for migration %d: %w",
mig.version, err,
m.version, err,
)
}
_, err = transaction.ExecContext(ctx, mig.sql)
_, err = tx.ExecContext(ctx, m.sql)
if err != nil {
_ = transaction.Rollback()
_ = tx.Rollback()
return fmt.Errorf(
"apply migration %d (%s): %w",
mig.version, mig.name, err,
m.version, m.name, err,
)
}
_, err = transaction.ExecContext(ctx,
_, err = tx.ExecContext(ctx,
`INSERT INTO schema_migrations (version)
VALUES (?)`,
mig.version,
m.version,
)
if err != nil {
_ = transaction.Rollback()
_ = tx.Rollback()
return fmt.Errorf(
"record migration %d: %w",
mig.version, err,
m.version, err,
)
}
err = transaction.Commit()
if err != nil {
return fmt.Errorf(
"commit migration %d: %w",
mig.version, err,
)
}
return nil
return tx.Commit()
}
func (database *Database) loadMigrations() (
func (s *Database) loadMigrations() (
[]migration,
error,
) {
@@ -250,7 +233,7 @@ func (database *Database) loadMigrations() (
)
}
migrations := make([]migration, 0, len(entries))
var migrations []migration
for _, entry := range entries {
if entry.IsDir() ||

View File

@@ -13,48 +13,35 @@ var testDBCounter atomic.Int64
// NewTestDatabase creates an in-memory database for testing.
func NewTestDatabase() (*Database, error) {
counter := testDBCounter.Add(1)
n := testDBCounter.Add(1)
dsn := fmt.Sprintf(
"file:testdb%d?mode=memory"+
"&cache=shared&_pragma=foreign_keys(1)",
counter,
n,
)
conn, err := sql.Open("sqlite", dsn)
d, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("open test db: %w", err)
return nil, err
}
database := &Database{ //nolint:exhaustruct // test helper, params not needed
conn: conn,
log: slog.Default(),
}
database := &Database{db: d, log: slog.Default()}
err = database.runMigrations(context.Background())
if err != nil {
closeErr := conn.Close()
closeErr := d.Close()
if closeErr != nil {
return nil, fmt.Errorf(
"close after migration failure: %w",
closeErr,
)
return nil, closeErr
}
return nil, fmt.Errorf(
"run test migrations: %w", err,
)
return nil, err
}
return database, nil
}
// Close closes the underlying database connection.
func (database *Database) Close() error {
err := database.conn.Close()
if err != nil {
return fmt.Errorf("close database: %w", err)
}
return nil
func (s *Database) Close() error {
return s.db.Close()
}

View File

@@ -18,15 +18,11 @@ const (
defaultHistLimit = 50
)
func generateToken() (string, error) {
buf := make([]byte, tokenBytes)
func generateToken() string {
b := make([]byte, tokenBytes)
_, _ = rand.Read(b)
_, err := rand.Read(buf)
if err != nil {
return "", fmt.Errorf("generate token: %w", err)
}
return hex.EncodeToString(buf), nil
return hex.EncodeToString(b)
}
// IRCMessage is the IRC envelope for all messages.
@@ -56,18 +52,14 @@ type MemberInfo struct {
}
// CreateUser registers a new user with the given nick.
func (database *Database) CreateUser(
func (s *Database) CreateUser(
ctx context.Context,
nick string,
) (int64, string, error) {
token, err := generateToken()
if err != nil {
return 0, "", err
}
token := generateToken()
now := time.Now()
res, err := database.conn.ExecContext(ctx,
res, err := s.db.ExecContext(ctx,
`INSERT INTO users
(nick, token, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
@@ -76,88 +68,90 @@ func (database *Database) CreateUser(
return 0, "", fmt.Errorf("create user: %w", err)
}
userID, _ := res.LastInsertId()
id, _ := res.LastInsertId()
return userID, token, nil
return id, token, nil
}
// GetUserByToken returns user id and nick for a token.
func (database *Database) GetUserByToken(
func (s *Database) GetUserByToken(
ctx context.Context,
token string,
) (int64, string, error) {
var userID int64
var id int64
var nick string
err := database.conn.QueryRowContext(
err := s.db.QueryRowContext(
ctx,
"SELECT id, nick FROM users WHERE token = ?",
token,
).Scan(&userID, &nick)
).Scan(&id, &nick)
if err != nil {
return 0, "", fmt.Errorf("get user by token: %w", err)
return 0, "", err
}
_, _ = database.conn.ExecContext(
_, _ = s.db.ExecContext(
ctx,
"UPDATE users SET last_seen = ? WHERE id = ?",
time.Now(), userID,
time.Now(), id,
)
return userID, nick, nil
return id, nick, nil
}
// GetUserByNick returns user id for a given nick.
func (database *Database) GetUserByNick(
func (s *Database) GetUserByNick(
ctx context.Context,
nick string,
) (int64, error) {
var userID int64
var id int64
err := database.conn.QueryRowContext(
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM users WHERE nick = ?",
nick,
).Scan(&userID)
if err != nil {
return 0, fmt.Errorf("get user by nick: %w", err)
}
).Scan(&id)
return userID, nil
return id, err
}
// GetChannelByName returns the channel ID for a name.
func (database *Database) GetChannelByName(
func (s *Database) GetChannelByName(
ctx context.Context,
name string,
) (int64, error) {
var channelID int64
var id int64
err := database.conn.QueryRowContext(
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&channelID)
if err != nil {
return 0, fmt.Errorf(
"get channel by name: %w", err,
)
}
).Scan(&id)
return channelID, nil
return id, err
}
// GetOrCreateChannel returns channel id, creating if needed.
// Uses INSERT OR IGNORE to avoid TOCTOU races.
func (database *Database) GetOrCreateChannel(
func (s *Database) GetOrCreateChannel(
ctx context.Context,
name string,
) (int64, error) {
var id int64
err := s.db.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&id)
if err == nil {
return id, nil
}
now := time.Now()
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channels
res, err := s.db.ExecContext(ctx,
`INSERT INTO channels
(name, created_at, updated_at)
VALUES (?, ?, ?)`,
name, now, now)
@@ -165,71 +159,51 @@ func (database *Database) GetOrCreateChannel(
return 0, fmt.Errorf("create channel: %w", err)
}
var channelID int64
id, _ = res.LastInsertId()
err = database.conn.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&channelID)
if err != nil {
return 0, fmt.Errorf("get channel: %w", err)
}
return channelID, nil
return id, nil
}
// JoinChannel adds a user to a channel.
func (database *Database) JoinChannel(
func (s *Database) JoinChannel(
ctx context.Context,
channelID, userID int64,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_members
(channel_id, user_id, joined_at)
VALUES (?, ?, ?)`,
channelID, userID, time.Now())
if err != nil {
return fmt.Errorf("join channel: %w", err)
}
return nil
return err
}
// PartChannel removes a user from a channel.
func (database *Database) PartChannel(
func (s *Database) PartChannel(
ctx context.Context,
channelID, userID int64,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
`DELETE FROM channel_members
WHERE channel_id = ? AND user_id = ?`,
channelID, userID)
if err != nil {
return fmt.Errorf("part channel: %w", err)
}
return nil
return err
}
// DeleteChannelIfEmpty removes a channel with no members.
func (database *Database) DeleteChannelIfEmpty(
func (s *Database) DeleteChannelIfEmpty(
ctx context.Context,
channelID int64,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
`DELETE FROM channels WHERE id = ?
AND NOT EXISTS
(SELECT 1 FROM channel_members
WHERE channel_id = ?)`,
channelID, channelID)
if err != nil {
return fmt.Errorf(
"delete channel if empty: %w", err,
)
}
return nil
return err
}
// scanChannels scans rows into a ChannelInfo slice.
@@ -241,21 +215,19 @@ func scanChannels(
var out []ChannelInfo
for rows.Next() {
var chanInfo ChannelInfo
var ch ChannelInfo
err := rows.Scan(
&chanInfo.ID, &chanInfo.Name, &chanInfo.Topic,
)
err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic)
if err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
return nil, err
}
out = append(out, chanInfo)
out = append(out, ch)
}
err := rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
return nil, err
}
if out == nil {
@@ -266,11 +238,11 @@ func scanChannels(
}
// ListChannels returns channels the user has joined.
func (database *Database) ListChannels(
func (s *Database) ListChannels(
ctx context.Context,
userID int64,
) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic
FROM channels c
INNER JOIN channel_members cm
@@ -278,34 +250,32 @@ func (database *Database) ListChannels(
WHERE cm.user_id = ?
ORDER BY c.name`, userID)
if err != nil {
return nil, fmt.Errorf("list channels: %w", err)
return nil, err
}
return scanChannels(rows)
}
// ListAllChannels returns every channel.
func (database *Database) ListAllChannels(
func (s *Database) ListAllChannels(
ctx context.Context,
) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT id, name, topic
FROM channels ORDER BY name`)
if err != nil {
return nil, fmt.Errorf(
"list all channels: %w", err,
)
return nil, err
}
return scanChannels(rows)
}
// ChannelMembers returns all members of a channel.
func (database *Database) ChannelMembers(
func (s *Database) ChannelMembers(
ctx context.Context,
channelID int64,
) ([]MemberInfo, error) {
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT u.id, u.nick, u.last_seen
FROM users u
INNER JOIN channel_members cm
@@ -313,9 +283,7 @@ func (database *Database) ChannelMembers(
WHERE cm.channel_id = ?
ORDER BY u.nick`, channelID)
if err != nil {
return nil, fmt.Errorf(
"query channel members: %w", err,
)
return nil, err
}
defer func() { _ = rows.Close() }()
@@ -323,23 +291,19 @@ func (database *Database) ChannelMembers(
var members []MemberInfo
for rows.Next() {
var member MemberInfo
var m MemberInfo
err = rows.Scan(
&member.ID, &member.Nick, &member.LastSeen,
)
err = rows.Scan(&m.ID, &m.Nick, &m.LastSeen)
if err != nil {
return nil, fmt.Errorf(
"scan member: %w", err,
)
return nil, err
}
members = append(members, member)
members = append(members, m)
}
err = rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
return nil, err
}
if members == nil {
@@ -349,27 +313,6 @@ func (database *Database) ChannelMembers(
return members, nil
}
// IsChannelMember checks if a user belongs to a channel.
func (database *Database) IsChannelMember(
ctx context.Context,
channelID, userID int64,
) (bool, error) {
var count int
err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM channel_members
WHERE channel_id = ? AND user_id = ?`,
channelID, userID,
).Scan(&count)
if err != nil {
return false, fmt.Errorf(
"check membership: %w", err,
)
}
return count > 0, nil
}
// scanInt64s scans rows into an int64 slice.
func scanInt64s(rows *sql.Rows) ([]int64, error) {
defer func() { _ = rows.Close() }()
@@ -377,64 +320,58 @@ func scanInt64s(rows *sql.Rows) ([]int64, error) {
var ids []int64
for rows.Next() {
var val int64
var id int64
err := rows.Scan(&val)
err := rows.Scan(&id)
if err != nil {
return nil, fmt.Errorf(
"scan int64: %w", err,
)
return nil, err
}
ids = append(ids, val)
ids = append(ids, id)
}
err := rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
return nil, err
}
return ids, nil
}
// GetChannelMemberIDs returns user IDs in a channel.
func (database *Database) GetChannelMemberIDs(
func (s *Database) GetChannelMemberIDs(
ctx context.Context,
channelID int64,
) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT user_id FROM channel_members
WHERE channel_id = ?`, channelID)
if err != nil {
return nil, fmt.Errorf(
"get channel member ids: %w", err,
)
return nil, err
}
return scanInt64s(rows)
}
// GetUserChannelIDs returns channel IDs the user is in.
func (database *Database) GetUserChannelIDs(
func (s *Database) GetUserChannelIDs(
ctx context.Context,
userID int64,
) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT channel_id FROM channel_members
WHERE user_id = ?`, userID)
if err != nil {
return nil, fmt.Errorf(
"get user channel ids: %w", err,
)
return nil, err
}
return scanInt64s(rows)
}
// InsertMessage stores a message and returns its DB ID.
func (database *Database) InsertMessage(
func (s *Database) InsertMessage(
ctx context.Context,
command, from, target string,
command, from, to string,
body json.RawMessage,
meta json.RawMessage,
) (int64, string, error) {
@@ -449,43 +386,38 @@ func (database *Database) InsertMessage(
meta = json.RawMessage("{}")
}
res, err := database.conn.ExecContext(ctx,
res, err := s.db.ExecContext(ctx,
`INSERT INTO messages
(uuid, command, msg_from, msg_to,
body, meta, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
msgUUID, command, from, target,
msgUUID, command, from, to,
string(body), string(meta), now)
if err != nil {
return 0, "", fmt.Errorf(
"insert message: %w", err,
)
return 0, "", err
}
dbID, _ := res.LastInsertId()
id, _ := res.LastInsertId()
return dbID, msgUUID, nil
return id, msgUUID, nil
}
// EnqueueMessage adds a message to a user's queue.
func (database *Database) EnqueueMessage(
func (s *Database) EnqueueMessage(
ctx context.Context,
userID, messageID int64,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues
(user_id, message_id, created_at)
VALUES (?, ?, ?)`,
userID, messageID, time.Now())
if err != nil {
return fmt.Errorf("enqueue message: %w", err)
}
return nil
return err
}
// PollMessages returns queued messages for a user.
func (database *Database) PollMessages(
func (s *Database) PollMessages(
ctx context.Context,
userID, afterQueueID int64,
limit int,
@@ -494,7 +426,7 @@ func (database *Database) PollMessages(
limit = defaultPollLimit
}
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT cq.id, m.uuid, m.command,
m.msg_from, m.msg_to,
m.body, m.meta, m.created_at
@@ -505,9 +437,7 @@ func (database *Database) PollMessages(
ORDER BY cq.id ASC LIMIT ?`,
userID, afterQueueID, limit)
if err != nil {
return nil, afterQueueID, fmt.Errorf(
"poll messages: %w", err,
)
return nil, afterQueueID, err
}
msgs, lastQID, scanErr := scanMessages(
@@ -521,7 +451,7 @@ func (database *Database) PollMessages(
}
// GetHistory returns message history for a target.
func (database *Database) GetHistory(
func (s *Database) GetHistory(
ctx context.Context,
target string,
beforeID int64,
@@ -531,7 +461,7 @@ func (database *Database) GetHistory(
limit = defaultHistLimit
}
rows, err := database.queryHistory(
rows, err := s.queryHistory(
ctx, target, beforeID, limit,
)
if err != nil {
@@ -552,14 +482,14 @@ func (database *Database) GetHistory(
return msgs, nil
}
func (database *Database) queryHistory(
func (s *Database) queryHistory(
ctx context.Context,
target string,
beforeID int64,
limit int,
) (*sql.Rows, error) {
if beforeID > 0 {
rows, err := database.conn.QueryContext(ctx,
return s.db.QueryContext(ctx,
`SELECT id, uuid, command, msg_from,
msg_to, body, meta, created_at
FROM messages
@@ -567,16 +497,9 @@ func (database *Database) queryHistory(
AND command = 'PRIVMSG'
ORDER BY id DESC LIMIT ?`,
target, beforeID, limit)
if err != nil {
return nil, fmt.Errorf(
"query history: %w", err,
)
}
return rows, nil
}
rows, err := database.conn.QueryContext(ctx,
return s.db.QueryContext(ctx,
`SELECT id, uuid, command, msg_from,
msg_to, body, meta, created_at
FROM messages
@@ -584,11 +507,6 @@ func (database *Database) queryHistory(
AND command = 'PRIVMSG'
ORDER BY id DESC LIMIT ?`,
target, limit)
if err != nil {
return nil, fmt.Errorf("query history: %w", err)
}
return rows, nil
}
func scanMessages(
@@ -603,37 +521,33 @@ func scanMessages(
for rows.Next() {
var (
msg IRCMessage
m IRCMessage
qID int64
body, meta string
createdAt time.Time
ts time.Time
)
err := rows.Scan(
&qID, &msg.ID, &msg.Command,
&msg.From, &msg.To,
&body, &meta, &createdAt,
&qID, &m.ID, &m.Command,
&m.From, &m.To,
&body, &meta, &ts,
)
if err != nil {
return nil, fallbackQID, fmt.Errorf(
"scan message: %w", err,
)
return nil, fallbackQID, err
}
msg.Body = json.RawMessage(body)
msg.Meta = json.RawMessage(meta)
msg.TS = createdAt.Format(time.RFC3339Nano)
msg.DBID = qID
m.Body = json.RawMessage(body)
m.Meta = json.RawMessage(meta)
m.TS = ts.Format(time.RFC3339Nano)
m.DBID = qID
lastQID = qID
msgs = append(msgs, msg)
msgs = append(msgs, m)
}
err := rows.Err()
if err != nil {
return nil, fallbackQID, fmt.Errorf(
"rows error: %w", err,
)
return nil, fallbackQID, err
}
if msgs == nil {
@@ -650,70 +564,59 @@ func reverseMessages(msgs []IRCMessage) {
}
// ChangeNick updates a user's nickname.
func (database *Database) ChangeNick(
func (s *Database) ChangeNick(
ctx context.Context,
userID int64,
newNick string,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
"UPDATE users SET nick = ? WHERE id = ?",
newNick, userID)
if err != nil {
return fmt.Errorf("change nick: %w", err)
}
return nil
return err
}
// SetTopic sets the topic for a channel.
func (database *Database) SetTopic(
func (s *Database) SetTopic(
ctx context.Context,
channelName, topic string,
) error {
_, err := database.conn.ExecContext(ctx,
_, err := s.db.ExecContext(ctx,
`UPDATE channels SET topic = ?,
updated_at = ? WHERE name = ?`,
topic, time.Now(), channelName)
if err != nil {
return fmt.Errorf("set topic: %w", err)
}
return nil
return err
}
// DeleteUser removes a user and all their data.
func (database *Database) DeleteUser(
func (s *Database) DeleteUser(
ctx context.Context,
userID int64,
) error {
_, err := database.conn.ExecContext(
_, err := s.db.ExecContext(
ctx,
"DELETE FROM users WHERE id = ?",
userID,
)
if err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil
return err
}
// GetAllChannelMembershipsForUser returns channels
// a user belongs to.
func (database *Database) GetAllChannelMembershipsForUser(
func (s *Database) GetAllChannelMembershipsForUser(
ctx context.Context,
userID int64,
) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx,
rows, err := s.db.QueryContext(ctx,
`SELECT c.id, c.name, c.topic
FROM channels c
INNER JOIN channel_members cm
ON cm.channel_id = c.id
WHERE cm.user_id = ?`, userID)
if err != nil {
return nil, fmt.Errorf(
"get memberships: %w", err,
)
return nil, err
}
return scanChannels(rows)

View File

@@ -1,6 +1,7 @@
package db_test
import (
"context"
"encoding/json"
"testing"
@@ -12,26 +13,26 @@ import (
func setupTestDB(t *testing.T) *db.Database {
t.Helper()
database, err := db.NewTestDatabase()
d, err := db.NewTestDatabase()
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
closeErr := database.Close()
closeErr := d.Close()
if closeErr != nil {
t.Logf("close db: %v", closeErr)
}
})
return database
return d
}
func TestCreateUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
id, token, err := database.CreateUser(ctx, "alice")
if err != nil {
@@ -52,7 +53,7 @@ func TestGetUserByToken(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
_, token, err := database.CreateUser(ctx, "bob")
if err != nil {
@@ -78,7 +79,7 @@ func TestGetUserByNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
_, _, err := database.CreateUser(ctx, "charlie")
if err != nil {
@@ -100,7 +101,7 @@ func TestChannelOperations(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
chID, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil || chID == 0 {
@@ -127,7 +128,7 @@ func TestJoinAndPart(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, "user1")
if err != nil {
@@ -169,7 +170,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
chID, err := database.GetOrCreateChannel(
ctx, "#empty",
@@ -211,7 +212,7 @@ func createUserWithChannels(
) (int64, int64, int64) {
t.Helper()
ctx := t.Context()
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, nick)
if err != nil {
@@ -254,7 +255,7 @@ func TestListChannels(t *testing.T) {
)
channels, err := database.ListChannels(
t.Context(), uid,
context.Background(), uid,
)
if err != nil || len(channels) != 2 {
t.Fatalf(
@@ -268,7 +269,7 @@ func TestListAllChannels(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
_, err := database.GetOrCreateChannel(ctx, "#x")
if err != nil {
@@ -293,7 +294,7 @@ func TestChangeNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
uid, token, err := database.CreateUser(ctx, "old")
if err != nil {
@@ -319,7 +320,7 @@ func TestSetTopic(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
_, err := database.GetOrCreateChannel(
ctx, "#topictest",
@@ -349,31 +350,11 @@ func TestSetTopic(t *testing.T) {
}
}
func TestInsertMessage(t *testing.T) {
func TestInsertAndPollMessages(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
body := json.RawMessage(`["hello"]`)
dbID, msgUUID, err := database.InsertMessage(
ctx, "PRIVMSG", "poller", "#test", body, nil,
)
if err != nil {
t.Fatal(err)
}
if dbID == 0 || msgUUID == "" {
t.Fatal("expected valid id and uuid")
}
}
func TestPollMessages(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, "poller")
if err != nil {
@@ -382,11 +363,11 @@ func TestPollMessages(t *testing.T) {
body := json.RawMessage(`["hello"]`)
dbID, _, err := database.InsertMessage(
dbID, msgUUID, err := database.InsertMessage(
ctx, "PRIVMSG", "poller", "#test", body, nil,
)
if err != nil {
t.Fatal(err)
if err != nil || dbID == 0 || msgUUID == "" {
t.Fatal("insert failed")
}
err = database.EnqueueMessage(ctx, uid, dbID)
@@ -434,7 +415,7 @@ func TestGetHistory(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
const msgCount = 10
@@ -471,7 +452,7 @@ func TestDeleteUser(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
uid, _, err := database.CreateUser(ctx, "deleteme")
if err != nil {
@@ -510,7 +491,7 @@ func TestChannelMembers(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
ctx := context.Background()
uid1, _, err := database.CreateUser(ctx, "m1")
if err != nil {
@@ -558,7 +539,7 @@ func TestGetAllChannelMembershipsForUser(t *testing.T) {
channels, err :=
database.GetAllChannelMembershipsForUser(
t.Context(), uid,
context.Background(), uid,
)
if err != nil || len(channels) != 2 {
t.Fatalf(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -40,59 +40,40 @@ type Handlers struct {
// New creates a new Handlers instance.
func New(
lifecycle fx.Lifecycle,
lc fx.Lifecycle,
params Params,
) (*Handlers, error) {
hdlr := &Handlers{
params: &params,
log: params.Logger.Get(),
hc: params.Healthcheck,
broker: broker.New(),
}
s := new(Handlers)
s.params = &params
s.log = params.Logger.Get()
s.hc = params.Healthcheck
s.broker = broker.New()
lifecycle.Append(fx.Hook{
lc.Append(fx.Hook{
OnStart: func(_ context.Context) error {
return nil
},
OnStop: func(_ context.Context) error {
return nil
},
})
return hdlr, nil
return s, nil
}
func (hdlr *Handlers) respondJSON(
writer http.ResponseWriter,
func (s *Handlers) respondJSON(
w http.ResponseWriter,
_ *http.Request,
data any,
status int,
) {
writer.Header().Set(
w.Header().Set(
"Content-Type",
"application/json; charset=utf-8",
)
writer.WriteHeader(status)
w.WriteHeader(status)
if data != nil {
err := json.NewEncoder(writer).Encode(data)
err := json.NewEncoder(w).Encode(data)
if err != nil {
hdlr.log.Error(
"json encode error", "error", err,
)
s.log.Error("json encode error", "error", err)
}
}
}
func (hdlr *Handlers) respondError(
writer http.ResponseWriter,
request *http.Request,
msg string,
status int,
) {
hdlr.respondJSON(
writer, request,
map[string]string{"error": msg},
status,
)
}

View File

@@ -7,12 +7,9 @@ import (
const httpStatusOK = 200
// HandleHealthCheck returns an HTTP handler for the health check endpoint.
func (hdlr *Handlers) HandleHealthCheck() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
resp := hdlr.hc.Healthcheck()
hdlr.respondJSON(writer, request, resp, httpStatusOK)
func (s *Handlers) HandleHealthCheck() http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
resp := s.hc.Healthcheck()
s.respondJSON(w, req, resp, httpStatusOK)
}
}

View File

@@ -33,17 +33,14 @@ type Healthcheck struct {
}
// New creates a new Healthcheck instance.
func New(
lifecycle fx.Lifecycle, params Params,
) (*Healthcheck, error) {
hcheck := &Healthcheck{ //nolint:exhaustruct // StartupTime set in OnStart
params: &params,
log: params.Logger.Get(),
}
func New(lc fx.Lifecycle, params Params) (*Healthcheck, error) {
s := new(Healthcheck)
s.params = &params
s.log = params.Logger.Get()
lifecycle.Append(fx.Hook{
lc.Append(fx.Hook{
OnStart: func(_ context.Context) error {
hcheck.StartupTime = time.Now()
s.StartupTime = time.Now()
return nil
},
@@ -52,7 +49,7 @@ func New(
},
})
return hcheck, nil
return s, nil
}
// Response is the JSON response returned by the health endpoint.
@@ -67,18 +64,19 @@ type Response struct {
}
// Healthcheck returns the current health status of the server.
func (hcheck *Healthcheck) Healthcheck() *Response {
return &Response{
func (s *Healthcheck) Healthcheck() *Response {
resp := &Response{
Status: "ok",
Now: time.Now().UTC().Format(time.RFC3339Nano),
UptimeSeconds: int64(hcheck.uptime().Seconds()),
UptimeHuman: hcheck.uptime().String(),
Appname: hcheck.params.Globals.Appname,
Version: hcheck.params.Globals.Version,
Maintenance: hcheck.params.Config.MaintenanceMode,
UptimeSeconds: int64(s.uptime().Seconds()),
UptimeHuman: s.uptime().String(),
Appname: s.params.Globals.Appname,
Version: s.params.Globals.Version,
}
return resp
}
func (hcheck *Healthcheck) uptime() time.Duration {
return time.Since(hcheck.StartupTime)
func (s *Healthcheck) uptime() time.Duration {
return time.Since(s.StartupTime)
}

View File

@@ -23,56 +23,51 @@ type Logger struct {
params Params
}
// New creates a new Logger with appropriate handler
// based on terminal detection.
func New(
_ fx.Lifecycle, params Params,
) (*Logger, error) {
logger := new(Logger)
logger.level = new(slog.LevelVar)
logger.level.Set(slog.LevelInfo)
// New creates a new Logger with appropriate handler based on terminal detection.
func New(_ fx.Lifecycle, params Params) (*Logger, error) {
l := new(Logger)
l.level = new(slog.LevelVar)
l.level.Set(slog.LevelInfo)
tty := false
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) != 0 {
tty = true
}
opts := &slog.HandlerOptions{ //nolint:exhaustruct // ReplaceAttr optional
Level: logger.level,
AddSource: true,
}
var handler slog.Handler
if tty {
handler = slog.NewTextHandler(os.Stdout, opts)
handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: l.level,
AddSource: true,
})
} else {
handler = slog.NewJSONHandler(os.Stdout, opts)
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: l.level,
AddSource: true,
})
}
logger.log = slog.New(handler)
logger.params = params
l.log = slog.New(handler)
l.params = params
return logger, nil
return l, nil
}
// EnableDebugLogging switches the log level to debug.
func (logger *Logger) EnableDebugLogging() {
logger.level.Set(slog.LevelDebug)
logger.log.Debug(
"debug logging enabled", "debug", true,
)
func (l *Logger) EnableDebugLogging() {
l.level.Set(slog.LevelDebug)
l.log.Debug("debug logging enabled", "debug", true)
}
// Get returns the underlying slog.Logger.
func (logger *Logger) Get() *slog.Logger {
return logger.log
func (l *Logger) Get() *slog.Logger {
return l.log
}
// Identify logs the application name and version at startup.
func (logger *Logger) Identify() {
logger.log.Info("starting",
"appname", logger.params.Globals.Appname,
"version", logger.params.Globals.Version,
func (l *Logger) Identify() {
l.log.Info("starting",
"appname", l.params.Globals.Appname,
"version", l.params.Globals.Version,
)
}

View File

@@ -11,7 +11,7 @@ import (
"git.eeqj.de/sneak/chat/internal/globals"
"git.eeqj.de/sneak/chat/internal/logger"
basicauth "github.com/99designs/basicauth-go"
chimw "github.com/go-chi/chi/middleware"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
metrics "github.com/slok/go-http-metrics/metrics/prometheus"
ghmm "github.com/slok/go-http-metrics/middleware"
@@ -38,28 +38,25 @@ type Middleware struct {
}
// New creates a new Middleware instance.
func New(
_ fx.Lifecycle, params Params,
) (*Middleware, error) {
mware := &Middleware{
params: &params,
log: params.Logger.Get(),
}
func New(_ fx.Lifecycle, params Params) (*Middleware, error) {
s := new(Middleware)
s.params = &params
s.log = params.Logger.Get()
return mware, nil
return s, nil
}
func ipFromHostPort(hostPort string) string {
host, _, err := net.SplitHostPort(hostPort)
func ipFromHostPort(hp string) string {
h, _, err := net.SplitHostPort(hp)
if err != nil {
return ""
}
if len(host) > 0 && host[0] == '[' {
return host[1 : len(host)-1]
if len(h) > 0 && h[0] == '[' {
return h[1 : len(h)-1]
}
return host
return h
}
type loggingResponseWriter struct {
@@ -68,15 +65,9 @@ type loggingResponseWriter struct {
statusCode int
}
// newLoggingResponseWriter wraps a ResponseWriter
// to capture the status code.
func newLoggingResponseWriter(
writer http.ResponseWriter,
) *loggingResponseWriter {
return &loggingResponseWriter{
ResponseWriter: writer,
statusCode: http.StatusOK,
}
// newLoggingResponseWriter wraps a ResponseWriter to capture the status code.
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{w, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
@@ -85,57 +76,43 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
}
// Logging returns middleware that logs each HTTP request.
func (mware *Middleware) Logging() func(http.Handler) http.Handler {
func (s *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(
func(
writer http.ResponseWriter,
request *http.Request,
) {
start := time.Now()
lrw := newLoggingResponseWriter(writer)
ctx := request.Context()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
lrw := newLoggingResponseWriter(w)
ctx := r.Context()
defer func() {
latency := time.Since(start)
defer func() {
latency := time.Since(start)
reqID, _ := ctx.Value(
chimw.RequestIDKey,
).(string)
reqID, _ := ctx.Value(middleware.RequestIDKey).(string)
mware.log.InfoContext(
ctx, "request",
"request_start", start,
"method", request.Method,
"url", request.URL.String(),
"useragent", request.UserAgent(),
"request_id", reqID,
"referer", request.Referer(),
"proto", request.Proto,
"remoteIP",
ipFromHostPort(request.RemoteAddr),
"status", lrw.statusCode,
"latency_ms",
latency.Milliseconds(),
)
}()
s.log.InfoContext(ctx, "request",
"request_start", start,
"method", r.Method,
"url", r.URL.String(),
"useragent", r.UserAgent(),
"request_id", reqID,
"referer", r.Referer(),
"proto", r.Proto,
"remoteIP", ipFromHostPort(r.RemoteAddr),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
}()
next.ServeHTTP(lrw, request)
})
next.ServeHTTP(lrw, r)
})
}
}
// CORS returns middleware that handles Cross-Origin Resource Sharing.
func (mware *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
"GET", "POST", "PUT", "DELETE", "OPTIONS",
},
AllowedHeaders: []string{
"Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
},
func (s *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: corsMaxAge,
@@ -143,34 +120,28 @@ func (mware *Middleware) CORS() func(http.Handler) http.Handler {
}
// Auth returns middleware that performs authentication.
func (mware *Middleware) Auth() func(http.Handler) http.Handler {
func (s *Middleware) Auth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(
func(
writer http.ResponseWriter,
request *http.Request,
) {
mware.log.Info("AUTH: before request")
next.ServeHTTP(writer, request)
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.log.Info("AUTH: before request")
next.ServeHTTP(w, r)
})
}
}
// Metrics returns middleware that records HTTP metrics.
func (mware *Middleware) Metrics() func(http.Handler) http.Handler {
metricsMiddleware := ghmm.New(ghmm.Config{ //nolint:exhaustruct // optional fields
Recorder: metrics.NewRecorder(
metrics.Config{}, //nolint:exhaustruct // defaults
),
func (s *Middleware) Metrics() func(http.Handler) http.Handler {
mdlw := ghmm.New(ghmm.Config{
Recorder: metrics.NewRecorder(metrics.Config{}),
})
return func(next http.Handler) http.Handler {
return std.Handler("", metricsMiddleware, next)
return std.Handler("", mdlw, next)
}
}
// MetricsAuth returns middleware that protects metrics with basic auth.
func (mware *Middleware) MetricsAuth() func(http.Handler) http.Handler {
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
return basicauth.New(
"metrics",
map[string][]string{

View File

@@ -17,94 +17,67 @@ import (
const routeTimeout = 60 * time.Second
// SetupRoutes configures the HTTP routes and middleware.
func (srv *Server) SetupRoutes() {
srv.router = chi.NewRouter()
func (s *Server) SetupRoutes() {
s.router = chi.NewRouter()
srv.router.Use(middleware.Recoverer)
srv.router.Use(middleware.RequestID)
srv.router.Use(srv.mw.Logging())
s.router.Use(middleware.Recoverer)
s.router.Use(middleware.RequestID)
s.router.Use(s.mw.Logging())
if viper.GetString("METRICS_USERNAME") != "" {
srv.router.Use(srv.mw.Metrics())
s.router.Use(s.mw.Metrics())
}
srv.router.Use(srv.mw.CORS())
srv.router.Use(middleware.Timeout(routeTimeout))
s.router.Use(s.mw.CORS())
s.router.Use(middleware.Timeout(routeTimeout))
if srv.sentryEnabled {
sentryHandler := sentryhttp.New(
sentryhttp.Options{ //nolint:exhaustruct // optional fields
Repanic: true,
},
)
srv.router.Use(sentryHandler.Handle)
if s.sentryEnabled {
sentryHandler := sentryhttp.New(sentryhttp.Options{
Repanic: true,
})
s.router.Use(sentryHandler.Handle)
}
// Health check.
srv.router.Get(
// Health check
s.router.Get(
"/.well-known/healthcheck.json",
srv.handlers.HandleHealthCheck(),
s.h.HandleHealthCheck(),
)
// Protected metrics endpoint.
// Protected metrics endpoint
if viper.GetString("METRICS_USERNAME") != "" {
srv.router.Group(func(router chi.Router) {
router.Use(srv.mw.MetricsAuth())
router.Get("/metrics",
s.router.Group(func(r chi.Router) {
r.Use(s.mw.MetricsAuth())
r.Get("/metrics",
http.HandlerFunc(
promhttp.Handler().ServeHTTP,
))
})
}
// API v1.
srv.router.Route(
"/api/v1",
func(router chi.Router) {
router.Get(
"/server",
srv.handlers.HandleServerInfo(),
)
router.Post(
"/session",
srv.handlers.HandleCreateSession(),
)
router.Get(
"/state",
srv.handlers.HandleState(),
)
router.Get(
"/messages",
srv.handlers.HandleGetMessages(),
)
router.Post(
"/messages",
srv.handlers.HandleSendCommand(),
)
router.Get(
"/history",
srv.handlers.HandleGetHistory(),
)
router.Get(
"/channels",
srv.handlers.HandleListAllChannels(),
)
router.Get(
"/channels/{channel}/members",
srv.handlers.HandleChannelMembers(),
)
},
)
// API v1
s.router.Route("/api/v1", func(r chi.Router) {
r.Get("/server", s.h.HandleServerInfo())
r.Post("/session", s.h.HandleCreateSession())
r.Get("/state", s.h.HandleState())
r.Get("/messages", s.h.HandleGetMessages())
r.Post("/messages", s.h.HandleSendCommand())
r.Get("/history", s.h.HandleGetHistory())
r.Get("/channels", s.h.HandleListAllChannels())
r.Get(
"/channels/{channel}/members",
s.h.HandleChannelMembers(),
)
})
// Serve embedded SPA.
srv.setupSPA()
// Serve embedded SPA
s.setupSPA()
}
func (srv *Server) setupSPA() {
func (s *Server) setupSPA() {
distFS, err := fs.Sub(web.Dist, "dist")
if err != nil {
srv.log.Error(
s.log.Error(
"failed to get web dist filesystem",
"error", err,
)
@@ -114,40 +87,38 @@ func (srv *Server) setupSPA() {
fileServer := http.FileServer(http.FS(distFS))
srv.router.Get("/*", func(
writer http.ResponseWriter,
request *http.Request,
s.router.Get("/*", func(
w http.ResponseWriter,
r *http.Request,
) {
readFS, ok := distFS.(fs.ReadFileFS)
if !ok {
fileServer.ServeHTTP(writer, request)
fileServer.ServeHTTP(w, r)
return
}
fileData, readErr := readFS.ReadFile(
request.URL.Path[1:],
)
if readErr != nil || len(fileData) == 0 {
f, readErr := readFS.ReadFile(r.URL.Path[1:])
if readErr != nil || len(f) == 0 {
indexHTML, indexErr := readFS.ReadFile(
"index.html",
)
if indexErr != nil {
http.NotFound(writer, request)
http.NotFound(w, r)
return
}
writer.Header().Set(
w.Header().Set(
"Content-Type",
"text/html; charset=utf-8",
)
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write(indexHTML)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(indexHTML)
return
}
fileServer.ServeHTTP(writer, request)
fileServer.ServeHTTP(w, r)
})
}

View File

@@ -41,8 +41,7 @@ type Params struct {
Handlers *handlers.Handlers
}
// Server is the main HTTP server.
// It manages routing, middleware, and lifecycle.
// Server is the main HTTP server. It manages routing, middleware, and lifecycle.
type Server struct {
startupTime time.Time
exitCode int
@@ -54,24 +53,21 @@ type Server struct {
router *chi.Mux
params Params
mw *middleware.Middleware
handlers *handlers.Handlers
h *handlers.Handlers
}
// New creates a new Server and registers its lifecycle hooks.
func New(
lifecycle fx.Lifecycle, params Params,
) (*Server, error) {
srv := &Server{ //nolint:exhaustruct // fields set during lifecycle
params: params,
mw: params.Middleware,
handlers: params.Handlers,
log: params.Logger.Get(),
}
func New(lc fx.Lifecycle, params Params) (*Server, error) {
s := new(Server)
s.params = params
s.mw = params.Middleware
s.h = params.Handlers
s.log = params.Logger.Get()
lifecycle.Append(fx.Hook{
lc.Append(fx.Hook{
OnStart: func(_ context.Context) error {
srv.startupTime = time.Now()
go srv.Run() //nolint:contextcheck
s.startupTime = time.Now()
go s.Run() //nolint:contextcheck
return nil
},
@@ -80,140 +76,122 @@ func New(
},
})
return srv, nil
return s, nil
}
// Run starts the server configuration, Sentry, and begins serving.
func (srv *Server) Run() {
srv.configure()
srv.enableSentry()
srv.serve()
func (s *Server) Run() {
s.configure()
s.enableSentry()
s.serve()
}
// ServeHTTP delegates to the chi router.
func (srv *Server) ServeHTTP(
writer http.ResponseWriter,
request *http.Request,
) {
srv.router.ServeHTTP(writer, request)
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// MaintenanceMode reports whether the server is in maintenance mode.
func (srv *Server) MaintenanceMode() bool {
return srv.params.Config.MaintenanceMode
func (s *Server) MaintenanceMode() bool {
return s.params.Config.MaintenanceMode
}
func (srv *Server) enableSentry() {
srv.sentryEnabled = false
func (s *Server) enableSentry() {
s.sentryEnabled = false
if srv.params.Config.SentryDSN == "" {
if s.params.Config.SentryDSN == "" {
return
}
err := sentry.Init(sentry.ClientOptions{ //nolint:exhaustruct // only essential fields
Dsn: srv.params.Config.SentryDSN,
Release: fmt.Sprintf(
"%s-%s",
srv.params.Globals.Appname,
srv.params.Globals.Version,
),
err := sentry.Init(sentry.ClientOptions{
Dsn: s.params.Config.SentryDSN,
Release: fmt.Sprintf("%s-%s", s.params.Globals.Appname, s.params.Globals.Version),
})
if err != nil {
srv.log.Error("sentry init failure", "error", err)
s.log.Error("sentry init failure", "error", err)
os.Exit(1)
}
srv.log.Info("sentry error reporting activated")
srv.sentryEnabled = true
s.log.Info("sentry error reporting activated")
s.sentryEnabled = true
}
func (srv *Server) serve() int {
srv.ctx, srv.cancelFunc = context.WithCancel(
context.Background(),
)
func (s *Server) serve() int {
s.ctx, s.cancelFunc = context.WithCancel(context.Background())
go func() {
sigCh := make(chan os.Signal, 1)
c := make(chan os.Signal, 1)
signal.Ignore(syscall.SIGPIPE)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
sig := <-c
s.log.Info("signal received", "signal", sig)
sig := <-sigCh
srv.log.Info("signal received", "signal", sig)
if srv.cancelFunc != nil {
srv.cancelFunc()
if s.cancelFunc != nil {
s.cancelFunc()
}
}()
go srv.serveUntilShutdown()
go s.serveUntilShutdown()
<-srv.ctx.Done()
<-s.ctx.Done()
srv.cleanShutdown()
s.cleanShutdown()
return srv.exitCode
return s.exitCode
}
func (srv *Server) cleanupForExit() {
srv.log.Info("cleaning up")
func (s *Server) cleanupForExit() {
s.log.Info("cleaning up")
}
func (srv *Server) cleanShutdown() {
srv.exitCode = 0
func (s *Server) cleanShutdown() {
s.exitCode = 0
ctxShutdown, shutdownCancel := context.WithTimeout(
context.Background(), shutdownTimeout,
)
err := srv.httpServer.Shutdown(ctxShutdown)
err := s.httpServer.Shutdown(ctxShutdown)
if err != nil {
srv.log.Error(
"server clean shutdown failed", "error", err,
)
s.log.Error("server clean shutdown failed", "error", err)
}
if shutdownCancel != nil {
shutdownCancel()
}
srv.cleanupForExit()
s.cleanupForExit()
if srv.sentryEnabled {
if s.sentryEnabled {
sentry.Flush(sentryFlushTime)
}
}
func (srv *Server) configure() {
// Server configuration placeholder.
func (s *Server) configure() {
// server configuration placeholder
}
func (srv *Server) serveUntilShutdown() {
listenAddr := fmt.Sprintf(
":%d", srv.params.Config.Port,
)
srv.httpServer = &http.Server{ //nolint:exhaustruct // optional fields
func (s *Server) serveUntilShutdown() {
listenAddr := fmt.Sprintf(":%d", s.params.Config.Port)
s.httpServer = &http.Server{
Addr: listenAddr,
ReadTimeout: httpReadTimeout,
WriteTimeout: httpWriteTimeout,
MaxHeaderBytes: maxHeaderBytes,
Handler: srv,
Handler: s,
}
srv.SetupRoutes()
s.SetupRoutes()
srv.log.Info(
"http begin listen", "listenaddr", listenAddr,
)
s.log.Info("http begin listen", "listenaddr", listenAddr)
err := srv.httpServer.ListenAndServe()
err := s.httpServer.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
srv.log.Error("listen error", "error", err)
s.log.Error("listen error", "error", err)
if srv.cancelFunc != nil {
srv.cancelFunc()
if s.cancelFunc != nil {
s.cancelFunc()
}
}
}