Compare commits

..

2 Commits

Author SHA1 Message Date
b1343364a3 Merge branch 'main' into fix/45-readme-clarify-env-and-datadir
All checks were successful
check / check (push) Successful in 1m1s
2026-03-17 12:31:27 +01:00
clawbot
ec92c2bdd2 fix: use absolute path for dev DATA_DIR default, clarify env docs
All checks were successful
check / check (push) Successful in 1m1s
Change the dev-mode DATA_DIR default from the relative path ./data to
$XDG_DATA_HOME/webhooker (falling back to $HOME/.local/share/webhooker).
This ensures the application's data directory does not depend on the
working directory.

Add a table to the README that clearly documents what WEBHOOKER_ENVIRONMENT
actually controls: DATA_DIR default, CORS policy, and session cookie
Secure flag.

Add tests for devDataDir() and verify the dev default is always absolute.
2026-03-17 03:30:30 -07:00
68 changed files with 4067 additions and 8440 deletions

View File

@@ -1,32 +1,46 @@
version: "2"
run: run:
timeout: 5m timeout: 5m
modules-download-mode: readonly tests: true
linters: linters:
default: all enable:
disable: - gofmt
# Genuinely incompatible with project patterns - revive
- exhaustruct # Requires all struct fields - govet
- depguard # Dependency allow/block lists - errcheck
- godot # Requires comments to end with periods - staticcheck
- wsl # Deprecated, replaced by wsl_v5 - unused
- wrapcheck # Too verbose for internal packages - gosimple
- varnamelen # Short names like db, id are idiomatic Go - ineffassign
- typecheck
- gosec
- misspell
- unparam
- prealloc
- copyloopvar
- gocritic
- gochecknoinits
- gochecknoglobals
linters-settings: linters-settings:
lll: gofmt:
line-length: 88 simplify: true
funlen: revive:
lines: 80 confidence: 0.8
statements: 50 govet:
cyclop: enable:
max-complexity: 15 - shadow
dupl: errcheck:
threshold: 100 check-type-assertions: true
check-blank: true
issues: issues:
exclude-use-default: false exclude-rules:
max-issues-per-linter: 0 # Exclude globals check for version variables in main
max-same-issues: 0 - path: cmd/webhooker/main.go
linters:
- gochecknoglobals
# Exclude globals check for version variables in globals package
- path: internal/globals/globals.go
linters:
- gochecknoglobals

View File

@@ -1,58 +1,49 @@
# Lint stage # golang:1.24 (bookworm) — 2026-03-01
# golangci/golangci-lint:v2.11.3 (Debian-based), 2026-03-17
# Using Debian-based image because mattn/go-sqlite3 (CGO) does not
# compile on Alpine musl (off64_t is a glibc type).
FROM golangci/golangci-lint:v2.11.3@sha256:e838e8ab68aaefe83e2408691510867ade9329c0e0b895a3fb35eb93d1c2a4ba AS lint
RUN apt-get update && apt-get install -y --no-install-recommends make && rm -rf /var/lib/apt/lists/*
WORKDIR /src
# Copy go mod files first for better layer caching
COPY go.mod go.sum ./
RUN go mod download
# Copy source code
COPY . .
# Run formatting check and linter
RUN make fmt-check
RUN make lint
# Build stage
# golang:1.26.1-bookworm (Debian-based), 2026-03-17
# Using Debian-based image because gorm.io/driver/sqlite pulls in # Using Debian-based image because gorm.io/driver/sqlite pulls in
# mattn/go-sqlite3 (CGO), which does not compile on Alpine musl. # mattn/go-sqlite3 (CGO), which does not compile on Alpine musl.
FROM golang:1.26.1-bookworm@sha256:4465644228bc2857a954b092167e12aa59c006a3492282a6c820bf4755fd64a4 AS builder FROM golang@sha256:d2d2bc1c84f7e60d7d2438a3836ae7d0c847f4888464e7ec9ba3a1339a1ee804 AS builder
# Depend on lint stage passing
COPY --from=lint /src/go.sum /dev/null
# gcc is pre-installed in the Debian-based golang image
RUN apt-get update && apt-get install -y --no-install-recommends make && rm -rf /var/lib/apt/lists/* RUN apt-get update && apt-get install -y --no-install-recommends make && rm -rf /var/lib/apt/lists/*
WORKDIR /build WORKDIR /build
# Copy go mod files first for better layer caching # Install golangci-lint v1.64.8 — 2026-03-01
# Using v1.x because the repo's .golangci.yml uses v1 config format.
RUN set -eux; \
GOLANGCI_VERSION="1.64.8"; \
ARCH="$(uname -m)"; \
case "${ARCH}" in \
x86_64) \
GOARCH="amd64"; \
GOLANGCI_SHA256="b6270687afb143d019f387c791cd2a6f1cb383be9b3124d241ca11bd3ce2e54e"; \
;; \
aarch64) \
GOARCH="arm64"; \
GOLANGCI_SHA256="a6ab58ebcb1c48572622146cdaec2956f56871038a54ed1149f1386e287789a5"; \
;; \
*) echo "unsupported architecture: ${ARCH}" && exit 1 ;; \
esac; \
wget -q "https://github.com/golangci/golangci-lint/releases/download/v${GOLANGCI_VERSION}/golangci-lint-${GOLANGCI_VERSION}-linux-${GOARCH}.tar.gz" \
-O /tmp/golangci-lint.tar.gz; \
echo "${GOLANGCI_SHA256} /tmp/golangci-lint.tar.gz" | sha256sum -c -; \
tar -xzf /tmp/golangci-lint.tar.gz -C /tmp; \
mv "/tmp/golangci-lint-${GOLANGCI_VERSION}-linux-${GOARCH}/golangci-lint" /usr/local/bin/; \
rm -rf /tmp/golangci-lint*; \
golangci-lint --version
# Copy go module files and download dependencies
COPY go.mod go.sum ./ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
# Copy source code # Copy source code
COPY . . COPY . .
# Run tests and build # Run all checks (fmt-check, lint, test, build)
RUN make test RUN make check
RUN make build
# Rebuild with static linking for Alpine runtime. # alpine:3.21 — 2026-03-01
# make build already verified compilation. FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
# The CGO binary from `make build` is dynamically linked against glibc,
# which doesn't exist on Alpine (musl). Rebuild with static linking so
# the binary runs on Alpine without glibc.
RUN CGO_ENABLED=1 go build -ldflags '-extldflags "-static"' -o bin/webhooker ./cmd/webhooker
# Runtime stage
# alpine:3.21, 2026-03-17
FROM alpine:3.21@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk --no-cache add ca-certificates RUN apk --no-cache add ca-certificates
@@ -63,13 +54,13 @@ RUN addgroup -g 1000 -S webhooker && \
WORKDIR /app WORKDIR /app
# Copy binary from builder # Copy binary from builder
COPY --from=builder /build/bin/webhooker /app/webhooker COPY --from=builder /build/bin/webhooker .
# Create data directory for all SQLite databases (main app DB + # Create data directory for all SQLite databases (main app DB +
# per-webhook event DBs). DATA_DIR defaults to /var/lib/webhooker. # per-webhook event DBs). DATA_DIR defaults to /data in production.
RUN mkdir -p /var/lib/webhooker RUN mkdir -p /data
RUN chown -R webhooker:webhooker /app /var/lib/webhooker RUN chown -R webhooker:webhooker /app /data
USER webhooker USER webhooker
@@ -78,4 +69,4 @@ EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/.well-known/healthcheck || exit 1 CMD wget --no-verbose --tries=1 --spider http://localhost:8080/.well-known/healthcheck || exit 1
CMD ["/app/webhooker"] CMD ["./webhooker"]

View File

@@ -11,8 +11,8 @@ with retry support, logging, and observability. Category: infrastructure
### Prerequisites ### Prerequisites
- Go 1.26+ - Go 1.24+
- golangci-lint v2.11+ - golangci-lint v1.64+
- Docker (for containerized deployment) - Docker (for containerized deployment)
### Quick Start ### Quick Start
@@ -59,24 +59,10 @@ or `prod` (default: `dev`). The setting controls several behaviors:
| Behavior | `dev` | `prod` | | Behavior | `dev` | `prod` |
| --------------------- | -------------------------------- | ------------------------------- | | --------------------- | -------------------------------- | ------------------------------- |
| Default `DATA_DIR` | `$XDG_DATA_HOME/webhooker` (or `$HOME/.local/share/webhooker`) | `/data` |
| CORS | Allows any origin (`*`) | Disabled (no-op) | | CORS | Allows any origin (`*`) | Disabled (no-op) |
| Session cookie Secure | `false` (works over plain HTTP) | `true` (requires HTTPS) | | Session cookie Secure | `false` (works over plain HTTP) | `true` (requires HTTPS) |
The CSRF cookie's `Secure` flag and Origin/Referer validation mode are
determined per-request based on the actual transport protocol, not the
environment setting. The middleware checks `r.TLS` (direct TLS) and the
`X-Forwarded-Proto` header (TLS-terminating reverse proxy) to decide:
- **Direct TLS or `X-Forwarded-Proto: https`**: Secure cookies, strict
Origin/Referer validation.
- **Plaintext HTTP**: Non-Secure cookies, relaxed Origin/Referer
checks (token validation still enforced).
This means CSRF protection works correctly in all deployment scenarios:
behind a TLS-terminating reverse proxy, with direct TLS, or over plain
HTTP during development. When running behind a reverse proxy, ensure it
sets the `X-Forwarded-Proto: https` header.
All other differences (log format, security headers, etc.) are All other differences (log format, security headers, etc.) are
independent of the environment setting — log format is determined by independent of the environment setting — log format is determined by
TTY detection, and security headers are always applied. TTY detection, and security headers are always applied.
@@ -85,7 +71,7 @@ TTY detection, and security headers are always applied.
| ----------------------- | ----------------------------------- | -------- | | ----------------------- | ----------------------------------- | -------- |
| `WEBHOOKER_ENVIRONMENT` | `dev` or `prod` | `dev` | | `WEBHOOKER_ENVIRONMENT` | `dev` or `prod` | `dev` |
| `PORT` | HTTP listen port | `8080` | | `PORT` | HTTP listen port | `8080` |
| `DATA_DIR` | Directory for all SQLite databases | `/var/lib/webhooker` | | `DATA_DIR` | Directory for all SQLite databases | `$XDG_DATA_HOME/webhooker` (dev) / `/data` (prod) |
| `DEBUG` | Enable debug logging | `false` | | `DEBUG` | Enable debug logging | `false` |
| `METRICS_USERNAME` | Basic auth username for `/metrics` | `""` | | `METRICS_USERNAME` | Basic auth username for `/metrics` | `""` |
| `METRICS_PASSWORD` | Basic auth password for `/metrics` | `""` | | `METRICS_PASSWORD` | Basic auth password for `/metrics` | `""` |
@@ -104,16 +90,16 @@ is only displayed once.
```bash ```bash
docker run -d \ docker run -d \
-p 8080:8080 \ -p 8080:8080 \
-v /path/to/data:/var/lib/webhooker \ -v /path/to/data:/data \
-e WEBHOOKER_ENVIRONMENT=prod \ -e WEBHOOKER_ENVIRONMENT=prod \
webhooker:latest webhooker:latest
``` ```
The container runs as a non-root user (`webhooker`, UID 1000), exposes The container runs as a non-root user (`webhooker`, UID 1000), exposes
port 8080, and includes a health check against port 8080, and includes a health check against
`/.well-known/healthcheck`. The `/var/lib/webhooker` volume holds all `/.well-known/healthcheck`. The `/data` volume holds all SQLite
SQLite databases: the main application database (`webhooker.db`) and databases: the main application database (`webhooker.db`) and the
the per-webhook event databases (`events-{uuid}.db`). Mount this as a per-webhook event databases (`events-{uuid}.db`). Mount this as a
persistent volume to preserve data across container restarts. persistent volume to preserve data across container restarts.
## Rationale ## Rationale
@@ -181,10 +167,6 @@ It uses:
logging with TTY detection (text for dev, JSON for prod) logging with TTY detection (text for dev, JSON for prod)
- **[gorilla/sessions](https://github.com/gorilla/sessions)** for - **[gorilla/sessions](https://github.com/gorilla/sessions)** for
encrypted cookie-based session management encrypted cookie-based session management
- **[gorilla/csrf](https://github.com/gorilla/csrf)** for CSRF
protection (cookie-based double-submit tokens)
- **[go-chi/httprate](https://github.com/go-chi/httprate)** for
per-IP login rate limiting (sliding window counter)
- **[Prometheus](https://prometheus.io)** for metrics, served at - **[Prometheus](https://prometheus.io)** for metrics, served at
`/metrics` behind basic auth `/metrics` behind basic auth
- **[Sentry](https://sentry.io)** for optional error reporting - **[Sentry](https://sentry.io)** for optional error reporting
@@ -667,7 +649,7 @@ against a misbehaving sender).
| Method | Path | Description | | Method | Path | Description |
| ------ | --------------------------- | ----------- | | ------ | --------------------------- | ----------- |
| `GET` | `/` | Root redirect (authenticated → `/sources`, unauthenticated → `/pages/login`) | | `GET` | `/` | Web UI index page (server-rendered) |
| `GET` | `/.well-known/healthcheck` | Health check (JSON: status, uptime, version) | | `GET` | `/.well-known/healthcheck` | Health check (JSON: status, uptime, version) |
| `GET` | `/s/*` | Static file serving (embedded CSS, JS) | | `GET` | `/s/*` | Static file serving (embedded CSS, JS) |
| `ANY` | `/webhook/{uuid}` | Webhook receiver endpoint (accepts all methods) | | `ANY` | `/webhook/{uuid}` | Webhook receiver endpoint (accepts all methods) |
@@ -748,8 +730,7 @@ webhooker/
│ │ └── globals.go # Build-time variables (appname, version, arch) │ │ └── globals.go # Build-time variables (appname, version, arch)
│ ├── delivery/ │ ├── delivery/
│ │ ├── engine.go # Event-driven delivery engine (channel + timer based) │ │ ├── engine.go # Event-driven delivery engine (channel + timer based)
│ │ ── circuit_breaker.go # Per-target circuit breaker for HTTP targets with retries │ │ ── circuit_breaker.go # Per-target circuit breaker for HTTP targets with retries
│ │ └── ssrf.go # SSRF prevention (IP validation, safe HTTP transport)
│ ├── handlers/ │ ├── handlers/
│ │ ├── handlers.go # Base handler struct, JSON helpers, template rendering │ │ ├── handlers.go # Base handler struct, JSON helpers, template rendering
│ │ ├── auth.go # Login, logout handlers │ │ ├── auth.go # Login, logout handlers
@@ -763,9 +744,7 @@ webhooker/
│ ├── logger/ │ ├── logger/
│ │ └── logger.go # slog setup with TTY detection │ │ └── logger.go # slog setup with TTY detection
│ ├── middleware/ │ ├── middleware/
│ │ ── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize │ │ ── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize
│ │ ├── csrf.go # CSRF protection middleware (gorilla/csrf)
│ │ └── ratelimit.go # Per-IP rate limiting middleware (go-chi/httprate)
│ ├── server/ │ ├── server/
│ │ ├── server.go # Server struct, fx lifecycle, signal handling │ │ ├── server.go # Server struct, fx lifecycle, signal handling
│ │ ├── http.go # HTTP server setup with timeouts │ │ ├── http.go # HTTP server setup with timeouts
@@ -777,7 +756,7 @@ webhooker/
│ ├── css/style.css # Custom stylesheet (system font stack, card effects, layout) │ ├── css/style.css # Custom stylesheet (system font stack, card effects, layout)
│ └── js/app.js # Client-side JavaScript (minimal bootstrap) │ └── js/app.js # Client-side JavaScript (minimal bootstrap)
├── templates/ # Go HTML templates (base, index, login, etc.) ├── templates/ # Go HTML templates (base, index, login, etc.)
├── Dockerfile # Multi-stage: lint, build+test, then Alpine runtime ├── Dockerfile # Multi-stage: build + check, then Alpine runtime
├── Makefile # fmt, lint, test, check, build, docker targets ├── Makefile # fmt, lint, test, check, build, docker targets
├── go.mod / go.sum ├── go.mod / go.sum
└── .golangci.yml # Linter configuration └── .golangci.yml # Linter configuration
@@ -852,21 +831,6 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a
(`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy, (`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy,
and Permissions-Policy and Permissions-Policy
- Request body size limits (1 MB) on all form POST endpoints - Request body size limits (1 MB) on all form POST endpoints
- **CSRF protection** via [gorilla/csrf](https://github.com/gorilla/csrf)
on all state-changing forms (cookie-based double-submit tokens with
HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and
`/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and
`/api` (stateless API). The middleware auto-detects TLS status
per-request (via `r.TLS` and `X-Forwarded-Proto`) to set appropriate
cookie security flags and Origin/Referer validation mode
- **SSRF prevention** for HTTP delivery targets: private/reserved IP
ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked
both at target creation time (URL validation) and at delivery time
(custom HTTP transport with SSRF-safe dialer that validates resolved
IPs before connecting, preventing DNS rebinding attacks)
- **Login rate limiting** via [go-chi/httprate](https://github.com/go-chi/httprate):
per-IP sliding-window rate limiter on the login endpoint (5 POST
attempts per minute per IP) to prevent brute-force attacks
- Prometheus metrics behind basic auth - Prometheus metrics behind basic auth
- Static assets embedded in binary (no filesystem access needed at - Static assets embedded in binary (no filesystem access needed at
runtime) runtime)
@@ -881,8 +845,8 @@ The Dockerfile uses a multi-stage build:
golangci-lint, downloads dependencies, copies source, runs `make golangci-lint, downloads dependencies, copies source, runs `make
check` (format verification, linting, tests, compilation). check` (format verification, linting, tests, compilation).
2. **Runtime stage** (`alpine:3.21`) — copies the binary, creates the 2. **Runtime stage** (`alpine:3.21`) — copies the binary, creates the
`/var/lib/webhooker` directory for all SQLite databases, runs as `/data` directory for all SQLite databases, runs as non-root user,
non-root user, exposes port 8080, includes a health check. exposes port 8080, includes a health check.
The builder uses Debian rather than Alpine because GORM's SQLite The builder uses Debian rather than Alpine because GORM's SQLite
dialect pulls in CGO-dependent headers at compile time. The runtime dialect pulls in CGO-dependent headers at compile time. The runtime
@@ -953,12 +917,7 @@ linted, tested, and compiled.
### Remaining: Core Features ### Remaining: Core Features
- [ ] Per-webhook rate limiting in the receiver handler - [ ] Per-webhook rate limiting in the receiver handler
- [ ] Webhook signature verification (GitHub, Stripe formats) - [ ] Webhook signature verification (GitHub, Stripe formats)
- [x] CSRF protection for forms - [ ] CSRF protection for forms
([#35](https://git.eeqj.de/sneak/webhooker/issues/35))
- [x] SSRF prevention for HTTP delivery targets
([#36](https://git.eeqj.de/sneak/webhooker/issues/36))
- [x] Login rate limiting (per-IP brute-force protection)
([#37](https://git.eeqj.de/sneak/webhooker/issues/37))
- [ ] Session expiration and "remember me" - [ ] Session expiration and "remember me"
- [ ] Password change/reset flow - [ ] Password change/reset flow
- [ ] API key authentication for programmatic access - [ ] API key authentication for programmatic access

View File

@@ -1,4 +1,3 @@
// Package main is the entry point for the webhooker application.
package main package main
import ( import (
@@ -16,8 +15,6 @@ import (
) )
// Build-time variables set via -ldflags. // Build-time variables set via -ldflags.
//
//nolint:gochecknoglobals // Build-time variables injected by the linker.
var ( var (
version = "dev" version = "dev"
appname = "webhooker" appname = "webhooker"

8
go.mod
View File

@@ -1,15 +1,15 @@
module sneak.berlin/go/webhooker module sneak.berlin/go/webhooker
go 1.26.1 go 1.23.0
toolchain go1.24.1
require ( require (
github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8 github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8
github.com/getsentry/sentry-go v0.25.0 github.com/getsentry/sentry-go v0.25.0
github.com/go-chi/chi v1.5.5 github.com/go-chi/chi v1.5.5
github.com/go-chi/cors v1.2.1 github.com/go-chi/cors v1.2.1
github.com/go-chi/httprate v0.15.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/csrf v1.7.3
github.com/gorilla/sessions v1.4.0 github.com/gorilla/sessions v1.4.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
@@ -31,7 +31,6 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect
@@ -41,7 +40,6 @@ require (
github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/common v0.45.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
go.uber.org/dig v1.17.0 // indirect go.uber.org/dig v1.17.0 // indirect
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect

10
go.sum
View File

@@ -19,8 +19,6 @@ github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE=
github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g=
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
@@ -33,8 +31,6 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
@@ -47,8 +43,6 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -86,10 +80,6 @@ github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdr
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI=

View File

@@ -1,11 +1,10 @@
// Package config loads application configuration from environment variables.
package config package config
import ( import (
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
@@ -19,29 +18,19 @@ import (
) )
const ( const (
// EnvironmentDev represents development environment. // EnvironmentDev represents development environment
EnvironmentDev = "dev" EnvironmentDev = "dev"
// EnvironmentProd represents production environment. // EnvironmentProd represents production environment
EnvironmentProd = "prod" EnvironmentProd = "prod"
// defaultPort is the default HTTP listen port.
defaultPort = 8080
) )
// ErrInvalidEnvironment is returned when WEBHOOKER_ENVIRONMENT // nolint:revive // ConfigParams is a standard fx naming convention
// contains an unrecognised value.
var ErrInvalidEnvironment = errors.New("invalid environment")
//nolint:revive // ConfigParams is a standard fx naming convention.
type ConfigParams struct { type ConfigParams struct {
fx.In fx.In
Globals *globals.Globals Globals *globals.Globals
Logger *logger.Logger Logger *logger.Logger
} }
// Config holds all application configuration loaded from
// environment variables.
type Config struct { type Config struct {
DataDir string DataDir string
Debug bool Debug bool
@@ -55,67 +44,73 @@ type Config struct {
log *slog.Logger log *slog.Logger
} }
// IsDev returns true if running in development environment. // IsDev returns true if running in development environment
func (c *Config) IsDev() bool { func (c *Config) IsDev() bool {
return c.Environment == EnvironmentDev return c.Environment == EnvironmentDev
} }
// IsProd returns true if running in production environment. // IsProd returns true if running in production environment
func (c *Config) IsProd() bool { func (c *Config) IsProd() bool {
return c.Environment == EnvironmentProd return c.Environment == EnvironmentProd
} }
// envString returns the value of the named environment variable, // envString returns the value of the named environment variable, or
// or an empty string if not set. // an empty string if not set.
func envString(key string) string { func envString(key string) string {
return os.Getenv(key) return os.Getenv(key)
} }
// envBool returns the value of the named environment variable // envBool returns the value of the named environment variable parsed as a
// parsed as a boolean. Returns defaultValue if not set. // boolean. Returns defaultValue if not set.
func envBool(key string, defaultValue bool) bool { func envBool(key string, defaultValue bool) bool {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1" return strings.EqualFold(v, "true") || v == "1"
} }
return defaultValue return defaultValue
} }
// envInt returns the value of the named environment variable // envInt returns the value of the named environment variable parsed as an
// parsed as an integer. Returns defaultValue if not set or // integer. Returns defaultValue if not set or unparseable.
// unparseable.
func envInt(key string, defaultValue int) int { func envInt(key string, defaultValue int) int {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
i, err := strconv.Atoi(v) if i, err := strconv.Atoi(v); err == nil {
if err == nil {
return i return i
} }
} }
return defaultValue return defaultValue
} }
// New creates a Config by reading environment variables. // devDataDir returns the default data directory for the dev
// // environment. It uses $XDG_DATA_HOME/webhooker if set, otherwise
//nolint:revive // lc parameter is required by fx even if unused. // falls back to $HOME/.local/share/webhooker. The result is always
// an absolute path so the application's behavior does not depend on
// the working directory.
func devDataDir() string {
if xdg := os.Getenv("XDG_DATA_HOME"); xdg != "" {
return filepath.Join(xdg, "webhooker")
}
home, err := os.UserHomeDir()
if err != nil {
// Last resort: use /tmp so we still have an absolute path.
return filepath.Join(os.TempDir(), "webhooker")
}
return filepath.Join(home, ".local", "share", "webhooker")
}
// nolint:revive // lc parameter is required by fx even if unused
func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) { func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
log := params.Logger.Get() log := params.Logger.Get()
// Determine environment from WEBHOOKER_ENVIRONMENT env var, // Determine environment from WEBHOOKER_ENVIRONMENT env var, default to dev
// default to dev
environment := os.Getenv("WEBHOOKER_ENVIRONMENT") environment := os.Getenv("WEBHOOKER_ENVIRONMENT")
if environment == "" { if environment == "" {
environment = EnvironmentDev environment = EnvironmentDev
} }
// Validate environment // Validate environment
if environment != EnvironmentDev && if environment != EnvironmentDev && environment != EnvironmentProd {
environment != EnvironmentProd { return nil, fmt.Errorf("WEBHOOKER_ENVIRONMENT must be either '%s' or '%s', got '%s'",
return nil, fmt.Errorf( EnvironmentDev, EnvironmentProd, environment)
"%w: WEBHOOKER_ENVIRONMENT must be '%s' or '%s', got '%s'",
ErrInvalidEnvironment,
EnvironmentDev, EnvironmentProd, environment,
)
} }
// Load configuration values from environment variables // Load configuration values from environment variables
@@ -126,18 +121,22 @@ func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
Environment: environment, Environment: environment,
MetricsUsername: envString("METRICS_USERNAME"), MetricsUsername: envString("METRICS_USERNAME"),
MetricsPassword: envString("METRICS_PASSWORD"), MetricsPassword: envString("METRICS_PASSWORD"),
Port: envInt("PORT", defaultPort), Port: envInt("PORT", 8080),
SentryDSN: envString("SENTRY_DSN"), SentryDSN: envString("SENTRY_DSN"),
log: log, log: log,
params: &params, params: &params,
} }
// Set default DataDir. All SQLite databases (main application // Set default DataDir based on environment. All SQLite databases
// DB and per-webhook event DBs) live here. The same default is // (main application DB and per-webhook event DBs) live here.
// used regardless of environment; override with DATA_DIR if // Both defaults are absolute paths to avoid dependence on the
// needed. // working directory.
if s.DataDir == "" { if s.DataDir == "" {
s.DataDir = "/var/lib/webhooker" if s.IsProd() {
s.DataDir = "/data"
} else {
s.DataDir = devDataDir()
}
} }
if s.Debug { if s.Debug {
@@ -152,8 +151,7 @@ func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
"maintenanceMode", s.MaintenanceMode, "maintenanceMode", s.MaintenanceMode,
"dataDir", s.DataDir, "dataDir", s.DataDir,
"hasSentryDSN", s.SentryDSN != "", "hasSentryDSN", s.SentryDSN != "",
"hasMetricsAuth", "hasMetricsAuth", s.MetricsUsername != "" && s.MetricsPassword != "",
s.MetricsUsername != "" && s.MetricsPassword != "",
) )
return s, nil return s, nil

View File

@@ -1,14 +1,14 @@
package config_test package config
import ( import (
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/fx" "go.uber.org/fx"
"go.uber.org/fx/fxtest" "go.uber.org/fx/fxtest"
"sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
) )
@@ -24,142 +24,135 @@ func TestEnvironmentConfig(t *testing.T) {
}{ }{
{ {
name: "default is dev", name: "default is dev",
envValue: "",
envVars: map[string]string{},
expectError: false,
isDev: true, isDev: true,
isProd: false, isProd: false,
}, },
{ {
name: "explicit dev", name: "explicit dev",
envValue: "dev", envValue: "dev",
envVars: map[string]string{},
expectError: false,
isDev: true, isDev: true,
isProd: false, isProd: false,
}, },
{ {
name: "explicit prod", name: "explicit prod",
envValue: "prod", envValue: "prod",
envVars: map[string]string{},
expectError: false,
isDev: false, isDev: false,
isProd: true, isProd: true,
}, },
{ {
name: "invalid environment", name: "invalid environment",
envValue: "staging", envValue: "staging",
envVars: map[string]string{},
expectError: true, expectError: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Cannot use t.Parallel() here because t.Setenv // Set environment variable if specified
// is incompatible with parallel subtests.
if tt.envValue != "" { if tt.envValue != "" {
t.Setenv( os.Setenv("WEBHOOKER_ENVIRONMENT", tt.envValue)
"WEBHOOKER_ENVIRONMENT", tt.envValue, defer os.Unsetenv("WEBHOOKER_ENVIRONMENT")
)
} else { } else {
require.NoError(t, os.Unsetenv( os.Unsetenv("WEBHOOKER_ENVIRONMENT")
"WEBHOOKER_ENVIRONMENT",
))
} }
// Set additional environment variables
for k, v := range tt.envVars { for k, v := range tt.envVars {
t.Setenv(k, v) os.Setenv(k, v)
defer os.Unsetenv(k)
} }
if tt.expectError { if tt.expectError {
testEnvironmentConfigError(t) // Use regular fx.New for error cases since fxtest doesn't expose errors the same way
} else { var cfg *Config
testEnvironmentConfigSuccess(
t, tt.isDev, tt.isProd,
)
}
})
}
}
func testEnvironmentConfigError(t *testing.T) {
t.Helper()
var cfg *config.Config
app := fx.New( app := fx.New(
fx.NopLogger, fx.NopLogger, // Suppress fx logs in tests
fx.Provide( fx.Provide(
globals.New, globals.New,
logger.New, logger.New,
config.New, New,
), ),
fx.Populate(&cfg), fx.Populate(&cfg),
) )
assert.Error(t, app.Err()) assert.Error(t, app.Err())
}
func testEnvironmentConfigSuccess(
t *testing.T,
isDev, isProd bool,
) {
t.Helper()
var cfg *config.Config
app := fxtest.New(
t,
fx.Provide(
globals.New,
logger.New,
config.New,
),
fx.Populate(&cfg),
)
require.NoError(t, app.Err())
app.RequireStart()
defer app.RequireStop()
assert.Equal(t, isDev, cfg.IsDev())
assert.Equal(t, isProd, cfg.IsProd())
}
func TestDefaultDataDir(t *testing.T) {
for _, env := range []string{"", "dev", "prod"} {
name := env
if name == "" {
name = "unset"
}
t.Run("env="+name, func(t *testing.T) {
// Cannot use t.Parallel() here because t.Setenv
// is incompatible with parallel subtests.
if env != "" {
t.Setenv("WEBHOOKER_ENVIRONMENT", env)
} else { } else {
require.NoError(t, os.Unsetenv( // Use fxtest for success cases
"WEBHOOKER_ENVIRONMENT", var cfg *Config
))
}
require.NoError(t, os.Unsetenv("DATA_DIR"))
var cfg *config.Config
app := fxtest.New( app := fxtest.New(
t, t,
fx.Provide( fx.Provide(
globals.New, globals.New,
logger.New, logger.New,
config.New, New,
), ),
fx.Populate(&cfg), fx.Populate(&cfg),
) )
require.NoError(t, app.Err()) require.NoError(t, app.Err())
app.RequireStart() app.RequireStart()
defer app.RequireStop() defer app.RequireStop()
assert.Equal( assert.Equal(t, tt.isDev, cfg.IsDev())
t, "/var/lib/webhooker", cfg.DataDir, assert.Equal(t, tt.isProd, cfg.IsProd())
) }
}) })
} }
} }
func TestDevDataDir(t *testing.T) {
t.Run("uses XDG_DATA_HOME when set", func(t *testing.T) {
os.Setenv("XDG_DATA_HOME", "/custom/data")
defer os.Unsetenv("XDG_DATA_HOME")
got := devDataDir()
assert.Equal(t, "/custom/data/webhooker", got)
})
t.Run("falls back to HOME/.local/share/webhooker", func(t *testing.T) {
os.Unsetenv("XDG_DATA_HOME")
home, err := os.UserHomeDir()
require.NoError(t, err)
got := devDataDir()
assert.Equal(t, filepath.Join(home, ".local", "share", "webhooker"), got)
})
t.Run("result is always absolute", func(t *testing.T) {
os.Unsetenv("XDG_DATA_HOME")
got := devDataDir()
assert.True(t, filepath.IsAbs(got), "devDataDir() returned relative path: %s", got)
})
}
func TestDevDefaultDataDirIsAbsolute(t *testing.T) {
// Verify that when WEBHOOKER_ENVIRONMENT=dev and DATA_DIR is unset,
// the resulting DataDir is an absolute path.
os.Unsetenv("WEBHOOKER_ENVIRONMENT")
os.Unsetenv("DATA_DIR")
var cfg *Config
app := fxtest.New(
t,
fx.Provide(
globals.New,
logger.New,
New,
),
fx.Populate(&cfg),
)
require.NoError(t, app.Err())
app.RequireStart()
defer app.RequireStop()
assert.True(t, filepath.IsAbs(cfg.DataDir),
"dev default DataDir should be absolute, got: %s", cfg.DataDir)
}

View File

@@ -11,16 +11,15 @@ import (
// This replaces gorm.Model but uses UUID instead of uint for ID // This replaces gorm.Model but uses UUID instead of uint for ID
type BaseModel struct { type BaseModel struct {
ID string `gorm:"type:uuid;primary_key" json:"id"` ID string `gorm:"type:uuid;primary_key" json:"id"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updatedAt"` UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deletedAt,omitzero"` DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
} }
// BeforeCreate hook to set UUID before creating a record. // BeforeCreate hook to set UUID before creating a record
func (b *BaseModel) BeforeCreate(_ *gorm.DB) error { func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
if b.ID == "" { if b.ID == "" {
b.ID = uuid.New().String() b.ID = uuid.New().String()
} }
return nil return nil
} }

View File

@@ -1,4 +1,3 @@
// Package database provides SQLite persistence for webhooks, events, and users.
package database package database
import ( import (
@@ -20,42 +19,30 @@ import (
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
) )
const ( // nolint:revive // DatabaseParams is a standard fx naming convention
dataDirPerm = 0750
randomPasswordLen = 16
sessionKeyLen = 32
)
//nolint:revive // DatabaseParams is a standard fx naming convention.
type DatabaseParams struct { type DatabaseParams struct {
fx.In fx.In
Config *config.Config Config *config.Config
Logger *logger.Logger Logger *logger.Logger
} }
// Database manages the main SQLite connection and schema migrations.
type Database struct { type Database struct {
db *gorm.DB db *gorm.DB
log *slog.Logger log *slog.Logger
params *DatabaseParams params *DatabaseParams
} }
// New creates a Database that connects on fx start and disconnects on stop. func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
func New(
lc fx.Lifecycle,
params DatabaseParams,
) (*Database, error) {
d := &Database{ d := &Database{
params: &params, params: &params,
log: params.Logger.Get(), log: params.Logger.Get(),
} }
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
return d.connect() return d.connect()
}, },
OnStop: func(_ context.Context) error { OnStop: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
return d.close() return d.close()
}, },
}) })
@@ -63,92 +50,21 @@ func New(
return d, nil return d, nil
} }
// DB returns the underlying GORM database handle.
func (d *Database) DB() *gorm.DB {
return d.db
}
// GetOrCreateSessionKey retrieves the session encryption key from the
// settings table. If no key exists, a cryptographically secure random
// 32-byte key is generated, base64-encoded, and stored for future use.
func (d *Database) GetOrCreateSessionKey() (string, error) {
var setting Setting
result := d.db.Where(
&Setting{Key: "session_key"},
).First(&setting)
if result.Error == nil {
return setting.Value, nil
}
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", fmt.Errorf(
"failed to query session key: %w",
result.Error,
)
}
// Generate a new cryptographically secure 32-byte key
keyBytes := make([]byte, sessionKeyLen)
_, err := rand.Read(keyBytes)
if err != nil {
return "", fmt.Errorf(
"failed to generate session key: %w",
err,
)
}
encoded := base64.StdEncoding.EncodeToString(keyBytes)
setting = Setting{
Key: "session_key",
Value: encoded,
}
err = d.db.Create(&setting).Error
if err != nil {
return "", fmt.Errorf(
"failed to store session key: %w",
err,
)
}
d.log.Info(
"generated new session key and stored in database",
)
return encoded, nil
}
func (d *Database) connect() error { func (d *Database) connect() error {
// Ensure the data directory exists before opening the database. // Ensure the data directory exists before opening the database.
dataDir := d.params.Config.DataDir dataDir := d.params.Config.DataDir
if err := os.MkdirAll(dataDir, 0750); err != nil {
err := os.MkdirAll(dataDir, dataDirPerm) return fmt.Errorf("creating data directory %s: %w", dataDir, err)
if err != nil {
return fmt.Errorf(
"creating data directory %s: %w",
dataDir,
err,
)
} }
// Construct the main application database path inside DATA_DIR. // Construct the main application database path inside DATA_DIR.
dbPath := filepath.Join(dataDir, "webhooker.db") dbPath := filepath.Join(dataDir, "webhooker.db")
dbURL := fmt.Sprintf( dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbPath)
"file:%s?cache=shared&mode=rwc",
dbPath,
)
// Open the database with the pure Go SQLite driver // Open the database with the pure Go SQLite driver
sqlDB, err := sql.Open("sqlite", dbURL) sqlDB, err := sql.Open("sqlite", dbURL)
if err != nil { if err != nil {
d.log.Error( d.log.Error("failed to open database", "error", err)
"failed to open database",
"error", err,
)
return err return err
} }
@@ -157,11 +73,7 @@ func (d *Database) connect() error {
Conn: sqlDB, Conn: sqlDB,
}, &gorm.Config{}) }, &gorm.Config{})
if err != nil { if err != nil {
d.log.Error( d.log.Error("failed to connect to database", "error", err)
"failed to connect to database",
"error", err,
)
return err return err
} }
@@ -174,62 +86,34 @@ func (d *Database) connect() error {
func (d *Database) migrate() error { func (d *Database) migrate() error {
// Run GORM auto-migrations // Run GORM auto-migrations
err := d.Migrate() if err := d.Migrate(); err != nil {
if err != nil { d.log.Error("failed to run database migrations", "error", err)
d.log.Error(
"failed to run database migrations",
"error", err,
)
return err return err
} }
d.log.Info("database migrations completed") d.log.Info("database migrations completed")
// Check if admin user exists // Check if admin user exists
var userCount int64 var userCount int64
if err := d.db.Model(&User{}).Count(&userCount).Error; err != nil {
err = d.db.Model(&User{}).Count(&userCount).Error d.log.Error("failed to count users", "error", err)
if err != nil {
d.log.Error(
"failed to count users",
"error", err,
)
return err return err
} }
if userCount == 0 { if userCount == 0 {
return d.createAdminUser() // Create admin user
}
return nil
}
func (d *Database) createAdminUser() error {
d.log.Info("no users found, creating admin user") d.log.Info("no users found, creating admin user")
// Generate random password // Generate random password
password, err := GenerateRandomPassword( password, err := GenerateRandomPassword(16)
randomPasswordLen,
)
if err != nil { if err != nil {
d.log.Error( d.log.Error("failed to generate random password", "error", err)
"failed to generate random password",
"error", err,
)
return err return err
} }
// Hash the password // Hash the password
hashedPassword, err := HashPassword(password) hashedPassword, err := HashPassword(password)
if err != nil { if err != nil {
d.log.Error( d.log.Error("failed to hash password", "error", err)
"failed to hash password",
"error", err,
)
return err return err
} }
@@ -239,22 +123,17 @@ func (d *Database) createAdminUser() error {
Password: hashedPassword, Password: hashedPassword,
} }
err = d.db.Create(adminUser).Error if err := d.db.Create(adminUser).Error; err != nil {
if err != nil { d.log.Error("failed to create admin user", "error", err)
d.log.Error(
"failed to create admin user",
"error", err,
)
return err return err
} }
d.log.Info("admin user created", d.log.Info("admin user created",
"username", "admin", "username", "admin",
"password", password, "password", password,
"message", "message", "SAVE THIS PASSWORD - it will not be shown again!",
"SAVE THIS PASSWORD - it will not be shown again!",
) )
}
return nil return nil
} }
@@ -265,9 +144,43 @@ func (d *Database) close() error {
if err != nil { if err != nil {
return err return err
} }
return sqlDB.Close() return sqlDB.Close()
} }
return nil return nil
} }
func (d *Database) DB() *gorm.DB {
return d.db
}
// GetOrCreateSessionKey retrieves the session encryption key from the
// settings table. If no key exists, a cryptographically secure random
// 32-byte key is generated, base64-encoded, and stored for future use.
func (d *Database) GetOrCreateSessionKey() (string, error) {
var setting Setting
result := d.db.Where(&Setting{Key: "session_key"}).First(&setting)
if result.Error == nil {
return setting.Value, nil
}
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", fmt.Errorf("failed to query session key: %w", result.Error)
}
// Generate a new cryptographically secure 32-byte key
keyBytes := make([]byte, 32)
if _, err := rand.Read(keyBytes); err != nil {
return "", fmt.Errorf("failed to generate session key: %w", err)
}
encoded := base64.StdEncoding.EncodeToString(keyBytes)
setting = Setting{
Key: "session_key",
Value: encoded,
}
if err := d.db.Create(&setting).Error; err != nil {
return "", fmt.Errorf("failed to store session key: %w", err)
}
d.log.Info("generated new session key and stored in database")
return encoded, nil
}

View File

@@ -1,4 +1,4 @@
package database_test package database
import ( import (
"context" "context"
@@ -6,37 +6,37 @@ import (
"go.uber.org/fx/fxtest" "go.uber.org/fx/fxtest"
"sneak.berlin/go/webhooker/internal/config" "sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/database"
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
) )
func setupTestDB( func TestDatabaseConnection(t *testing.T) {
t *testing.T, // Set up test dependencies
) (*database.Database, *fxtest.Lifecycle) {
t.Helper()
lc := fxtest.NewLifecycle(t) lc := fxtest.NewLifecycle(t)
g := &globals.Globals{ // Create globals
Appname: "webhooker-test", globals.Appname = "webhooker-test"
Version: "test", globals.Version = "test"
g, err := globals.New(lc)
if err != nil {
t.Fatalf("Failed to create globals: %v", err)
} }
l, err := logger.New( // Create logger
lc, l, err := logger.New(lc, logger.LoggerParams{Globals: g})
logger.LoggerParams{Globals: g},
)
if err != nil { if err != nil {
t.Fatalf("Failed to create logger: %v", err) t.Fatalf("Failed to create logger: %v", err)
} }
// Create config with DataDir pointing to a temp directory
c := &config.Config{ c := &config.Config{
DataDir: t.TempDir(), DataDir: t.TempDir(),
Environment: "dev", Environment: "dev",
} }
db, err := database.New(lc, database.DatabaseParams{ // Create database
db, err := New(lc, DatabaseParams{
Config: c, Config: c,
Logger: l, Logger: l,
}) })
@@ -44,45 +44,31 @@ func setupTestDB(
t.Fatalf("Failed to create database: %v", err) t.Fatalf("Failed to create database: %v", err)
} }
return db, lc // Start lifecycle (this will trigger the connection)
}
func TestDatabaseConnection(t *testing.T) {
t.Parallel()
db, lc := setupTestDB(t)
ctx := context.Background() ctx := context.Background()
err = lc.Start(ctx)
err := lc.Start(ctx)
if err != nil { if err != nil {
t.Fatalf("Failed to connect to database: %v", err) t.Fatalf("Failed to connect to database: %v", err)
} }
defer func() { defer func() {
stopErr := lc.Stop(ctx) if stopErr := lc.Stop(ctx); stopErr != nil {
if stopErr != nil { t.Errorf("Failed to stop lifecycle: %v", stopErr)
t.Errorf(
"Failed to stop lifecycle: %v",
stopErr,
)
} }
}() }()
// Verify we can get the DB instance
if db.DB() == nil { if db.DB() == nil {
t.Error("Expected non-nil database connection") t.Error("Expected non-nil database connection")
} }
// Test that we can perform a simple query
var result int var result int
err = db.DB().Raw("SELECT 1").Scan(&result).Error err = db.DB().Raw("SELECT 1").Scan(&result).Error
if err != nil { if err != nil {
t.Fatalf("Failed to execute test query: %v", err) t.Fatalf("Failed to execute test query: %v", err)
} }
if result != 1 { if result != 1 {
t.Errorf( t.Errorf("Expected query result to be 1, got %d", result)
"Expected query result to be 1, got %d",
result,
)
} }
} }

View File

@@ -6,11 +6,11 @@ import "time"
type APIKey struct { type APIKey struct {
BaseModel BaseModel
UserID string `gorm:"type:uuid;not null" json:"userId"` UserID string `gorm:"type:uuid;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;not null" json:"key"` Key string `gorm:"uniqueIndex;not null" json:"key"`
Description string `json:"description"` Description string `json:"description"`
LastUsedAt *time.Time `json:"lastUsedAt,omitempty"` LastUsedAt *time.Time `json:"last_used_at,omitempty"`
// Relations // Relations
User User `json:"user,omitzero"` User User `json:"user,omitempty"`
} }

View File

@@ -3,7 +3,6 @@ package database
// DeliveryStatus represents the status of a delivery // DeliveryStatus represents the status of a delivery
type DeliveryStatus string type DeliveryStatus string
// Delivery status values.
const ( const (
DeliveryStatusPending DeliveryStatus = "pending" DeliveryStatusPending DeliveryStatus = "pending"
DeliveryStatusDelivered DeliveryStatus = "delivered" DeliveryStatusDelivered DeliveryStatus = "delivered"
@@ -15,12 +14,12 @@ const (
type Delivery struct { type Delivery struct {
BaseModel BaseModel
EventID string `gorm:"type:uuid;not null" json:"eventId"` EventID string `gorm:"type:uuid;not null" json:"event_id"`
TargetID string `gorm:"type:uuid;not null" json:"targetId"` TargetID string `gorm:"type:uuid;not null" json:"target_id"`
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"` Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
// Relations // Relations
Event Event `json:"event,omitzero"` Event Event `json:"event,omitempty"`
Target Target `json:"target,omitzero"` Target Target `json:"target,omitempty"`
DeliveryResults []DeliveryResult `json:"deliveryResults,omitempty"` DeliveryResults []DeliveryResult `json:"delivery_results,omitempty"`
} }

View File

@@ -4,14 +4,14 @@ package database
type DeliveryResult struct { type DeliveryResult struct {
BaseModel BaseModel
DeliveryID string `gorm:"type:uuid;not null" json:"deliveryId"` DeliveryID string `gorm:"type:uuid;not null" json:"delivery_id"`
AttemptNum int `gorm:"not null" json:"attemptNum"` AttemptNum int `gorm:"not null" json:"attempt_num"`
Success bool `json:"success"` Success bool `json:"success"`
StatusCode int `json:"statusCode,omitempty"` StatusCode int `json:"status_code,omitempty"`
ResponseBody string `gorm:"type:text" json:"responseBody,omitempty"` ResponseBody string `gorm:"type:text" json:"response_body,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
Duration int64 `json:"durationMs"` // Duration in milliseconds Duration int64 `json:"duration_ms"` // Duration in milliseconds
// Relations // Relations
Delivery Delivery `json:"delivery,omitzero"` Delivery Delivery `json:"delivery,omitempty"`
} }

View File

@@ -4,11 +4,11 @@ package database
type Entrypoint struct { type Entrypoint struct {
BaseModel BaseModel
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"` WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this entrypoint Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this entrypoint
Description string `json:"description"` Description string `json:"description"`
Active bool `gorm:"default:true" json:"active"` Active bool `gorm:"default:true" json:"active"`
// Relations // Relations
Webhook Webhook `json:"webhook,omitzero"` Webhook Webhook `json:"webhook,omitempty"`
} }

View File

@@ -4,17 +4,17 @@ package database
type Event struct { type Event struct {
BaseModel BaseModel
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"` WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
EntrypointID string `gorm:"type:uuid;not null" json:"entrypointId"` EntrypointID string `gorm:"type:uuid;not null" json:"entrypoint_id"`
// Request data // Request data
Method string `gorm:"not null" json:"method"` Method string `gorm:"not null" json:"method"`
Headers string `gorm:"type:text" json:"headers"` // JSON Headers string `gorm:"type:text" json:"headers"` // JSON
Body string `gorm:"type:text" json:"body"` Body string `gorm:"type:text" json:"body"`
ContentType string `json:"contentType"` ContentType string `json:"content_type"`
// Relations // Relations
Webhook Webhook `json:"webhook,omitzero"` Webhook Webhook `json:"webhook,omitempty"`
Entrypoint Entrypoint `json:"entrypoint,omitzero"` Entrypoint Entrypoint `json:"entrypoint,omitempty"`
Deliveries []Delivery `json:"deliveries,omitempty"` Deliveries []Delivery `json:"deliveries,omitempty"`
} }

View File

@@ -3,7 +3,6 @@ package database
// TargetType represents the type of delivery target // TargetType represents the type of delivery target
type TargetType string type TargetType string
// Target type values.
const ( const (
TargetTypeHTTP TargetType = "http" TargetTypeHTTP TargetType = "http"
TargetTypeDatabase TargetType = "database" TargetTypeDatabase TargetType = "database"
@@ -15,7 +14,7 @@ const (
type Target struct { type Target struct {
BaseModel BaseModel
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"` WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
Type TargetType `gorm:"not null" json:"type"` Type TargetType `gorm:"not null" json:"type"`
Active bool `gorm:"default:true" json:"active"` Active bool `gorm:"default:true" json:"active"`
@@ -24,10 +23,10 @@ type Target struct {
Config string `gorm:"type:text" json:"config"` // JSON configuration Config string `gorm:"type:text" json:"config"` // JSON configuration
// For HTTP targets (max_retries=0 means fire-and-forget, >0 enables retries with backoff) // For HTTP targets (max_retries=0 means fire-and-forget, >0 enables retries with backoff)
MaxRetries int `json:"maxRetries,omitempty"` MaxRetries int `json:"max_retries,omitempty"`
MaxQueueSize int `json:"maxQueueSize,omitempty"` MaxQueueSize int `json:"max_queue_size,omitempty"`
// Relations // Relations
Webhook Webhook `json:"webhook,omitzero"` Webhook Webhook `json:"webhook,omitempty"`
Deliveries []Delivery `json:"deliveries,omitempty"` Deliveries []Delivery `json:"deliveries,omitempty"`
} }

View File

@@ -9,5 +9,5 @@ type User struct {
// Relations // Relations
Webhooks []Webhook `json:"webhooks,omitempty"` Webhooks []Webhook `json:"webhooks,omitempty"`
APIKeys []APIKey `json:"apiKeys,omitempty"` APIKeys []APIKey `json:"api_keys,omitempty"`
} }

View File

@@ -4,13 +4,13 @@ package database
type Webhook struct { type Webhook struct {
BaseModel BaseModel
UserID string `gorm:"type:uuid;not null" json:"userId"` UserID string `gorm:"type:uuid;not null" json:"user_id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
Description string `json:"description"` Description string `json:"description"`
RetentionDays int `gorm:"default:30" json:"retentionDays"` // Days to retain events RetentionDays int `gorm:"default:30" json:"retention_days"` // Days to retain events
// Relations // Relations
User User `json:"user,omitzero"` User User `json:"user,omitempty"`
Entrypoints []Entrypoint `json:"entrypoints,omitempty"` Entrypoints []Entrypoint `json:"entrypoints,omitempty"`
Targets []Target `json:"targets,omitempty"` Targets []Target `json:"targets,omitempty"`
} }

View File

@@ -4,7 +4,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"strings" "strings"
@@ -21,23 +20,6 @@ const (
argon2SaltLen = 16 argon2SaltLen = 16
) )
// hashParts is the expected number of $-separated segments
// in an encoded Argon2id hash string.
const hashParts = 6
// minPasswordComplexityLen is the minimum password length that
// triggers per-character-class complexity enforcement.
const minPasswordComplexityLen = 4
// Sentinel errors returned by decodeHash.
var (
errInvalidHashFormat = errors.New("invalid hash format")
errInvalidAlgorithm = errors.New("invalid algorithm")
errIncompatibleVersion = errors.New("incompatible argon2 version")
errSaltLengthOutOfRange = errors.New("salt length out of range")
errHashLengthOutOfRange = errors.New("hash length out of range")
)
// PasswordConfig holds Argon2 configuration // PasswordConfig holds Argon2 configuration
type PasswordConfig struct { type PasswordConfig struct {
Time uint32 Time uint32
@@ -64,44 +46,26 @@ func HashPassword(password string) (string, error) {
// Generate a salt // Generate a salt
salt := make([]byte, config.SaltLen) salt := make([]byte, config.SaltLen)
if _, err := rand.Read(salt); err != nil {
_, err := rand.Read(salt)
if err != nil {
return "", err return "", err
} }
// Generate the hash // Generate the hash
hash := argon2.IDKey( hash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
[]byte(password),
salt,
config.Time,
config.Memory,
config.Threads,
config.KeyLen,
)
// Encode the hash and parameters // Encode the hash and parameters
b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash) b64Hash := base64.RawStdEncoding.EncodeToString(hash)
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash // Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
encoded := fmt.Sprintf( encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash)
argon2.Version,
config.Memory,
config.Time,
config.Threads,
b64Salt,
b64Hash,
)
return encoded, nil return encoded, nil
} }
// VerifyPassword checks if the provided password matches the hash // VerifyPassword checks if the provided password matches the hash
func VerifyPassword( func VerifyPassword(password, encodedHash string) (bool, error) {
password, encodedHash string,
) (bool, error) {
// Extract parameters and hash from encoded string // Extract parameters and hash from encoded string
config, salt, hash, err := decodeHash(encodedHash) config, salt, hash, err := decodeHash(encodedHash)
if err != nil { if err != nil {
@@ -109,119 +73,60 @@ func VerifyPassword(
} }
// Generate hash of the provided password // Generate hash of the provided password
otherHash := argon2.IDKey( otherHash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
[]byte(password),
salt,
config.Time,
config.Memory,
config.Threads,
config.KeyLen,
)
// Compare hashes using constant time comparison // Compare hashes using constant time comparison
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
} }
// decodeHash extracts parameters, salt, and hash from an // decodeHash extracts parameters, salt, and hash from an encoded hash string
// encoded hash string. func decodeHash(encodedHash string) (*PasswordConfig, []byte, []byte, error) {
func decodeHash(
encodedHash string,
) (*PasswordConfig, []byte, []byte, error) {
parts := strings.Split(encodedHash, "$") parts := strings.Split(encodedHash, "$")
if len(parts) != hashParts { if len(parts) != 6 {
return nil, nil, nil, errInvalidHashFormat return nil, nil, nil, fmt.Errorf("invalid hash format")
} }
if parts[1] != "argon2id" { if parts[1] != "argon2id" {
return nil, nil, nil, errInvalidAlgorithm return nil, nil, nil, fmt.Errorf("invalid algorithm")
} }
version, err := parseVersion(parts[2]) var version int
if err != nil { if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
if version != argon2.Version { if version != argon2.Version {
return nil, nil, nil, errIncompatibleVersion return nil, nil, nil, fmt.Errorf("incompatible argon2 version")
} }
config, err := parseParams(parts[3]) config := &PasswordConfig{}
if err != nil { if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads); err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
salt, err := decodeSalt(parts[4]) salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
saltLen := len(salt)
if saltLen < 0 || saltLen > int(^uint32(0)) {
return nil, nil, nil, fmt.Errorf("salt length out of range")
}
config.SaltLen = uint32(saltLen) // nolint:gosec // checked above
config.SaltLen = uint32(len(salt)) //nolint:gosec // validated in decodeSalt hash, err := base64.RawStdEncoding.DecodeString(parts[5])
hash, err := decodeHashBytes(parts[5])
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
hashLen := len(hash)
config.KeyLen = uint32(len(hash)) //nolint:gosec // validated in decodeHashBytes if hashLen < 0 || hashLen > int(^uint32(0)) {
return nil, nil, nil, fmt.Errorf("hash length out of range")
}
config.KeyLen = uint32(hashLen) // nolint:gosec // checked above
return config, salt, hash, nil return config, salt, hash, nil
} }
func parseVersion(s string) (int, error) { // GenerateRandomPassword generates a cryptographically secure random password
var version int
_, err := fmt.Sscanf(s, "v=%d", &version)
if err != nil {
return 0, fmt.Errorf("parsing version: %w", err)
}
return version, nil
}
func parseParams(s string) (*PasswordConfig, error) {
config := &PasswordConfig{}
_, err := fmt.Sscanf(
s, "m=%d,t=%d,p=%d",
&config.Memory, &config.Time, &config.Threads,
)
if err != nil {
return nil, fmt.Errorf("parsing params: %w", err)
}
return config, nil
}
func decodeSalt(s string) ([]byte, error) {
salt, err := base64.RawStdEncoding.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("decoding salt: %w", err)
}
saltLen := len(salt)
if saltLen < 0 || saltLen > int(^uint32(0)) {
return nil, errSaltLengthOutOfRange
}
return salt, nil
}
func decodeHashBytes(s string) ([]byte, error) {
hash, err := base64.RawStdEncoding.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("decoding hash: %w", err)
}
hashLen := len(hash)
if hashLen < 0 || hashLen > int(^uint32(0)) {
return nil, errHashLengthOutOfRange
}
return hash, nil
}
// GenerateRandomPassword generates a cryptographically secure
// random password.
func GenerateRandomPassword(length int) (string, error) { func GenerateRandomPassword(length int) (string, error) {
const ( const (
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
@@ -236,27 +141,27 @@ func GenerateRandomPassword(length int) (string, error) {
// Create password slice // Create password slice
password := make([]byte, length) password := make([]byte, length)
// Ensure at least one character from each set // Ensure at least one character from each set for password complexity
if length >= minPasswordComplexityLen { if length >= 4 {
// Get one character from each set
password[0] = uppercase[cryptoRandInt(len(uppercase))] password[0] = uppercase[cryptoRandInt(len(uppercase))]
password[1] = lowercase[cryptoRandInt(len(lowercase))] password[1] = lowercase[cryptoRandInt(len(lowercase))]
password[2] = digits[cryptoRandInt(len(digits))] password[2] = digits[cryptoRandInt(len(digits))]
password[3] = special[cryptoRandInt(len(special))] password[3] = special[cryptoRandInt(len(special))]
// Fill the rest randomly from all characters // Fill the rest randomly from all characters
for i := minPasswordComplexityLen; i < length; i++ { for i := 4; i < length; i++ {
password[i] = allChars[cryptoRandInt(len(allChars))] password[i] = allChars[cryptoRandInt(len(allChars))]
} }
// Shuffle the password to avoid predictable pattern // Shuffle the password to avoid predictable pattern
for i := range len(password) - 1 { for i := len(password) - 1; i > 0; i-- {
j := cryptoRandInt(len(password) - i) j := cryptoRandInt(i + 1)
idx := len(password) - 1 - i password[i], password[j] = password[j], password[i]
password[idx], password[j] = password[j], password[idx]
} }
} else { } else {
// For very short passwords, just use all characters // For very short passwords, just use all characters
for i := range length { for i := 0; i < length; i++ {
password[i] = allChars[cryptoRandInt(len(allChars))] password[i] = allChars[cryptoRandInt(len(allChars))]
} }
} }
@@ -264,17 +169,16 @@ func GenerateRandomPassword(length int) (string, error) {
return string(password), nil return string(password), nil
} }
// cryptoRandInt generates a cryptographically secure random // cryptoRandInt generates a cryptographically secure random integer in [0, max)
// integer in [0, upperBound). func cryptoRandInt(max int) int {
func cryptoRandInt(upperBound int) int { if max <= 0 {
if upperBound <= 0 { panic("max must be positive")
panic("upperBound must be positive")
} }
nBig, err := rand.Int( // Calculate the maximum valid value to avoid modulo bias
rand.Reader, // For example, if max=200 and we have 256 possible values,
big.NewInt(int64(upperBound)), // we only accept values 0-199 (reject 200-255)
) nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil { if err != nil {
panic(fmt.Sprintf("crypto/rand error: %v", err)) panic(fmt.Sprintf("crypto/rand error: %v", err))
} }

View File

@@ -1,15 +1,11 @@
package database_test package database
import ( import (
"strings" "strings"
"testing" "testing"
"sneak.berlin/go/webhooker/internal/database"
) )
func TestGenerateRandomPassword(t *testing.T) { func TestGenerateRandomPassword(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
length int length int
@@ -22,172 +18,109 @@ func TestGenerateRandomPassword(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() password, err := GenerateRandomPassword(tt.length)
password, err := database.GenerateRandomPassword(
tt.length,
)
if err != nil { if err != nil {
t.Fatalf( t.Fatalf("GenerateRandomPassword() error = %v", err)
"GenerateRandomPassword() error = %v",
err,
)
} }
if len(password) != tt.length { if len(password) != tt.length {
t.Errorf( t.Errorf("Password length = %v, want %v", len(password), tt.length)
"Password length = %v, want %v",
len(password), tt.length,
)
} }
checkPasswordComplexity( // For passwords >= 4 chars, check complexity
t, password, tt.length, if tt.length >= 4 {
) hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
hasSpecial = true
}
}
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
t.Errorf("Password lacks required complexity: upper=%v, lower=%v, digit=%v, special=%v",
hasUpper, hasLower, hasDigit, hasSpecial)
}
}
}) })
} }
} }
func checkPasswordComplexity(
t *testing.T,
password string,
length int,
) {
t.Helper()
// For passwords >= 4 chars, check complexity
if length < 4 {
return
}
flags := classifyChars(password)
if !flags[0] || !flags[1] || !flags[2] || !flags[3] {
t.Errorf(
"Password lacks required complexity: "+
"upper=%v, lower=%v, digit=%v, special=%v",
flags[0], flags[1], flags[2], flags[3],
)
}
}
func classifyChars(s string) [4]bool {
var flags [4]bool // upper, lower, digit, special
for _, char := range s {
switch {
case char >= 'A' && char <= 'Z':
flags[0] = true
case char >= 'a' && char <= 'z':
flags[1] = true
case char >= '0' && char <= '9':
flags[2] = true
case strings.ContainsRune(
"!@#$%^&*()_+-=[]{}|;:,.<>?",
char,
):
flags[3] = true
}
}
return flags
}
func TestGenerateRandomPasswordUniqueness(t *testing.T) { func TestGenerateRandomPasswordUniqueness(t *testing.T) {
t.Parallel()
// Generate multiple passwords and ensure they're different // Generate multiple passwords and ensure they're different
passwords := make(map[string]bool) passwords := make(map[string]bool)
const numPasswords = 100 const numPasswords = 100
for range numPasswords { for i := 0; i < numPasswords; i++ {
password, err := database.GenerateRandomPassword(16) password, err := GenerateRandomPassword(16)
if err != nil { if err != nil {
t.Fatalf( t.Fatalf("GenerateRandomPassword() error = %v", err)
"GenerateRandomPassword() error = %v",
err,
)
} }
if passwords[password] { if passwords[password] {
t.Errorf( t.Errorf("Duplicate password generated: %s", password)
"Duplicate password generated: %s",
password,
)
} }
passwords[password] = true passwords[password] = true
} }
} }
func TestHashPassword(t *testing.T) { func TestHashPassword(t *testing.T) {
t.Parallel()
password := "testPassword123!" password := "testPassword123!"
hash, err := database.HashPassword(password) hash, err := HashPassword(password)
if err != nil { if err != nil {
t.Fatalf("HashPassword() error = %v", err) t.Fatalf("HashPassword() error = %v", err)
} }
// Check that hash has correct format // Check that hash has correct format
if !strings.HasPrefix(hash, "$argon2id$") { if !strings.HasPrefix(hash, "$argon2id$") {
t.Errorf( t.Errorf("Hash doesn't have correct prefix: %s", hash)
"Hash doesn't have correct prefix: %s",
hash,
)
} }
// Verify password // Verify password
valid, err := database.VerifyPassword(password, hash) valid, err := VerifyPassword(password, hash)
if err != nil { if err != nil {
t.Fatalf("VerifyPassword() error = %v", err) t.Fatalf("VerifyPassword() error = %v", err)
} }
if !valid { if !valid {
t.Error( t.Error("VerifyPassword() returned false for correct password")
"VerifyPassword() returned false " +
"for correct password",
)
} }
// Verify wrong password fails // Verify wrong password fails
valid, err = database.VerifyPassword( valid, err = VerifyPassword("wrongPassword", hash)
"wrongPassword", hash,
)
if err != nil { if err != nil {
t.Fatalf("VerifyPassword() error = %v", err) t.Fatalf("VerifyPassword() error = %v", err)
} }
if valid { if valid {
t.Error( t.Error("VerifyPassword() returned true for wrong password")
"VerifyPassword() returned true " +
"for wrong password",
)
} }
} }
func TestHashPasswordUniqueness(t *testing.T) { func TestHashPasswordUniqueness(t *testing.T) {
t.Parallel()
password := "testPassword123!" password := "testPassword123!"
// Same password should produce different hashes // Same password should produce different hashes due to salt
hash1, err := database.HashPassword(password) hash1, err := HashPassword(password)
if err != nil { if err != nil {
t.Fatalf("HashPassword() error = %v", err) t.Fatalf("HashPassword() error = %v", err)
} }
hash2, err := database.HashPassword(password) hash2, err := HashPassword(password)
if err != nil { if err != nil {
t.Fatalf("HashPassword() error = %v", err) t.Fatalf("HashPassword() error = %v", err)
} }
if hash1 == hash2 { if hash1 == hash2 {
t.Error( t.Error("Same password produced identical hashes (salt not working)")
"Same password produced identical hashes " +
"(salt not working)",
)
} }
} }

View File

@@ -3,7 +3,6 @@ package database
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
@@ -17,82 +16,87 @@ import (
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
) )
// WebhookDBManagerParams holds the fx dependencies for // nolint:revive // WebhookDBManagerParams is a standard fx naming convention
// WebhookDBManager.
type WebhookDBManagerParams struct { type WebhookDBManagerParams struct {
fx.In fx.In
Config *config.Config Config *config.Config
Logger *logger.Logger Logger *logger.Logger
} }
// errInvalidCachedDBType indicates a type assertion failure // WebhookDBManager manages per-webhook SQLite database files for event storage.
// when retrieving a cached database connection. // Each webhook gets its own dedicated database containing Events, Deliveries,
var errInvalidCachedDBType = errors.New( // and DeliveryResults. Database connections are opened lazily and cached.
"invalid cached database type",
)
// WebhookDBManager manages per-webhook SQLite database files
// for event storage. Each webhook gets its own dedicated
// database containing Events, Deliveries, and DeliveryResults.
// Database connections are opened lazily and cached.
type WebhookDBManager struct { type WebhookDBManager struct {
dataDir string dataDir string
dbs sync.Map // map[webhookID]*gorm.DB dbs sync.Map // map[webhookID]*gorm.DB
log *slog.Logger log *slog.Logger
} }
// NewWebhookDBManager creates a new WebhookDBManager and // NewWebhookDBManager creates a new WebhookDBManager and registers lifecycle hooks.
// registers lifecycle hooks. func NewWebhookDBManager(lc fx.Lifecycle, params WebhookDBManagerParams) (*WebhookDBManager, error) {
func NewWebhookDBManager(
lc fx.Lifecycle,
params WebhookDBManagerParams,
) (*WebhookDBManager, error) {
m := &WebhookDBManager{ m := &WebhookDBManager{
dataDir: params.Config.DataDir, dataDir: params.Config.DataDir,
log: params.Logger.Get(), log: params.Logger.Get(),
} }
// Create data directory if it doesn't exist // Create data directory if it doesn't exist
err := os.MkdirAll(m.dataDir, dataDirPerm) if err := os.MkdirAll(m.dataDir, 0750); err != nil {
if err != nil { return nil, fmt.Errorf("creating data directory %s: %w", m.dataDir, err)
return nil, fmt.Errorf(
"creating data directory %s: %w",
m.dataDir,
err,
)
} }
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStop: func(_ context.Context) error { OnStop: func(_ context.Context) error { //nolint:revive // ctx unused but required by fx
return m.CloseAll() return m.CloseAll()
}, },
}) })
m.log.Info( m.log.Info("webhook database manager initialized", "data_dir", m.dataDir)
"webhook database manager initialized",
"data_dir", m.dataDir,
)
return m, nil return m, nil
} }
// GetDB returns the database connection for a webhook, // dbPath returns the filesystem path for a webhook's database file.
// creating the database file lazily if it doesn't exist. func (m *WebhookDBManager) dbPath(webhookID string) string {
func (m *WebhookDBManager) GetDB( return filepath.Join(m.dataDir, fmt.Sprintf("events-%s.db", webhookID))
webhookID string, }
) (*gorm.DB, error) {
// openDB opens (or creates) a per-webhook SQLite database and runs migrations.
func (m *WebhookDBManager) openDB(webhookID string) (*gorm.DB, error) {
path := m.dbPath(webhookID)
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", path)
sqlDB, err := sql.Open("sqlite", dbURL)
if err != nil {
return nil, fmt.Errorf("opening webhook database %s: %w", webhookID, err)
}
db, err := gorm.Open(sqlite.Dialector{
Conn: sqlDB,
}, &gorm.Config{})
if err != nil {
sqlDB.Close()
return nil, fmt.Errorf("connecting to webhook database %s: %w", webhookID, err)
}
// Run migrations for event-tier models only
if err := db.AutoMigrate(&Event{}, &Delivery{}, &DeliveryResult{}); err != nil {
sqlDB.Close()
return nil, fmt.Errorf("migrating webhook database %s: %w", webhookID, err)
}
m.log.Info("opened per-webhook database", "webhook_id", webhookID, "path", path)
return db, nil
}
// GetDB returns the database connection for a webhook, creating the database
// file lazily if it doesn't exist. This handles both new webhooks and existing
// webhooks that were created before per-webhook databases were introduced.
func (m *WebhookDBManager) GetDB(webhookID string) (*gorm.DB, error) {
// Fast path: already open // Fast path: already open
if val, ok := m.dbs.Load(webhookID); ok { if val, ok := m.dbs.Load(webhookID); ok {
cachedDB, castOK := val.(*gorm.DB) cachedDB, castOK := val.(*gorm.DB)
if !castOK { if !castOK {
return nil, fmt.Errorf( return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
"%w for webhook %s",
errInvalidCachedDBType,
webhookID,
)
} }
return cachedDB, nil return cachedDB, nil
} }
@@ -102,61 +106,44 @@ func (m *WebhookDBManager) GetDB(
return nil, err return nil, err
} }
// Store it; if another goroutine beat us, close ours // Store it; if another goroutine beat us, close ours and use theirs
actual, loaded := m.dbs.LoadOrStore(webhookID, db) actual, loaded := m.dbs.LoadOrStore(webhookID, db)
if loaded { if loaded {
// Another goroutine created it first; close our duplicate // Another goroutine created it first; close our duplicate
sqlDB, closeErr := db.DB() if sqlDB, closeErr := db.DB(); closeErr == nil {
if closeErr == nil { sqlDB.Close()
_ = sqlDB.Close()
} }
existingDB, castOK := actual.(*gorm.DB) existingDB, castOK := actual.(*gorm.DB)
if !castOK { if !castOK {
return nil, fmt.Errorf( return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
"%w for webhook %s",
errInvalidCachedDBType,
webhookID,
)
} }
return existingDB, nil return existingDB, nil
} }
return db, nil return db, nil
} }
// CreateDB explicitly creates a new per-webhook database file // CreateDB explicitly creates a new per-webhook database file and runs migrations.
// and runs migrations. // This is called when a new webhook is created.
func (m *WebhookDBManager) CreateDB( func (m *WebhookDBManager) CreateDB(webhookID string) error {
webhookID string,
) error {
_, err := m.GetDB(webhookID) _, err := m.GetDB(webhookID)
return err return err
} }
// DBExists checks if a per-webhook database file exists on // DBExists checks if a per-webhook database file exists on disk.
// disk. func (m *WebhookDBManager) DBExists(webhookID string) bool {
func (m *WebhookDBManager) DBExists(
webhookID string,
) bool {
_, err := os.Stat(m.dbPath(webhookID)) _, err := os.Stat(m.dbPath(webhookID))
return err == nil return err == nil
} }
// DeleteDB closes the connection and deletes the database file // DeleteDB closes the connection and deletes the database file for a webhook.
// for a webhook. The file is permanently removed. // This performs a hard delete — the file is permanently removed.
func (m *WebhookDBManager) DeleteDB( func (m *WebhookDBManager) DeleteDB(webhookID string) error {
webhookID string,
) error {
// Close and remove from cache // Close and remove from cache
if val, ok := m.dbs.LoadAndDelete(webhookID); ok { if val, ok := m.dbs.LoadAndDelete(webhookID); ok {
if gormDB, castOK := val.(*gorm.DB); castOK { if gormDB, castOK := val.(*gorm.DB); castOK {
sqlDB, err := gormDB.DB() if sqlDB, err := gormDB.DB(); err == nil {
if err == nil { sqlDB.Close()
_ = sqlDB.Close()
} }
} }
} }
@@ -164,20 +151,12 @@ func (m *WebhookDBManager) DeleteDB(
// Delete the main DB file and WAL/SHM files // Delete the main DB file and WAL/SHM files
path := m.dbPath(webhookID) path := m.dbPath(webhookID)
for _, suffix := range []string{"", "-wal", "-shm"} { for _, suffix := range []string{"", "-wal", "-shm"} {
err := os.Remove(path + suffix) if err := os.Remove(path + suffix); err != nil && !os.IsNotExist(err) {
if err != nil && !os.IsNotExist(err) { return fmt.Errorf("deleting webhook database file %s%s: %w", path, suffix, err)
return fmt.Errorf(
"deleting webhook database file %s%s: %w",
path, suffix, err,
)
} }
} }
m.log.Info( m.log.Info("deleted per-webhook database", "webhook_id", webhookID)
"deleted per-webhook database",
"webhook_id", webhookID,
)
return nil return nil
} }
@@ -185,97 +164,20 @@ func (m *WebhookDBManager) DeleteDB(
// Called during application shutdown. // Called during application shutdown.
func (m *WebhookDBManager) CloseAll() error { func (m *WebhookDBManager) CloseAll() error {
var lastErr error var lastErr error
m.dbs.Range(func(key, value interface{}) bool {
m.dbs.Range(func(key, value any) bool {
if gormDB, castOK := value.(*gorm.DB); castOK { if gormDB, castOK := value.(*gorm.DB); castOK {
sqlDB, err := gormDB.DB() if sqlDB, err := gormDB.DB(); err == nil {
if err == nil { if closeErr := sqlDB.Close(); closeErr != nil {
closeErr := sqlDB.Close()
if closeErr != nil {
lastErr = closeErr lastErr = closeErr
m.log.Error( m.log.Error("failed to close webhook database",
"failed to close webhook database",
"webhook_id", key, "webhook_id", key,
"error", closeErr, "error", closeErr,
) )
} }
} }
} }
m.dbs.Delete(key) m.dbs.Delete(key)
return true return true
}) })
return lastErr return lastErr
} }
// DBPath returns the filesystem path for a webhook's database
// file.
func (m *WebhookDBManager) DBPath(
webhookID string,
) string {
return m.dbPath(webhookID)
}
func (m *WebhookDBManager) dbPath(
webhookID string,
) string {
return filepath.Join(
m.dataDir,
fmt.Sprintf("events-%s.db", webhookID),
)
}
// openDB opens (or creates) a per-webhook SQLite database and
// runs migrations.
func (m *WebhookDBManager) openDB(
webhookID string,
) (*gorm.DB, error) {
path := m.dbPath(webhookID)
dbURL := fmt.Sprintf(
"file:%s?cache=shared&mode=rwc",
path,
)
sqlDB, err := sql.Open("sqlite", dbURL)
if err != nil {
return nil, fmt.Errorf(
"opening webhook database %s: %w",
webhookID, err,
)
}
db, err := gorm.Open(sqlite.Dialector{
Conn: sqlDB,
}, &gorm.Config{})
if err != nil {
_ = sqlDB.Close()
return nil, fmt.Errorf(
"connecting to webhook database %s: %w",
webhookID, err,
)
}
// Run migrations for event-tier models only
err = db.AutoMigrate(
&Event{}, &Delivery{}, &DeliveryResult{},
)
if err != nil {
_ = sqlDB.Close()
return nil, fmt.Errorf(
"migrating webhook database %s: %w",
webhookID, err,
)
}
m.log.Info(
"opened per-webhook database",
"webhook_id", webhookID,
"path", path,
)
return db, nil
}

View File

@@ -1,4 +1,4 @@
package database_test package database
import ( import (
"context" "context"
@@ -10,29 +10,23 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/fx/fxtest" "go.uber.org/fx/fxtest"
"gorm.io/gorm"
"sneak.berlin/go/webhooker/internal/config" "sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/database"
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
) )
func setupTestWebhookDBManager( func setupTestWebhookDBManager(t *testing.T) (*WebhookDBManager, *fxtest.Lifecycle) {
t *testing.T,
) (*database.WebhookDBManager, *fxtest.Lifecycle) {
t.Helper() t.Helper()
lc := fxtest.NewLifecycle(t) lc := fxtest.NewLifecycle(t)
g := &globals.Globals{ globals.Appname = "webhooker-test"
Appname: "webhooker-test", globals.Version = "test"
Version: "test",
}
l, err := logger.New( g, err := globals.New(lc)
lc, require.NoError(t, err)
logger.LoggerParams{Globals: g},
) l, err := logger.New(lc, logger.LoggerParams{Globals: g})
require.NoError(t, err) require.NoError(t, err)
dataDir := filepath.Join(t.TempDir(), "events") dataDir := filepath.Join(t.TempDir(), "events")
@@ -41,25 +35,19 @@ func setupTestWebhookDBManager(
DataDir: dataDir, DataDir: dataDir,
} }
mgr, err := database.NewWebhookDBManager( mgr, err := NewWebhookDBManager(lc, WebhookDBManagerParams{
lc,
database.WebhookDBManagerParams{
Config: cfg, Config: cfg,
Logger: l, Logger: l,
}, })
)
require.NoError(t, err) require.NoError(t, err)
return mgr, lc return mgr, lc
} }
func TestWebhookDBManager_CreateAndGetDB(t *testing.T) { func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
t.Parallel()
mgr, lc := setupTestWebhookDBManager(t) mgr, lc := setupTestWebhookDBManager(t)
ctx := context.Background() ctx := context.Background()
require.NoError(t, lc.Start(ctx)) require.NoError(t, lc.Start(ctx))
defer func() { require.NoError(t, lc.Stop(ctx)) }() defer func() { require.NoError(t, lc.Stop(ctx)) }()
webhookID := uuid.New().String() webhookID := uuid.New().String()
@@ -80,7 +68,7 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
require.NotNil(t, db) require.NotNil(t, db)
// Verify we can write an event // Verify we can write an event
event := &database.Event{ event := &Event{
WebhookID: webhookID, WebhookID: webhookID,
EntrypointID: uuid.New().String(), EntrypointID: uuid.New().String(),
Method: "POST", Method: "POST",
@@ -92,35 +80,27 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
assert.NotEmpty(t, event.ID) assert.NotEmpty(t, event.ID)
// Verify we can read it back // Verify we can read it back
var readEvent database.Event var readEvent Event
require.NoError(t, db.First(&readEvent, "id = ?", event.ID).Error)
require.NoError(
t,
db.First(&readEvent, "id = ?", event.ID).Error,
)
assert.Equal(t, webhookID, readEvent.WebhookID) assert.Equal(t, webhookID, readEvent.WebhookID)
assert.Equal(t, "POST", readEvent.Method) assert.Equal(t, "POST", readEvent.Method)
assert.Equal(t, `{"test": true}`, readEvent.Body) assert.Equal(t, `{"test": true}`, readEvent.Body)
} }
func TestWebhookDBManager_DeleteDB(t *testing.T) { func TestWebhookDBManager_DeleteDB(t *testing.T) {
t.Parallel()
mgr, lc := setupTestWebhookDBManager(t) mgr, lc := setupTestWebhookDBManager(t)
ctx := context.Background() ctx := context.Background()
require.NoError(t, lc.Start(ctx)) require.NoError(t, lc.Start(ctx))
defer func() { require.NoError(t, lc.Stop(ctx)) }() defer func() { require.NoError(t, lc.Stop(ctx)) }()
webhookID := uuid.New().String() webhookID := uuid.New().String()
// Create the DB and write some data // Create the DB and write some data
require.NoError(t, mgr.CreateDB(webhookID)) require.NoError(t, mgr.CreateDB(webhookID))
db, err := mgr.GetDB(webhookID) db, err := mgr.GetDB(webhookID)
require.NoError(t, err) require.NoError(t, err)
event := &database.Event{ event := &Event{
WebhookID: webhookID, WebhookID: webhookID,
EntrypointID: uuid.New().String(), EntrypointID: uuid.New().String(),
Method: "POST", Method: "POST",
@@ -136,19 +116,15 @@ func TestWebhookDBManager_DeleteDB(t *testing.T) {
assert.False(t, mgr.DBExists(webhookID)) assert.False(t, mgr.DBExists(webhookID))
// Verify the file is actually gone from disk // Verify the file is actually gone from disk
dbPath := mgr.DBPath(webhookID) dbPath := mgr.dbPath(webhookID)
_, err = os.Stat(dbPath) _, err = os.Stat(dbPath)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
} }
func TestWebhookDBManager_LazyCreation(t *testing.T) { func TestWebhookDBManager_LazyCreation(t *testing.T) {
t.Parallel()
mgr, lc := setupTestWebhookDBManager(t) mgr, lc := setupTestWebhookDBManager(t)
ctx := context.Background() ctx := context.Background()
require.NoError(t, lc.Start(ctx)) require.NoError(t, lc.Start(ctx))
defer func() { require.NoError(t, lc.Stop(ctx)) }() defer func() { require.NoError(t, lc.Stop(ctx)) }()
webhookID := uuid.New().String() webhookID := uuid.New().String()
@@ -163,12 +139,9 @@ func TestWebhookDBManager_LazyCreation(t *testing.T) {
} }
func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) { func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
t.Parallel()
mgr, lc := setupTestWebhookDBManager(t) mgr, lc := setupTestWebhookDBManager(t)
ctx := context.Background() ctx := context.Background()
require.NoError(t, lc.Start(ctx)) require.NoError(t, lc.Start(ctx))
defer func() { require.NoError(t, lc.Stop(ctx)) }() defer func() { require.NoError(t, lc.Stop(ctx)) }()
webhookID := uuid.New().String() webhookID := uuid.New().String()
@@ -177,23 +150,8 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
db, err := mgr.GetDB(webhookID) db, err := mgr.GetDB(webhookID)
require.NoError(t, err) require.NoError(t, err)
event, delivery := seedDeliveryWorkflow( // Create an event
t, db, webhookID, targetID, event := &Event{
)
verifyPendingDeliveries(t, db, event)
completeDelivery(t, db, delivery)
verifyNoPending(t, db)
}
func seedDeliveryWorkflow(
t *testing.T,
db *gorm.DB,
webhookID, targetID string,
) (*database.Event, *database.Delivery) {
t.Helper()
event := &database.Event{
WebhookID: webhookID, WebhookID: webhookID,
EntrypointID: uuid.New().String(), EntrypointID: uuid.New().String(),
Method: "POST", Method: "POST",
@@ -203,45 +161,25 @@ func seedDeliveryWorkflow(
} }
require.NoError(t, db.Create(event).Error) require.NoError(t, db.Create(event).Error)
delivery := &database.Delivery{ // Create a delivery
delivery := &Delivery{
EventID: event.ID, EventID: event.ID,
TargetID: targetID, TargetID: targetID,
Status: database.DeliveryStatusPending, Status: DeliveryStatusPending,
} }
require.NoError(t, db.Create(delivery).Error) require.NoError(t, db.Create(delivery).Error)
return event, delivery // Query pending deliveries
} var pending []Delivery
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).
func verifyPendingDeliveries( Preload("Event").
t *testing.T, Find(&pending).Error)
db *gorm.DB,
event *database.Event,
) {
t.Helper()
var pending []database.Delivery
require.NoError(
t,
db.Where(
"status = ?",
database.DeliveryStatusPending,
).Preload("Event").Find(&pending).Error,
)
require.Len(t, pending, 1) require.Len(t, pending, 1)
assert.Equal(t, event.ID, pending[0].EventID) assert.Equal(t, event.ID, pending[0].EventID)
assert.Equal(t, "POST", pending[0].Event.Method) assert.Equal(t, "POST", pending[0].Event.Method)
}
func completeDelivery( // Create a delivery result
t *testing.T, result := &DeliveryResult{
db *gorm.DB,
delivery *database.Delivery,
) {
t.Helper()
result := &database.DeliveryResult{
DeliveryID: delivery.ID, DeliveryID: delivery.ID,
AttemptNum: 1, AttemptNum: 1,
Success: true, Success: true,
@@ -250,40 +188,19 @@ func completeDelivery(
} }
require.NoError(t, db.Create(result).Error) require.NoError(t, db.Create(result).Error)
require.NoError( // Update delivery status
t, require.NoError(t, db.Model(delivery).Update("status", DeliveryStatusDelivered).Error)
db.Model(delivery).Update(
"status",
database.DeliveryStatusDelivered,
).Error,
)
}
func verifyNoPending( // Verify no more pending deliveries
t *testing.T, var stillPending []Delivery
db *gorm.DB, require.NoError(t, db.Where("status = ?", DeliveryStatusPending).Find(&stillPending).Error)
) {
t.Helper()
var stillPending []database.Delivery
require.NoError(
t,
db.Where(
"status = ?",
database.DeliveryStatusPending,
).Find(&stillPending).Error,
)
assert.Empty(t, stillPending) assert.Empty(t, stillPending)
} }
func TestWebhookDBManager_MultipleWebhooks(t *testing.T) { func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
t.Parallel()
mgr, lc := setupTestWebhookDBManager(t) mgr, lc := setupTestWebhookDBManager(t)
ctx := context.Background() ctx := context.Background()
require.NoError(t, lc.Start(ctx)) require.NoError(t, lc.Start(ctx))
defer func() { require.NoError(t, lc.Stop(ctx)) }() defer func() { require.NoError(t, lc.Stop(ctx)) }()
webhook1 := uuid.New().String() webhook1 := uuid.New().String()
@@ -295,38 +212,34 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
db1, err := mgr.GetDB(webhook1) db1, err := mgr.GetDB(webhook1)
require.NoError(t, err) require.NoError(t, err)
db2, err := mgr.GetDB(webhook2) db2, err := mgr.GetDB(webhook2)
require.NoError(t, err) require.NoError(t, err)
// Write events to each webhook's DB // Write events to each webhook's DB
event1 := &database.Event{ event1 := &Event{
WebhookID: webhook1, WebhookID: webhook1,
EntrypointID: uuid.New().String(), EntrypointID: uuid.New().String(),
Method: "POST", Method: "POST",
Body: `{"webhook": 1}`, Body: `{"webhook": 1}`,
ContentType: "application/json", ContentType: "application/json",
} }
event2 := &database.Event{ event2 := &Event{
WebhookID: webhook2, WebhookID: webhook2,
EntrypointID: uuid.New().String(), EntrypointID: uuid.New().String(),
Method: "PUT", Method: "PUT",
Body: `{"webhook": 2}`, Body: `{"webhook": 2}`,
ContentType: "application/json", ContentType: "application/json",
} }
require.NoError(t, db1.Create(event1).Error) require.NoError(t, db1.Create(event1).Error)
require.NoError(t, db2.Create(event2).Error) require.NoError(t, db2.Create(event2).Error)
// Verify isolation: each DB only has its own events // Verify isolation: each DB only has its own events
var count1 int64 var count1 int64
db1.Model(&Event{}).Count(&count1)
db1.Model(&database.Event{}).Count(&count1)
assert.Equal(t, int64(1), count1) assert.Equal(t, int64(1), count1)
var count2 int64 var count2 int64
db2.Model(&Event{}).Count(&count2)
db2.Model(&database.Event{}).Count(&count2)
assert.Equal(t, int64(1), count2) assert.Equal(t, int64(1), count2)
// Delete webhook1's DB, webhook2 should be unaffected // Delete webhook1's DB, webhook2 should be unaffected
@@ -335,31 +248,25 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
assert.True(t, mgr.DBExists(webhook2)) assert.True(t, mgr.DBExists(webhook2))
// webhook2's data should still be accessible // webhook2's data should still be accessible
var events []database.Event var events []Event
require.NoError(t, db2.Find(&events).Error) require.NoError(t, db2.Find(&events).Error)
assert.Len(t, events, 1) assert.Len(t, events, 1)
assert.Equal(t, "PUT", events[0].Method) assert.Equal(t, "PUT", events[0].Method)
} }
func TestWebhookDBManager_CloseAll(t *testing.T) { func TestWebhookDBManager_CloseAll(t *testing.T) {
t.Parallel()
mgr, lc := setupTestWebhookDBManager(t) mgr, lc := setupTestWebhookDBManager(t)
ctx := context.Background() ctx := context.Background()
require.NoError(t, lc.Start(ctx)) require.NoError(t, lc.Start(ctx))
// Create a few DBs // Create a few DBs
for range 3 { for i := 0; i < 3; i++ {
require.NoError( require.NoError(t, mgr.CreateDB(uuid.New().String()))
t,
mgr.CreateDB(uuid.New().String()),
)
} }
// CloseAll should close all connections without error // CloseAll should close all connections without error
require.NoError(t, mgr.CloseAll()) require.NoError(t, mgr.CloseAll())
// Stop lifecycle (CloseAll already called) // Stop lifecycle (CloseAll already called, but shouldn't panic)
require.NoError(t, lc.Stop(ctx)) require.NoError(t, lc.Stop(ctx))
} }

View File

@@ -5,32 +5,41 @@ import (
"time" "time"
) )
// CircuitState represents the current state of a circuit // CircuitState represents the current state of a circuit breaker.
// breaker.
type CircuitState int type CircuitState int
const ( const (
// CircuitClosed is the normal operating state. // CircuitClosed is the normal operating state. Deliveries flow through.
CircuitClosed CircuitState = iota CircuitClosed CircuitState = iota
// CircuitOpen means the circuit has tripped. // CircuitOpen means the circuit has tripped. Deliveries are skipped
// until the cooldown expires.
CircuitOpen CircuitOpen
// CircuitHalfOpen allows a single probe delivery to // CircuitHalfOpen allows a single probe delivery to test whether
// test whether the target has recovered. // the target has recovered.
CircuitHalfOpen CircuitHalfOpen
) )
const ( const (
// defaultFailureThreshold is the number of consecutive // defaultFailureThreshold is the number of consecutive failures
// failures before a circuit breaker trips open. // before a circuit breaker trips open.
defaultFailureThreshold = 5 defaultFailureThreshold = 5
// defaultCooldown is how long a circuit stays open // defaultCooldown is how long a circuit stays open before
// before transitioning to half-open. // transitioning to half-open for a probe delivery.
defaultCooldown = 30 * time.Second defaultCooldown = 30 * time.Second
) )
// CircuitBreaker implements the circuit breaker pattern // CircuitBreaker implements the circuit breaker pattern for a single
// for a single delivery target. // delivery target. It tracks consecutive failures and prevents
// hammering a down target by temporarily stopping delivery attempts.
//
// States:
// - Closed (normal): deliveries flow through; consecutive failures
// are counted.
// - Open (tripped): deliveries are skipped; a cooldown timer is
// running. After the cooldown expires the state moves to HalfOpen.
// - HalfOpen (probing): one probe delivery is allowed. If it
// succeeds the circuit closes; if it fails the circuit reopens.
type CircuitBreaker struct { type CircuitBreaker struct {
mu sync.Mutex mu sync.Mutex
state CircuitState state CircuitState
@@ -40,8 +49,7 @@ type CircuitBreaker struct {
lastFailure time.Time lastFailure time.Time
} }
// NewCircuitBreaker creates a circuit breaker with default // NewCircuitBreaker creates a circuit breaker with default settings.
// settings.
func NewCircuitBreaker() *CircuitBreaker { func NewCircuitBreaker() *CircuitBreaker {
return &CircuitBreaker{ return &CircuitBreaker{
state: CircuitClosed, state: CircuitClosed,
@@ -50,7 +58,12 @@ func NewCircuitBreaker() *CircuitBreaker {
} }
} }
// Allow checks whether a delivery attempt should proceed. // Allow checks whether a delivery attempt should proceed. It returns
// true if the delivery should be attempted, false if the circuit is
// open and the delivery should be skipped.
//
// When the circuit is open and the cooldown has elapsed, Allow
// transitions to half-open and permits exactly one probe delivery.
func (cb *CircuitBreaker) Allow() bool { func (cb *CircuitBreaker) Allow() bool {
cb.mu.Lock() cb.mu.Lock()
defer cb.mu.Unlock() defer cb.mu.Unlock()
@@ -60,15 +73,17 @@ func (cb *CircuitBreaker) Allow() bool {
return true return true
case CircuitOpen: case CircuitOpen:
// Check if cooldown has elapsed
if time.Since(cb.lastFailure) >= cb.cooldown { if time.Since(cb.lastFailure) >= cb.cooldown {
cb.state = CircuitHalfOpen cb.state = CircuitHalfOpen
return true return true
} }
return false return false
case CircuitHalfOpen: case CircuitHalfOpen:
// Only one probe at a time — reject additional attempts while
// a probe is in flight. The probe goroutine will call
// RecordSuccess or RecordFailure to resolve the state.
return false return false
default: default:
@@ -76,8 +91,9 @@ func (cb *CircuitBreaker) Allow() bool {
} }
} }
// CooldownRemaining returns how much time is left before // CooldownRemaining returns how much time is left before an open circuit
// an open circuit transitions to half-open. // transitions to half-open. Returns zero if the circuit is not open or
// the cooldown has already elapsed.
func (cb *CircuitBreaker) CooldownRemaining() time.Duration { func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
cb.mu.Lock() cb.mu.Lock()
defer cb.mu.Unlock() defer cb.mu.Unlock()
@@ -90,12 +106,11 @@ func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
if remaining < 0 { if remaining < 0 {
return 0 return 0
} }
return remaining return remaining
} }
// RecordSuccess records a successful delivery and resets // RecordSuccess records a successful delivery and resets the circuit
// the circuit breaker to closed state. // breaker to closed state with zero failures.
func (cb *CircuitBreaker) RecordSuccess() { func (cb *CircuitBreaker) RecordSuccess() {
cb.mu.Lock() cb.mu.Lock()
defer cb.mu.Unlock() defer cb.mu.Unlock()
@@ -104,8 +119,8 @@ func (cb *CircuitBreaker) RecordSuccess() {
cb.state = CircuitClosed cb.state = CircuitClosed
} }
// RecordFailure records a failed delivery. If the failure // RecordFailure records a failed delivery. If the failure count reaches
// count reaches the threshold, the circuit trips open. // the threshold, the circuit trips open.
func (cb *CircuitBreaker) RecordFailure() { func (cb *CircuitBreaker) RecordFailure() {
cb.mu.Lock() cb.mu.Lock()
defer cb.mu.Unlock() defer cb.mu.Unlock()
@@ -119,25 +134,20 @@ func (cb *CircuitBreaker) RecordFailure() {
cb.state = CircuitOpen cb.state = CircuitOpen
} }
case CircuitOpen:
// Already open; no state change needed.
case CircuitHalfOpen: case CircuitHalfOpen:
// Probe failed -- reopen immediately. // Probe failed reopen immediately
cb.state = CircuitOpen cb.state = CircuitOpen
} }
} }
// State returns the current circuit state. // State returns the current circuit state. Safe for concurrent use.
func (cb *CircuitBreaker) State() CircuitState { func (cb *CircuitBreaker) State() CircuitState {
cb.mu.Lock() cb.mu.Lock()
defer cb.mu.Unlock() defer cb.mu.Unlock()
return cb.state return cb.state
} }
// String returns the human-readable name of a circuit // String returns the human-readable name of a circuit state.
// state.
func (s CircuitState) String() string { func (s CircuitState) String() string {
switch s { switch s {
case CircuitClosed: case CircuitClosed:

View File

@@ -1,4 +1,4 @@
package delivery_test package delivery
import ( import (
"sync" "sync"
@@ -7,304 +7,237 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"sneak.berlin/go/webhooker/internal/delivery"
) )
func TestCircuitBreaker_ClosedState_AllowsDeliveries( func TestCircuitBreaker_ClosedState_AllowsDeliveries(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := NewCircuitBreaker()
cb := delivery.NewCircuitBreaker() assert.Equal(t, CircuitClosed, cb.State())
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
assert.Equal(t, delivery.CircuitClosed, cb.State()) // Multiple calls should all succeed
assert.True(t, cb.Allow(), for i := 0; i < 10; i++ {
"closed circuit should allow deliveries",
)
for range 10 {
assert.True(t, cb.Allow()) assert.True(t, cb.Allow())
} }
} }
func TestCircuitBreaker_FailureCounting(t *testing.T) { func TestCircuitBreaker_FailureCounting(t *testing.T) {
t.Parallel() t.Parallel()
cb := NewCircuitBreaker()
cb := delivery.NewCircuitBreaker() // Record failures below threshold — circuit should stay closed
for i := 0; i < defaultFailureThreshold-1; i++ {
for i := range delivery.ExportDefaultFailureThreshold - 1 {
cb.RecordFailure() cb.RecordFailure()
assert.Equal(t, CircuitClosed, cb.State(),
assert.Equal(t, "circuit should remain closed after %d failures", i+1)
delivery.CircuitClosed, cb.State(), assert.True(t, cb.Allow(), "should still allow after %d failures", i+1)
"circuit should remain closed after %d failures",
i+1,
)
assert.True(t, cb.Allow(),
"should still allow after %d failures",
i+1,
)
} }
} }
func TestCircuitBreaker_OpenTransition(t *testing.T) { func TestCircuitBreaker_OpenTransition(t *testing.T) {
t.Parallel() t.Parallel()
cb := NewCircuitBreaker()
cb := delivery.NewCircuitBreaker() // Record exactly threshold failures
for i := 0; i < defaultFailureThreshold; i++ {
for range delivery.ExportDefaultFailureThreshold {
cb.RecordFailure() cb.RecordFailure()
} }
assert.Equal(t, delivery.CircuitOpen, cb.State(), assert.Equal(t, CircuitOpen, cb.State(), "circuit should be open after threshold failures")
"circuit should be open after threshold failures", assert.False(t, cb.Allow(), "open circuit should reject deliveries")
)
assert.False(t, cb.Allow(),
"open circuit should reject deliveries",
)
} }
func TestCircuitBreaker_Cooldown_StaysOpen(t *testing.T) { func TestCircuitBreaker_Cooldown_StaysOpen(t *testing.T) {
t.Parallel() t.Parallel()
// Use a circuit with a known short cooldown for testing
cb := delivery.NewCircuitBreaker() cb := &CircuitBreaker{
state: CircuitClosed,
for range delivery.ExportDefaultFailureThreshold { threshold: defaultFailureThreshold,
cb.RecordFailure() cooldown: 200 * time.Millisecond,
} }
require.Equal(t, delivery.CircuitOpen, cb.State()) // Trip the circuit open
for i := 0; i < defaultFailureThreshold; i++ {
cb.RecordFailure()
}
require.Equal(t, CircuitOpen, cb.State())
assert.False(t, cb.Allow(), // During cooldown, Allow should return false
"should be blocked during cooldown", assert.False(t, cb.Allow(), "should be blocked during cooldown")
)
// CooldownRemaining should be positive
remaining := cb.CooldownRemaining() remaining := cb.CooldownRemaining()
assert.Greater(t, remaining, time.Duration(0), "cooldown should have remaining time")
assert.Greater(t, remaining, time.Duration(0),
"cooldown should have remaining time",
)
} }
func TestCircuitBreaker_HalfOpen_AfterCooldown( func TestCircuitBreaker_HalfOpen_AfterCooldown(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := &CircuitBreaker{
state: CircuitClosed,
threshold: defaultFailureThreshold,
cooldown: 50 * time.Millisecond,
}
cb := newShortCooldownCB(t) // Trip the circuit open
for i := 0; i < defaultFailureThreshold; i++ {
for range delivery.ExportDefaultFailureThreshold {
cb.RecordFailure() cb.RecordFailure()
} }
require.Equal(t, CircuitOpen, cb.State())
require.Equal(t, delivery.CircuitOpen, cb.State()) // Wait for cooldown to expire
time.Sleep(60 * time.Millisecond) time.Sleep(60 * time.Millisecond)
assert.Equal(t, time.Duration(0), // CooldownRemaining should be zero after cooldown
cb.CooldownRemaining(), assert.Equal(t, time.Duration(0), cb.CooldownRemaining())
)
assert.True(t, cb.Allow(), // First Allow after cooldown should succeed (probe)
"should allow one probe after cooldown", assert.True(t, cb.Allow(), "should allow one probe after cooldown")
) assert.Equal(t, CircuitHalfOpen, cb.State(), "should be half-open after probe allowed")
assert.Equal(t, // Second Allow should be rejected (only one probe at a time)
delivery.CircuitHalfOpen, cb.State(), assert.False(t, cb.Allow(), "should reject additional probes while half-open")
"should be half-open after probe allowed",
)
assert.False(t, cb.Allow(),
"should reject additional probes while half-open",
)
} }
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit( func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := &CircuitBreaker{
state: CircuitClosed,
threshold: defaultFailureThreshold,
cooldown: 50 * time.Millisecond,
}
cb := newShortCooldownCB(t) // Trip open → wait for cooldown → allow probe
for i := 0; i < defaultFailureThreshold; i++ {
for range delivery.ExportDefaultFailureThreshold {
cb.RecordFailure() cb.RecordFailure()
} }
time.Sleep(60 * time.Millisecond) time.Sleep(60 * time.Millisecond)
require.True(t, cb.Allow()) // probe allowed, state → half-open
require.True(t, cb.Allow()) // Probe succeeds → circuit should close
cb.RecordSuccess() cb.RecordSuccess()
assert.Equal(t, CircuitClosed, cb.State(), "successful probe should close circuit")
assert.Equal(t, delivery.CircuitClosed, cb.State(), // Should allow deliveries again
"successful probe should close circuit", assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
)
assert.True(t, cb.Allow(),
"closed circuit should allow deliveries",
)
} }
func TestCircuitBreaker_ProbeFailure_ReopensCircuit( func TestCircuitBreaker_ProbeFailure_ReopensCircuit(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := &CircuitBreaker{
state: CircuitClosed,
threshold: defaultFailureThreshold,
cooldown: 50 * time.Millisecond,
}
cb := newShortCooldownCB(t) // Trip open → wait for cooldown → allow probe
for i := 0; i < defaultFailureThreshold; i++ {
for range delivery.ExportDefaultFailureThreshold {
cb.RecordFailure() cb.RecordFailure()
} }
time.Sleep(60 * time.Millisecond) time.Sleep(60 * time.Millisecond)
require.True(t, cb.Allow()) // probe allowed, state → half-open
require.True(t, cb.Allow()) // Probe fails → circuit should reopen
cb.RecordFailure() cb.RecordFailure()
assert.Equal(t, CircuitOpen, cb.State(), "failed probe should reopen circuit")
assert.Equal(t, delivery.CircuitOpen, cb.State(), assert.False(t, cb.Allow(), "reopened circuit should reject deliveries")
"failed probe should reopen circuit",
)
assert.False(t, cb.Allow(),
"reopened circuit should reject deliveries",
)
} }
func TestCircuitBreaker_SuccessResetsFailures( func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := NewCircuitBreaker()
cb := delivery.NewCircuitBreaker() // Accumulate failures just below threshold
for i := 0; i < defaultFailureThreshold-1; i++ {
for range delivery.ExportDefaultFailureThreshold - 1 {
cb.RecordFailure() cb.RecordFailure()
} }
require.Equal(t, CircuitClosed, cb.State())
require.Equal(t, delivery.CircuitClosed, cb.State()) // Success should reset the failure counter
cb.RecordSuccess() cb.RecordSuccess()
assert.Equal(t, CircuitClosed, cb.State())
assert.Equal(t, delivery.CircuitClosed, cb.State()) // Now we should need another full threshold of failures to trip
for i := 0; i < defaultFailureThreshold-1; i++ {
for range delivery.ExportDefaultFailureThreshold - 1 {
cb.RecordFailure() cb.RecordFailure()
} }
assert.Equal(t, CircuitClosed, cb.State(),
"circuit should still be closed — success reset the counter")
assert.Equal(t, delivery.CircuitClosed, cb.State(), // One more failure should trip it
"circuit should still be closed -- "+
"success reset the counter",
)
cb.RecordFailure() cb.RecordFailure()
assert.Equal(t, CircuitOpen, cb.State())
assert.Equal(t, delivery.CircuitOpen, cb.State())
} }
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) { func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
t.Parallel() t.Parallel()
cb := NewCircuitBreaker()
cb := delivery.NewCircuitBreaker()
const goroutines = 100 const goroutines = 100
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(goroutines * 3) wg.Add(goroutines * 3)
for range goroutines { // Concurrent Allow calls
for i := 0; i < goroutines; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
cb.Allow() cb.Allow()
}() }()
} }
for range goroutines { // Concurrent RecordFailure calls
for i := 0; i < goroutines; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
cb.RecordFailure() cb.RecordFailure()
}() }()
} }
for range goroutines { // Concurrent RecordSuccess calls
for i := 0; i < goroutines; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
cb.RecordSuccess() cb.RecordSuccess()
}() }()
} }
wg.Wait() wg.Wait()
// No panic or data race — the test passes if -race doesn't flag anything.
// State should be one of the valid states.
state := cb.State() state := cb.State()
assert.Contains(t, []CircuitState{CircuitClosed, CircuitOpen, CircuitHalfOpen}, state,
assert.Contains(t, "state should be valid after concurrent access")
[]delivery.CircuitState{
delivery.CircuitClosed,
delivery.CircuitOpen,
delivery.CircuitHalfOpen,
},
state,
"state should be valid after concurrent access",
)
} }
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero( func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := NewCircuitBreaker()
cb := delivery.NewCircuitBreaker() assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
"closed circuit should have zero cooldown remaining")
assert.Equal(t, time.Duration(0),
cb.CooldownRemaining(),
"closed circuit should have zero cooldown remaining",
)
} }
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero( func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
cb := &CircuitBreaker{
state: CircuitClosed,
threshold: defaultFailureThreshold,
cooldown: 50 * time.Millisecond,
}
cb := newShortCooldownCB(t) // Trip open, wait, transition to half-open
for i := 0; i < defaultFailureThreshold; i++ {
for range delivery.ExportDefaultFailureThreshold {
cb.RecordFailure() cb.RecordFailure()
} }
time.Sleep(60 * time.Millisecond) time.Sleep(60 * time.Millisecond)
require.True(t, cb.Allow()) // → half-open
require.True(t, cb.Allow()) assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
"half-open circuit should have zero cooldown remaining")
assert.Equal(t, time.Duration(0),
cb.CooldownRemaining(),
"half-open circuit should have zero cooldown remaining",
)
} }
func TestCircuitState_String(t *testing.T) { func TestCircuitState_String(t *testing.T) {
t.Parallel() t.Parallel()
assert.Equal(t, "closed", CircuitClosed.String())
assert.Equal(t, "closed", delivery.CircuitClosed.String()) assert.Equal(t, "open", CircuitOpen.String())
assert.Equal(t, "open", delivery.CircuitOpen.String()) assert.Equal(t, "half-open", CircuitHalfOpen.String())
assert.Equal(t, "half-open", delivery.CircuitHalfOpen.String()) assert.Equal(t, "unknown", CircuitState(99).String())
assert.Equal(t, "unknown", delivery.CircuitState(99).String())
}
// newShortCooldownCB creates a CircuitBreaker with a short
// cooldown for testing. We use NewCircuitBreaker and
// manipulate through the public API.
func newShortCooldownCB(t *testing.T) *delivery.CircuitBreaker {
t.Helper()
return delivery.NewTestCircuitBreaker(
delivery.ExportDefaultFailureThreshold,
50*time.Millisecond,
)
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,240 +0,0 @@
package delivery
import (
"context"
"log/slog"
"net"
"net/http"
"time"
"gorm.io/gorm"
"sneak.berlin/go/webhooker/internal/database"
)
// Exported constants for test access.
const (
ExportDeliveryChannelSize = deliveryChannelSize
ExportRetryChannelSize = retryChannelSize
ExportDefaultFailureThreshold = defaultFailureThreshold
ExportDefaultCooldown = defaultCooldown
)
// ExportIsBlockedIP exposes isBlockedIP for testing.
func ExportIsBlockedIP(ip net.IP) bool {
return isBlockedIP(ip)
}
// ExportBlockedNetworks exposes blockedNetworks.
func ExportBlockedNetworks() []*net.IPNet {
return blockedNetworks
}
// ExportIsForwardableHeader exposes isForwardableHeader.
func ExportIsForwardableHeader(name string) bool {
return isForwardableHeader(name)
}
// ExportTruncate exposes truncate for testing.
func ExportTruncate(s string, maxLen int) string {
return truncate(s, maxLen)
}
// ExportDeliverHTTP exposes deliverHTTP for testing.
func (e *Engine) ExportDeliverHTTP(
ctx context.Context,
webhookDB *gorm.DB,
d *database.Delivery,
task *Task,
) {
e.deliverHTTP(ctx, webhookDB, d, task)
}
// ExportDeliverDatabase exposes deliverDatabase.
func (e *Engine) ExportDeliverDatabase(
webhookDB *gorm.DB, d *database.Delivery,
) {
e.deliverDatabase(webhookDB, d)
}
// ExportDeliverLog exposes deliverLog for testing.
func (e *Engine) ExportDeliverLog(
webhookDB *gorm.DB, d *database.Delivery,
) {
e.deliverLog(webhookDB, d)
}
// ExportDeliverSlack exposes deliverSlack for testing.
func (e *Engine) ExportDeliverSlack(
ctx context.Context,
webhookDB *gorm.DB,
d *database.Delivery,
) {
e.deliverSlack(ctx, webhookDB, d)
}
// ExportProcessNewTask exposes processNewTask.
func (e *Engine) ExportProcessNewTask(
ctx context.Context, task *Task,
) {
e.processNewTask(ctx, task)
}
// ExportProcessRetryTask exposes processRetryTask.
func (e *Engine) ExportProcessRetryTask(
ctx context.Context, task *Task,
) {
e.processRetryTask(ctx, task)
}
// ExportProcessDelivery exposes processDelivery.
func (e *Engine) ExportProcessDelivery(
ctx context.Context,
webhookDB *gorm.DB,
d *database.Delivery,
task *Task,
) {
e.processDelivery(ctx, webhookDB, d, task)
}
// ExportGetCircuitBreaker exposes getCircuitBreaker.
func (e *Engine) ExportGetCircuitBreaker(
targetID string,
) *CircuitBreaker {
return e.getCircuitBreaker(targetID)
}
// ExportParseHTTPConfig exposes parseHTTPConfig.
func (e *Engine) ExportParseHTTPConfig(
configJSON string,
) (*HTTPTargetConfig, error) {
return e.parseHTTPConfig(configJSON)
}
// ExportParseSlackConfig exposes parseSlackConfig.
func (e *Engine) ExportParseSlackConfig(
configJSON string,
) (*SlackTargetConfig, error) {
return e.parseSlackConfig(configJSON)
}
// ExportDoHTTPRequest exposes doHTTPRequest.
func (e *Engine) ExportDoHTTPRequest(
ctx context.Context,
cfg *HTTPTargetConfig,
event *database.Event,
) (int, string, int64, error) {
return e.doHTTPRequest(ctx, cfg, event)
}
// ExportScheduleRetry exposes scheduleRetry.
func (e *Engine) ExportScheduleRetry(
task Task, delay time.Duration,
) {
e.scheduleRetry(task, delay)
}
// ExportRecoverPendingDeliveries exposes
// recoverPendingDeliveries.
func (e *Engine) ExportRecoverPendingDeliveries(
ctx context.Context,
webhookDB *gorm.DB,
webhookID string,
) {
e.recoverPendingDeliveries(
ctx, webhookDB, webhookID,
)
}
// ExportRecoverWebhookDeliveries exposes
// recoverWebhookDeliveries.
func (e *Engine) ExportRecoverWebhookDeliveries(
ctx context.Context, webhookID string,
) {
e.recoverWebhookDeliveries(ctx, webhookID)
}
// ExportRecoverInFlight exposes recoverInFlight.
func (e *Engine) ExportRecoverInFlight(
ctx context.Context,
) {
e.recoverInFlight(ctx)
}
// ExportStart exposes start for testing.
func (e *Engine) ExportStart(ctx context.Context) {
e.start(ctx)
}
// ExportStop exposes stop for testing.
func (e *Engine) ExportStop() {
e.stop()
}
// ExportDeliveryCh returns the delivery channel.
func (e *Engine) ExportDeliveryCh() chan Task {
return e.deliveryCh
}
// ExportRetryCh returns the retry channel.
func (e *Engine) ExportRetryCh() chan Task {
return e.retryCh
}
// NewTestEngine creates an Engine for unit tests without
// database dependencies.
func NewTestEngine(
log *slog.Logger,
client *http.Client,
workers int,
) *Engine {
return &Engine{
log: log,
client: client,
deliveryCh: make(chan Task, deliveryChannelSize),
retryCh: make(chan Task, retryChannelSize),
workers: workers,
}
}
// NewTestEngineSmallRetry creates an Engine with a tiny
// retry channel buffer for overflow testing.
func NewTestEngineSmallRetry(
log *slog.Logger,
) *Engine {
return &Engine{
log: log,
retryCh: make(chan Task, 1),
}
}
// NewTestEngineWithDB creates an Engine with a real
// database and dbManager for integration tests.
func NewTestEngineWithDB(
db *database.Database,
dbMgr *database.WebhookDBManager,
log *slog.Logger,
client *http.Client,
workers int,
) *Engine {
return &Engine{
database: db,
dbManager: dbMgr,
log: log,
client: client,
deliveryCh: make(chan Task, deliveryChannelSize),
retryCh: make(chan Task, retryChannelSize),
workers: workers,
}
}
// NewTestCircuitBreaker creates a CircuitBreaker with
// custom settings for testing.
func NewTestCircuitBreaker(
threshold int, cooldown time.Duration,
) *CircuitBreaker {
return &CircuitBreaker{
state: CircuitClosed,
threshold: threshold,
cooldown: cooldown,
}
}

View File

@@ -1,222 +0,0 @@
package delivery
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"time"
)
const (
// dnsResolutionTimeout is the maximum time to wait for
// DNS resolution during SSRF validation.
dnsResolutionTimeout = 5 * time.Second
)
// Sentinel errors for SSRF validation.
var (
errNoHostname = errors.New("URL has no hostname")
errNoIPs = errors.New(
"hostname resolved to no IP addresses",
)
errBlockedIP = errors.New(
"blocked private/reserved IP range",
)
errInvalidScheme = errors.New(
"only http and https are allowed",
)
)
// blockedNetworks contains all private/reserved IP ranges
// that should be blocked to prevent SSRF attacks.
//
//nolint:gochecknoglobals // package-level network list is appropriate here
var blockedNetworks []*net.IPNet
//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup
func init() {
cidrs := []string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"0.0.0.0/8",
"100.64.0.0/10",
"192.0.0.0/24",
"192.0.2.0/24",
"198.18.0.0/15",
"198.51.100.0/24",
"203.0.113.0/24",
"224.0.0.0/4",
"240.0.0.0/4",
"::1/128",
"fc00::/7",
"fe80::/10",
}
for _, cidr := range cidrs {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
panic(fmt.Sprintf(
"ssrf: failed to parse CIDR %q: %v",
cidr, err,
))
}
blockedNetworks = append(
blockedNetworks, network,
)
}
}
// isBlockedIP checks whether an IP address falls within
// any blocked private/reserved network range.
func isBlockedIP(ip net.IP) bool {
for _, network := range blockedNetworks {
if network.Contains(ip) {
return true
}
}
return false
}
// ValidateTargetURL checks that an HTTP delivery target
// URL is safe from SSRF attacks.
func ValidateTargetURL(
ctx context.Context, targetURL string,
) error {
parsed, err := url.Parse(targetURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
err = validateScheme(parsed.Scheme)
if err != nil {
return err
}
host := parsed.Hostname()
if host == "" {
return errNoHostname
}
if ip := net.ParseIP(host); ip != nil {
return checkBlockedIP(ip)
}
return validateHostname(ctx, host)
}
func validateScheme(scheme string) error {
if scheme != "http" && scheme != "https" {
return fmt.Errorf(
"unsupported URL scheme %q: %w",
scheme, errInvalidScheme,
)
}
return nil
}
func checkBlockedIP(ip net.IP) error {
if isBlockedIP(ip) {
return fmt.Errorf(
"target IP %s is in a blocked "+
"private/reserved range: %w",
ip, errBlockedIP,
)
}
return nil
}
func validateHostname(
ctx context.Context, host string,
) error {
dnsCtx, cancel := context.WithTimeout(
ctx, dnsResolutionTimeout,
)
defer cancel()
ips, err := net.DefaultResolver.LookupIPAddr(
dnsCtx, host,
)
if err != nil {
return fmt.Errorf(
"failed to resolve hostname %q: %w",
host, err,
)
}
if len(ips) == 0 {
return fmt.Errorf(
"hostname %q: %w", host, errNoIPs,
)
}
for _, ipAddr := range ips {
if isBlockedIP(ipAddr.IP) {
return fmt.Errorf(
"hostname %q resolves to blocked "+
"IP %s: %w",
host, ipAddr.IP, errBlockedIP,
)
}
}
return nil
}
// NewSSRFSafeTransport creates an http.Transport with a
// custom DialContext that blocks connections to
// private/reserved IP addresses.
func NewSSRFSafeTransport() *http.Transport {
return &http.Transport{
DialContext: ssrfDialContext,
}
}
func ssrfDialContext(
ctx context.Context,
network, addr string,
) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf(
"ssrf: invalid address %q: %w",
addr, err,
)
}
ips, err := net.DefaultResolver.LookupIPAddr(
ctx, host,
)
if err != nil {
return nil, fmt.Errorf(
"ssrf: DNS resolution failed for %q: %w",
host, err,
)
}
for _, ipAddr := range ips {
if isBlockedIP(ipAddr.IP) {
return nil, fmt.Errorf(
"ssrf: connection to %s (%s) "+
"blocked: %w",
host, ipAddr.IP, errBlockedIP,
)
}
}
var dialer net.Dialer
return dialer.DialContext(
ctx, network,
net.JoinHostPort(ips[0].IP.String(), port),
)
}

View File

@@ -1,172 +0,0 @@
package delivery_test
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/webhooker/internal/delivery"
)
func TestIsBlockedIP_PrivateRanges(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ip string
blocked bool
}{
{"loopback 127.0.0.1", "127.0.0.1", true},
{"loopback 127.0.0.2", "127.0.0.2", true},
{"loopback 127.255.255.255", "127.255.255.255", true},
{"10.0.0.0", "10.0.0.0", true},
{"10.0.0.1", "10.0.0.1", true},
{"10.255.255.255", "10.255.255.255", true},
{"172.16.0.1", "172.16.0.1", true},
{"172.31.255.255", "172.31.255.255", true},
{"172.15.255.255", "172.15.255.255", false},
{"172.32.0.0", "172.32.0.0", false},
{"192.168.0.1", "192.168.0.1", true},
{"192.168.255.255", "192.168.255.255", true},
{"169.254.0.1", "169.254.0.1", true},
{"169.254.169.254", "169.254.169.254", true},
{"8.8.8.8", "8.8.8.8", false},
{"1.1.1.1", "1.1.1.1", false},
{"93.184.216.34", "93.184.216.34", false},
{"::1", "::1", true},
{"fd00::1", "fd00::1", true},
{"fc00::1", "fc00::1", true},
{"fe80::1", "fe80::1", true},
{
"2607:f8b0:4004:800::200e",
"2607:f8b0:4004:800::200e",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(tt.ip)
require.NotNil(t, ip,
"failed to parse IP %s", tt.ip,
)
assert.Equal(t,
tt.blocked,
delivery.ExportIsBlockedIP(ip),
"isBlockedIP(%s) = %v, want %v",
tt.ip,
delivery.ExportIsBlockedIP(ip),
tt.blocked,
)
})
}
}
func TestValidateTargetURL_Blocked(t *testing.T) {
t.Parallel()
blockedURLs := []string{
"http://127.0.0.1/hook",
"http://127.0.0.1:8080/hook",
"https://10.0.0.1/hook",
"http://192.168.1.1/webhook",
"http://172.16.0.1/api",
"http://169.254.169.254/latest/meta-data/",
"http://[::1]/hook",
"http://[fc00::1]/hook",
"http://[fe80::1]/hook",
"http://0.0.0.0/hook",
}
for _, u := range blockedURLs {
t.Run(u, func(t *testing.T) {
t.Parallel()
err := delivery.ValidateTargetURL(
context.Background(), u,
)
assert.Error(t, err,
"URL %s should be blocked", u,
)
})
}
}
func TestValidateTargetURL_Allowed(t *testing.T) {
t.Parallel()
allowedURLs := []string{
"https://example.com/hook",
"http://93.184.216.34/webhook",
"https://hooks.slack.com/services/T00/B00/xxx",
}
for _, u := range allowedURLs {
t.Run(u, func(t *testing.T) {
t.Parallel()
err := delivery.ValidateTargetURL(
context.Background(), u,
)
assert.NoError(t, err,
"URL %s should be allowed", u,
)
})
}
}
func TestValidateTargetURL_InvalidScheme(t *testing.T) {
t.Parallel()
err := delivery.ValidateTargetURL(
context.Background(), "ftp://example.com/hook",
)
require.Error(t, err)
assert.Contains(t, err.Error(),
"unsupported URL scheme",
)
}
func TestValidateTargetURL_EmptyHost(t *testing.T) {
t.Parallel()
err := delivery.ValidateTargetURL(
context.Background(), "http:///path",
)
assert.Error(t, err)
}
func TestValidateTargetURL_InvalidURL(t *testing.T) {
t.Parallel()
err := delivery.ValidateTargetURL(
context.Background(), "://invalid",
)
assert.Error(t, err)
}
func TestBlockedNetworks_Initialized(t *testing.T) {
t.Parallel()
nets := delivery.ExportBlockedNetworks()
assert.NotEmpty(t, nets,
"blockedNetworks should be initialized",
)
assert.GreaterOrEqual(t, len(nets), 8,
"should have at least 8 blocked network ranges",
)
}

View File

@@ -1,34 +1,25 @@
// Package globals provides build-time variables injected via ldflags.
package globals package globals
import ( import (
"go.uber.org/fx" "go.uber.org/fx"
) )
// Build-time variables populated from main() and copied into the // these get populated from main() and copied into the Globals object.
// Globals object.
//
//nolint:gochecknoglobals // Build-time variables set by main().
var ( var (
Appname string Appname string
Version string Version string
) )
// Globals holds build-time metadata about the application.
type Globals struct { type Globals struct {
Appname string Appname string
Version string Version string
} }
// New creates a Globals instance from the package-level // nolint:revive // lc parameter is required by fx even if unused
// build-time variables.
//
//nolint:revive // lc parameter is required by fx even if unused.
func New(lc fx.Lifecycle) (*Globals, error) { func New(lc fx.Lifecycle) (*Globals, error) {
n := &Globals{ n := &Globals{
Appname: Appname, Appname: Appname,
Version: Version, Version: Version,
} }
return n, nil return n, nil
} }

View File

@@ -1,30 +1,26 @@
package globals_test package globals
import ( import (
"testing" "testing"
"sneak.berlin/go/webhooker/internal/globals" "go.uber.org/fx/fxtest"
) )
func TestGlobalsFields(t *testing.T) { func TestNew(t *testing.T) {
t.Parallel() // Set test values
Appname = "test-app"
Version = "1.0.0"
g := &globals.Globals{ lc := fxtest.NewLifecycle(t)
Appname: "test-app", globals, err := New(lc)
Version: "1.0.0", if err != nil {
t.Fatalf("New() error = %v", err)
} }
if g.Appname != "test-app" { if globals.Appname != "test-app" {
t.Errorf( t.Errorf("Appname = %v, want %v", globals.Appname, "test-app")
"Appname = %v, want %v",
g.Appname, "test-app",
)
} }
if globals.Version != "1.0.0" {
if g.Version != "1.0.0" { t.Errorf("Version = %v, want %v", globals.Version, "1.0.0")
t.Errorf(
"Version = %v, want %v",
g.Version, "1.0.0",
)
} }
} }

View File

@@ -13,12 +13,11 @@ func (h *Handlers) HandleLoginPage() http.HandlerFunc {
sess, err := h.session.Get(r) sess, err := h.session.Get(r)
if err == nil && h.session.IsAuthenticated(sess) { if err == nil && h.session.IsAuthenticated(sess) {
http.Redirect(w, r, "/", http.StatusSeeOther) http.Redirect(w, r, "/", http.StatusSeeOther)
return return
} }
// Render login page // Render login page
data := map[string]any{ data := map[string]interface{}{
"Error": "", "Error": "",
} }
@@ -29,15 +28,10 @@ func (h *Handlers) HandleLoginPage() http.HandlerFunc {
// HandleLoginSubmit handles the login form submission (POST) // HandleLoginSubmit handles the login form submission (POST)
func (h *Handlers) HandleLoginSubmit() http.HandlerFunc { func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Limit request body to prevent memory exhaustion
r.Body = http.MaxBytesReader(w, r.Body, 1<<maxBodyShift)
// Parse form data // Parse form data
err := r.ParseForm() if err := r.ParseForm(); err != nil {
if err != nil {
h.log.Error("failed to parse form", "error", err) h.log.Error("failed to parse form", "error", err)
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
} }
@@ -46,147 +40,76 @@ func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
// Validate input // Validate input
if username == "" || password == "" { if username == "" || password == "" {
h.renderLoginError( data := map[string]interface{}{
w, r, "Error": "Username and password are required",
"Username and password are required",
http.StatusBadRequest,
)
return
} }
w.WriteHeader(http.StatusBadRequest)
user, err := h.authenticateUser(
w, r, username, password,
)
if err != nil {
return
}
err = h.createAuthenticatedSession(w, r, user)
if err != nil {
return
}
h.log.Info(
"user logged in",
"username", username,
"user_id", user.ID,
)
// Redirect to home page
http.Redirect(w, r, "/", http.StatusSeeOther)
}
}
// renderLoginError renders the login page with an error message.
func (h *Handlers) renderLoginError(
w http.ResponseWriter,
r *http.Request,
msg string,
status int,
) {
data := map[string]any{
"Error": msg,
}
w.WriteHeader(status)
h.renderTemplate(w, r, "login.html", data) h.renderTemplate(w, r, "login.html", data)
return
} }
// authenticateUser looks up and verifies a user's credentials. // Find user in database
// On failure it writes an HTTP response and returns an error.
func (h *Handlers) authenticateUser(
w http.ResponseWriter,
r *http.Request,
username, password string,
) (database.User, error) {
var user database.User var user database.User
if err := h.db.DB().Where("username = ?", username).First(&user).Error; err != nil {
err := h.db.DB().Where(
"username = ?", username,
).First(&user).Error
if err != nil {
h.log.Debug("user not found", "username", username) h.log.Debug("user not found", "username", username)
h.renderLoginError( data := map[string]interface{}{
w, r, "Error": "Invalid username or password",
"Invalid username or password", }
http.StatusUnauthorized, w.WriteHeader(http.StatusUnauthorized)
) h.renderTemplate(w, r, "login.html", data)
return
return user, err
} }
// Verify password
valid, err := database.VerifyPassword(password, user.Password) valid, err := database.VerifyPassword(password, user.Password)
if err != nil { if err != nil {
h.log.Error("failed to verify password", "error", err) h.log.Error("failed to verify password", "error", err)
http.Error( http.Error(w, "Internal server error", http.StatusInternalServerError)
w, "Internal server error", return
http.StatusInternalServerError,
)
return user, err
} }
if !valid { if !valid {
h.log.Debug("invalid password", "username", username) h.log.Debug("invalid password", "username", username)
h.renderLoginError( data := map[string]interface{}{
w, r, "Error": "Invalid username or password",
"Invalid username or password", }
http.StatusUnauthorized, w.WriteHeader(http.StatusUnauthorized)
) h.renderTemplate(w, r, "login.html", data)
return
return user, errInvalidPassword
} }
return user, nil // Get the current session (may be pre-existing / attacker-set)
}
// createAuthenticatedSession regenerates the session and stores
// user info. On failure it writes an HTTP response and returns
// an error.
func (h *Handlers) createAuthenticatedSession(
w http.ResponseWriter,
r *http.Request,
user database.User,
) error {
oldSess, err := h.session.Get(r) oldSess, err := h.session.Get(r)
if err != nil { if err != nil {
h.log.Error("failed to get session", "error", err) h.log.Error("failed to get session", "error", err)
http.Error( http.Error(w, "Internal server error", http.StatusInternalServerError)
w, "Internal server error", return
http.StatusInternalServerError,
)
return err
} }
// Regenerate the session to prevent session fixation attacks.
// This destroys the old session ID and creates a new one.
sess, err := h.session.Regenerate(r, w, oldSess) sess, err := h.session.Regenerate(r, w, oldSess)
if err != nil { if err != nil {
h.log.Error( h.log.Error("failed to regenerate session", "error", err)
"failed to regenerate session", "error", err, http.Error(w, "Internal server error", http.StatusInternalServerError)
) return
http.Error(
w, "Internal server error",
http.StatusInternalServerError,
)
return err
} }
// Set user in session
h.session.SetUser(sess, user.ID, user.Username) h.session.SetUser(sess, user.ID, user.Username)
err = h.session.Save(r, w, sess) // Save session
if err != nil { if err := h.session.Save(r, w, sess); err != nil {
h.log.Error("failed to save session", "error", err) h.log.Error("failed to save session", "error", err)
http.Error( http.Error(w, "Internal server error", http.StatusInternalServerError)
w, "Internal server error", return
http.StatusInternalServerError,
)
return err
} }
return nil h.log.Info("user logged in", "username", username, "user_id", user.ID)
// Redirect to home page
http.Redirect(w, r, "/", http.StatusSeeOther)
}
} }
// HandleLogout handles user logout // HandleLogout handles user logout
@@ -195,10 +118,7 @@ func (h *Handlers) HandleLogout() http.HandlerFunc {
sess, err := h.session.Get(r) sess, err := h.session.Get(r)
if err != nil { if err != nil {
h.log.Error("failed to get session", "error", err) h.log.Error("failed to get session", "error", err)
http.Redirect( http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
w, r, "/pages/login", http.StatusSeeOther,
)
return return
} }
@@ -206,12 +126,8 @@ func (h *Handlers) HandleLogout() http.HandlerFunc {
h.session.Destroy(sess) h.session.Destroy(sess)
// Save the destroyed session // Save the destroyed session
err = h.session.Save(r, w, sess) if err := h.session.Save(r, w, sess); err != nil {
if err != nil { h.log.Error("failed to save destroyed session", "error", err)
h.log.Error(
"failed to save destroyed session",
"error", err,
)
} }
// Redirect to login page // Redirect to login page

View File

@@ -1,14 +0,0 @@
package handlers
import "net/http"
// RenderTemplateForTest exposes renderTemplate for use in the
// handlers_test package.
func (s *Handlers) RenderTemplateForTest(
w http.ResponseWriter,
r *http.Request,
pageTemplate string,
data any,
) {
s.renderTemplate(w, r, pageTemplate, data)
}

View File

@@ -1,11 +1,8 @@
// Package handlers provides HTTP request handlers for the
// webhooker web UI and API.
package handlers package handlers
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"html/template" "html/template"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -16,29 +13,13 @@ import (
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
"sneak.berlin/go/webhooker/internal/healthcheck" "sneak.berlin/go/webhooker/internal/healthcheck"
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
"sneak.berlin/go/webhooker/internal/middleware"
"sneak.berlin/go/webhooker/internal/session" "sneak.berlin/go/webhooker/internal/session"
"sneak.berlin/go/webhooker/templates" "sneak.berlin/go/webhooker/templates"
) )
const ( // nolint:revive // HandlersParams is a standard fx naming convention
// maxBodyShift is the bit shift for 1 MB body limit.
maxBodyShift = 20
// recentEventLimit is the number of recent events to show.
recentEventLimit = 20
// defaultRetentionDays is the default event retention period.
defaultRetentionDays = 30
// paginationPerPage is the number of items per page.
paginationPerPage = 25
)
// errInvalidPassword is returned when a password does not match.
var errInvalidPassword = errors.New("invalid password")
//nolint:revive // HandlersParams is a standard fx naming convention.
type HandlersParams struct { type HandlersParams struct {
fx.In fx.In
Logger *logger.Logger Logger *logger.Logger
Globals *globals.Globals Globals *globals.Globals
Database *database.Database Database *database.Database
@@ -48,8 +29,6 @@ type HandlersParams struct {
Notifier delivery.Notifier Notifier delivery.Notifier
} }
// Handlers provides HTTP handler methods for all application
// routes.
type Handlers struct { type Handlers struct {
params *HandlersParams params *HandlersParams
log *slog.Logger log *slog.Logger
@@ -61,29 +40,19 @@ type Handlers struct {
templates map[string]*template.Template templates map[string]*template.Template
} }
// parsePageTemplate parses a page-specific template set from the // parsePageTemplate parses a page-specific template set from the embedded FS.
// embedded FS. Each page template is combined with the shared // Each page template is combined with the shared base, htmlheader, and navbar templates.
// base, htmlheader, and navbar templates. The page file must be // The page file must be listed first so that its root action ({{template "base" .}})
// listed first so that its root action ({{template "base" .}}) // becomes the template set's entry point. If a shared partial (e.g. htmlheader.html)
// becomes the template set's entry point. // is listed first, its {{define}} block becomes the root — which is empty — and
// Execute() produces no output.
func parsePageTemplate(pageFile string) *template.Template { func parsePageTemplate(pageFile string) *template.Template {
return template.Must( return template.Must(
template.ParseFS( template.ParseFS(templates.Templates, pageFile, "base.html", "htmlheader.html", "navbar.html"),
templates.Templates,
pageFile,
"base.html",
"htmlheader.html",
"navbar.html",
),
) )
} }
// New creates a Handlers instance, parsing all page templates at func New(lc fx.Lifecycle, params HandlersParams) (*Handlers, error) {
// startup.
func New(
lc fx.Lifecycle,
params HandlersParams,
) (*Handlers, error) {
s := new(Handlers) s := new(Handlers)
s.params = &params s.params = &params
s.log = params.Logger.Get() s.log = params.Logger.Get()
@@ -95,6 +64,7 @@ func New(
// Parse all page templates once at startup // Parse all page templates once at startup
s.templates = map[string]*template.Template{ s.templates = map[string]*template.Template{
"index.html": parsePageTemplate("index.html"),
"login.html": parsePageTemplate("login.html"), "login.html": parsePageTemplate("login.html"),
"profile.html": parsePageTemplate("profile.html"), "profile.html": parsePageTemplate("profile.html"),
"sources_list.html": parsePageTemplate("sources_list.html"), "sources_list.html": parsePageTemplate("sources_list.html"),
@@ -105,23 +75,17 @@ func New(
} }
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(ctx context.Context) error {
return nil return nil
}, },
}) })
return s, nil return s, nil
} }
func (s *Handlers) respondJSON( //nolint:unparam // r parameter will be used in the future for request context
w http.ResponseWriter, func (s *Handlers) respondJSON(w http.ResponseWriter, r *http.Request, data interface{}, status int) {
_ *http.Request,
data any,
status int,
) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
if data != nil { if data != nil {
err := json.NewEncoder(w).Encode(data) err := json.NewEncoder(w).Encode(data)
if err != nil { if err != nil {
@@ -130,15 +94,9 @@ func (s *Handlers) respondJSON(
} }
} }
// serverError logs an error and sends a 500 response. //nolint:unparam,unused // will be used for handling JSON requests
func (s *Handlers) serverError( func (s *Handlers) decodeJSON(w http.ResponseWriter, r *http.Request, v interface{}) error {
w http.ResponseWriter, msg string, err error, return json.NewDecoder(r.Body).Decode(v)
) {
s.log.Error(msg, "error", err)
http.Error(
w, "Internal server error",
http.StatusInternalServerError,
)
} }
// UserInfo represents user information for templates // UserInfo represents user information for templates
@@ -147,91 +105,52 @@ type UserInfo struct {
Username string Username string
} }
// templateDataWrapper wraps non-map data with common fields. // renderTemplate renders a pre-parsed template with common data
type templateDataWrapper struct { func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTemplate string, data interface{}) {
User *UserInfo
CSRFToken string
Data any
}
// getUserInfo extracts user info from the session.
func (s *Handlers) getUserInfo(
r *http.Request,
) *UserInfo {
sess, err := s.session.Get(r)
if err != nil || !s.session.IsAuthenticated(sess) {
return nil
}
username, ok := s.session.GetUsername(sess)
if !ok {
return nil
}
userID, ok := s.session.GetUserID(sess)
if !ok {
return nil
}
return &UserInfo{ID: userID, Username: username}
}
// renderTemplate renders a pre-parsed template with common
// data
func (s *Handlers) renderTemplate(
w http.ResponseWriter,
r *http.Request,
pageTemplate string,
data any,
) {
tmpl, ok := s.templates[pageTemplate] tmpl, ok := s.templates[pageTemplate]
if !ok { if !ok {
s.log.Error( s.log.Error("template not found", "template", pageTemplate)
"template not found", http.Error(w, "Internal server error", http.StatusInternalServerError)
"template", pageTemplate,
)
http.Error(
w, "Internal server error",
http.StatusInternalServerError,
)
return return
} }
userInfo := s.getUserInfo(r) // Get user from session if available
csrfToken := middleware.CSRFToken(r) var userInfo *UserInfo
sess, err := s.session.Get(r)
if err == nil && s.session.IsAuthenticated(sess) {
if username, ok := s.session.GetUsername(sess); ok {
if userID, ok := s.session.GetUserID(sess); ok {
userInfo = &UserInfo{
ID: userID,
Username: username,
}
}
}
}
if m, ok := data.(map[string]any); ok { // If data is a map, merge user info into it
if m, ok := data.(map[string]interface{}); ok {
m["User"] = userInfo m["User"] = userInfo
m["CSRFToken"] = csrfToken if err := tmpl.Execute(w, m); err != nil {
s.executeTemplate(w, tmpl, m) s.log.Error("failed to execute template", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
return return
} }
// Wrap data with base template data
type templateDataWrapper struct {
User *UserInfo
Data interface{}
}
wrapper := templateDataWrapper{ wrapper := templateDataWrapper{
User: userInfo, User: userInfo,
CSRFToken: csrfToken,
Data: data, Data: data,
} }
s.executeTemplate(w, tmpl, wrapper) if err := tmpl.Execute(w, wrapper); err != nil {
} s.log.Error("failed to execute template", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
// executeTemplate runs the template and handles errors.
func (s *Handlers) executeTemplate(
w http.ResponseWriter,
tmpl *template.Template,
data any,
) {
err := tmpl.Execute(w, data)
if err != nil {
s.log.Error(
"failed to execute template", "error", err,
)
http.Error(
w, "Internal server error",
http.StatusInternalServerError,
)
} }
} }

View File

@@ -1,10 +1,10 @@
package handlers_test package handlers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -14,23 +14,20 @@ import (
"sneak.berlin/go/webhooker/internal/database" "sneak.berlin/go/webhooker/internal/database"
"sneak.berlin/go/webhooker/internal/delivery" "sneak.berlin/go/webhooker/internal/delivery"
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
"sneak.berlin/go/webhooker/internal/handlers"
"sneak.berlin/go/webhooker/internal/healthcheck" "sneak.berlin/go/webhooker/internal/healthcheck"
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
"sneak.berlin/go/webhooker/internal/session" "sneak.berlin/go/webhooker/internal/session"
) )
// noopNotifier is a no-op delivery.Notifier for tests.
type noopNotifier struct{} type noopNotifier struct{}
func (n *noopNotifier) Notify([]delivery.Task) {} func (n *noopNotifier) Notify([]delivery.DeliveryTask) {}
func newTestApp( func TestHandleIndex(t *testing.T) {
t *testing.T, var h *Handlers
targets ...any,
) *fxtest.App {
t.Helper()
return fxtest.New( app := fxtest.New(
t, t,
fx.Provide( fx.Provide(
globals.New, globals.New,
@@ -44,99 +41,92 @@ func newTestApp(
database.NewWebhookDBManager, database.NewWebhookDBManager,
healthcheck.New, healthcheck.New,
session.New, session.New,
func() delivery.Notifier { func() delivery.Notifier { return &noopNotifier{} },
return &noopNotifier{} New,
},
handlers.New,
), ),
fx.Populate(targets...), fx.Populate(&h),
) )
}
func TestHandleIndex_Unauthenticated(t *testing.T) {
t.Parallel()
var h *handlers.Handlers
app := newTestApp(t, &h)
app.RequireStart() app.RequireStart()
defer app.RequireStop()
t.Cleanup(app.RequireStop) // Since we can't test actual template rendering without templates,
// let's test that the handler is created and doesn't panic
req := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler := h.HandleIndex() handler := h.HandleIndex()
handler.ServeHTTP(w, req) assert.NotNil(t, handler)
assert.Equal(t, http.StatusSeeOther, w.Code)
assert.Equal(
t, "/pages/login", w.Header().Get("Location"),
)
}
func TestHandleIndex_Authenticated(t *testing.T) {
t.Parallel()
var h *handlers.Handlers
var sess *session.Session
app := newTestApp(t, &h, &sess)
app.RequireStart()
t.Cleanup(app.RequireStop)
req := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
w := httptest.NewRecorder()
s, err := sess.Get(req)
require.NoError(t, err)
sess.SetUser(s, "test-user-id", "testuser")
err = sess.Save(req, w, s)
require.NoError(t, err)
req2 := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
for _, cookie := range w.Result().Cookies() {
req2.AddCookie(cookie)
}
w2 := httptest.NewRecorder()
h.HandleIndex().ServeHTTP(w2, req2)
assert.Equal(t, http.StatusSeeOther, w2.Code)
assert.Equal(
t, "/sources", w2.Header().Get("Location"),
)
} }
func TestRenderTemplate(t *testing.T) { func TestRenderTemplate(t *testing.T) {
t.Parallel() var h *Handlers
var h *handlers.Handlers app := fxtest.New(
t,
app := newTestApp(t, &h) fx.Provide(
globals.New,
logger.New,
func() *config.Config {
return &config.Config{
DataDir: t.TempDir(),
}
},
database.New,
database.NewWebhookDBManager,
healthcheck.New,
session.New,
func() delivery.Notifier { return &noopNotifier{} },
New,
),
fx.Populate(&h),
)
app.RequireStart() app.RequireStart()
defer app.RequireStop()
t.Cleanup(app.RequireStop) t.Run("handles missing templates gracefully", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
data := map[string]any{"Version": "1.0.0"} data := map[string]interface{}{
"Version": "1.0.0",
h.RenderTemplateForTest( }
w, req, "nonexistent.html", data,
) // When a non-existent template name is requested, renderTemplate
// should return an internal server error
assert.Equal( h.renderTemplate(w, req, "nonexistent.html", data)
t, http.StatusInternalServerError, w.Code,
) // Should return internal server error when template is not found
assert.Equal(t, http.StatusInternalServerError, w.Code)
})
}
func TestFormatUptime(t *testing.T) {
tests := []struct {
name string
duration string
expected string
}{
{
name: "minutes only",
duration: "45m",
expected: "45m",
},
{
name: "hours and minutes",
duration: "2h30m",
expected: "2h 30m",
},
{
name: "days, hours and minutes",
duration: "25h45m",
expected: "1d 1h 45m",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d, err := time.ParseDuration(tt.duration)
require.NoError(t, err)
result := formatUptime(d)
assert.Equal(t, tt.expected, result)
})
}
} }

View File

@@ -4,13 +4,9 @@ import (
"net/http" "net/http"
) )
const httpStatusOK = 200
// HandleHealthCheck returns an HTTP handler that reports
// application health.
func (s *Handlers) HandleHealthCheck() http.HandlerFunc { func (s *Handlers) HandleHealthCheck() http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) {
resp := s.hc.Healthcheck() resp := s.hc.Healthcheck()
s.respondJSON(w, req, resp, httpStatusOK) s.respondJSON(w, req, resp, 200)
} }
} }

View File

@@ -1,21 +1,49 @@
package handlers package handlers
import ( import (
"fmt"
"net/http" "net/http"
"time"
"sneak.berlin/go/webhooker/internal/database"
) )
// HandleIndex returns a handler for the root path that redirects
// based on authentication state: authenticated users go to /sources
// (the dashboard), unauthenticated users go to the login page.
func (s *Handlers) HandleIndex() http.HandlerFunc { func (s *Handlers) HandleIndex() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { // Calculate server start time
sess, err := s.session.Get(r) startTime := time.Now()
if err == nil && s.session.IsAuthenticated(sess) {
http.Redirect(w, r, "/sources", http.StatusSeeOther)
return return func(w http.ResponseWriter, req *http.Request) {
// Calculate uptime
uptime := time.Since(startTime)
uptimeStr := formatUptime(uptime)
// Get user count from database
var userCount int64
s.db.DB().Model(&database.User{}).Count(&userCount)
// Prepare template data
data := map[string]interface{}{
"Version": s.params.Globals.Version,
"Uptime": uptimeStr,
"UserCount": userCount,
} }
http.Redirect(w, r, "/pages/login", http.StatusSeeOther) // Render the template
s.renderTemplate(w, req, "index.html", data)
} }
} }
// formatUptime formats a duration into a human-readable string
func formatUptime(d time.Duration) string {
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
minutes := int(d.Minutes()) % 60
if days > 0 {
return fmt.Sprintf("%dd %dh %dm", days, hours, minutes)
}
if hours > 0 {
return fmt.Sprintf("%dh %dm", hours, minutes)
}
return fmt.Sprintf("%dm", minutes)
}

View File

@@ -13,7 +13,6 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
requestedUsername := chi.URLParam(r, "username") requestedUsername := chi.URLParam(r, "username")
if requestedUsername == "" { if requestedUsername == "" {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
@@ -22,7 +21,6 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
if err != nil || !h.session.IsAuthenticated(sess) { if err != nil || !h.session.IsAuthenticated(sess) {
// Redirect to login if not authenticated // Redirect to login if not authenticated
http.Redirect(w, r, "/pages/login", http.StatusSeeOther) http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
return return
} }
@@ -31,7 +29,6 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
if !ok { if !ok {
h.log.Error("authenticated session missing username") h.log.Error("authenticated session missing username")
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
@@ -39,19 +36,17 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
if !ok { if !ok {
h.log.Error("authenticated session missing user ID") h.log.Error("authenticated session missing user ID")
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
// For now, only allow users to view their own profile // For now, only allow users to view their own profile
if requestedUsername != sessionUsername { if requestedUsername != sessionUsername {
http.Error(w, "Forbidden", http.StatusForbidden) http.Error(w, "Forbidden", http.StatusForbidden)
return return
} }
// Prepare data for template // Prepare data for template
data := map[string]any{ data := map[string]interface{}{
"User": &UserInfo{ "User": &UserInfo{
ID: sessionUserID, ID: sessionUserID,
Username: sessionUsername, Username: sessionUsername,

File diff suppressed because it is too large Load Diff

View File

@@ -6,36 +6,31 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"gorm.io/gorm"
"sneak.berlin/go/webhooker/internal/database" "sneak.berlin/go/webhooker/internal/database"
"sneak.berlin/go/webhooker/internal/delivery" "sneak.berlin/go/webhooker/internal/delivery"
) )
const ( const (
// maxWebhookBodySize is the maximum allowed webhook // maxWebhookBodySize is the maximum allowed webhook request body (1 MB).
// request body (1 MB). maxWebhookBodySize = 1 << 20
maxWebhookBodySize = 1 << maxBodyShift
) )
// HandleWebhook handles incoming webhook requests at entrypoint // HandleWebhook handles incoming webhook requests at entrypoint URLs.
// URLs. // Only POST requests are accepted; all other methods return 405 Method Not Allowed.
// Events and deliveries are stored in the per-webhook database. The handler
// builds self-contained DeliveryTask structs with all target and event data
// so the delivery engine can process them without additional DB reads.
func (h *Handlers) HandleWebhook() http.HandlerFunc { func (h *Handlers) HandleWebhook() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
w.Header().Set("Allow", "POST") w.Header().Set("Allow", "POST")
http.Error( http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
w,
"Method Not Allowed",
http.StatusMethodNotAllowed,
)
return return
} }
entrypointUUID := chi.URLParam(r, "uuid") entrypointUUID := chi.URLParam(r, "uuid")
if entrypointUUID == "" { if entrypointUUID == "" {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
@@ -45,241 +40,69 @@ func (h *Handlers) HandleWebhook() http.HandlerFunc {
"remote_addr", r.RemoteAddr, "remote_addr", r.RemoteAddr,
) )
entrypoint, ok := h.lookupEntrypoint( // Look up entrypoint by path (from main application DB)
w, r, entrypointUUID, var entrypoint database.Entrypoint
) result := h.db.DB().Where("path = ?", entrypointUUID).First(&entrypoint)
if !ok { if result.Error != nil {
h.log.Debug("entrypoint not found", "path", entrypointUUID)
http.NotFound(w, r)
return return
} }
// Check if active
if !entrypoint.Active { if !entrypoint.Active {
http.Error(w, "Gone", http.StatusGone) http.Error(w, "Gone", http.StatusGone)
return return
} }
h.processWebhookRequest(w, r, entrypoint) // Read body with size limit
body, err := io.ReadAll(io.LimitReader(r.Body, maxWebhookBodySize+1))
if err != nil {
h.log.Error("failed to read request body", "error", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
} }
} if len(body) > maxWebhookBodySize {
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
// processWebhookRequest reads the body, serializes headers,
// loads targets, and delivers the event.
func (h *Handlers) processWebhookRequest(
w http.ResponseWriter,
r *http.Request,
entrypoint database.Entrypoint,
) {
body, ok := h.readWebhookBody(w, r)
if !ok {
return return
} }
// Serialize headers as JSON
headersJSON, err := json.Marshal(r.Header) headersJSON, err := json.Marshal(r.Header)
if err != nil { if err != nil {
h.serverError(w, "failed to serialize headers", err) h.log.Error("failed to serialize headers", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
targets, err := h.loadActiveTargets(entrypoint.WebhookID) // Find all active targets for this webhook (from main application DB)
if err != nil {
h.serverError(w, "failed to query targets", err)
return
}
h.createAndDeliverEvent(
w, r, entrypoint, body, headersJSON, targets,
)
}
// loadActiveTargets returns all active targets for a webhook.
func (h *Handlers) loadActiveTargets(
webhookID string,
) ([]database.Target, error) {
var targets []database.Target var targets []database.Target
if targetErr := h.db.DB().Where("webhook_id = ? AND active = ?", entrypoint.WebhookID, true).Find(&targets).Error; targetErr != nil {
err := h.db.DB().Where( h.log.Error("failed to query targets", "error", targetErr)
"webhook_id = ? AND active = ?", http.Error(w, "Internal server error", http.StatusInternalServerError)
webhookID, true,
).Find(&targets).Error
return targets, err
}
// lookupEntrypoint finds an entrypoint by UUID path.
func (h *Handlers) lookupEntrypoint(
w http.ResponseWriter,
r *http.Request,
entrypointUUID string,
) (database.Entrypoint, bool) {
var entrypoint database.Entrypoint
result := h.db.DB().Where(
"path = ?", entrypointUUID,
).First(&entrypoint)
if result.Error != nil {
h.log.Debug(
"entrypoint not found",
"path", entrypointUUID,
)
http.NotFound(w, r)
return entrypoint, false
}
return entrypoint, true
}
// readWebhookBody reads and validates the request body size.
func (h *Handlers) readWebhookBody(
w http.ResponseWriter,
r *http.Request,
) ([]byte, bool) {
body, err := io.ReadAll(
io.LimitReader(r.Body, maxWebhookBodySize+1),
)
if err != nil {
h.log.Error(
"failed to read request body", "error", err,
)
http.Error(
w, "Bad request", http.StatusBadRequest,
)
return nil, false
}
if len(body) > maxWebhookBodySize {
http.Error(
w,
"Request body too large",
http.StatusRequestEntityTooLarge,
)
return nil, false
}
return body, true
}
// createAndDeliverEvent creates the event and delivery records
// then notifies the delivery engine.
func (h *Handlers) createAndDeliverEvent(
w http.ResponseWriter,
r *http.Request,
entrypoint database.Entrypoint,
body, headersJSON []byte,
targets []database.Target,
) {
tx, err := h.beginWebhookTx(w, entrypoint.WebhookID)
if err != nil {
return return
} }
event := h.buildEvent(r, entrypoint, headersJSON, body) // Get the per-webhook database for event storage
webhookDB, err := h.dbMgr.GetDB(entrypoint.WebhookID)
err = tx.Create(event).Error
if err != nil { if err != nil {
tx.Rollback() h.log.Error("failed to get webhook database",
h.serverError(w, "failed to create event", err) "webhook_id", entrypoint.WebhookID,
"error", err,
return
}
bodyPtr := inlineBody(body)
tasks := h.buildDeliveryTasks(
w, tx, event, entrypoint, targets, bodyPtr,
) )
if tasks == nil { http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
err = tx.Commit().Error // Create the event and deliveries in a transaction on the per-webhook DB
if err != nil {
h.serverError(w, "failed to commit transaction", err)
return
}
h.finishWebhookResponse(w, event, entrypoint, tasks)
}
// beginWebhookTx opens a transaction on the per-webhook DB.
func (h *Handlers) beginWebhookTx(
w http.ResponseWriter,
webhookID string,
) (*gorm.DB, error) {
webhookDB, err := h.dbMgr.GetDB(webhookID)
if err != nil {
h.serverError(
w, "failed to get webhook database", err,
)
return nil, err
}
tx := webhookDB.Begin() tx := webhookDB.Begin()
if tx.Error != nil { if tx.Error != nil {
h.serverError( h.log.Error("failed to begin transaction", "error", tx.Error)
w, "failed to begin transaction", tx.Error, http.Error(w, "Internal server error", http.StatusInternalServerError)
) return
return nil, tx.Error
} }
return tx, nil event := &database.Event{
}
// inlineBody returns a pointer to body as a string if it fits
// within the inline size limit, or nil otherwise.
func inlineBody(body []byte) *string {
if len(body) < delivery.MaxInlineBodySize {
s := string(body)
return &s
}
return nil
}
// finishWebhookResponse notifies the delivery engine, logs the
// event, and writes the HTTP response.
func (h *Handlers) finishWebhookResponse(
w http.ResponseWriter,
event *database.Event,
entrypoint database.Entrypoint,
tasks []delivery.Task,
) {
if len(tasks) > 0 {
h.notifier.Notify(tasks)
}
h.log.Info("webhook event created",
"event_id", event.ID,
"webhook_id", entrypoint.WebhookID,
"entrypoint_id", entrypoint.ID,
"target_count", len(tasks),
)
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`{"status":"ok"}`))
if err != nil {
h.log.Error(
"failed to write response", "error", err,
)
}
}
// buildEvent creates a new Event struct from request data.
func (h *Handlers) buildEvent(
r *http.Request,
entrypoint database.Entrypoint,
headersJSON, body []byte,
) *database.Event {
return &database.Event{
WebhookID: entrypoint.WebhookID, WebhookID: entrypoint.WebhookID,
EntrypointID: entrypoint.ID, EntrypointID: entrypoint.ID,
Method: r.Method, Method: r.Method,
@@ -287,45 +110,41 @@ func (h *Handlers) buildEvent(
Body: string(body), Body: string(body),
ContentType: r.Header.Get("Content-Type"), ContentType: r.Header.Get("Content-Type"),
} }
if err := tx.Create(event).Error; err != nil {
tx.Rollback()
h.log.Error("failed to create event", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
} }
// buildDeliveryTasks creates delivery records in the // Prepare body pointer for inline transport (≤16KB bodies are
// transaction and returns tasks for the delivery engine. // included in the DeliveryTask so the engine needs no DB read).
// Returns nil if an error occurred. var bodyPtr *string
func (h *Handlers) buildDeliveryTasks( if len(body) < delivery.MaxInlineBodySize {
w http.ResponseWriter, bodyStr := string(body)
tx *gorm.DB, bodyPtr = &bodyStr
event *database.Event, }
entrypoint database.Entrypoint,
targets []database.Target,
bodyPtr *string,
) []delivery.Task {
tasks := make([]delivery.Task, 0, len(targets))
// Create delivery records and build self-contained delivery tasks
tasks := make([]delivery.DeliveryTask, 0, len(targets))
for i := range targets { for i := range targets {
dlv := &database.Delivery{ dlv := &database.Delivery{
EventID: event.ID, EventID: event.ID,
TargetID: targets[i].ID, TargetID: targets[i].ID,
Status: database.DeliveryStatusPending, Status: database.DeliveryStatusPending,
} }
if err := tx.Create(dlv).Error; err != nil {
err := tx.Create(dlv).Error
if err != nil {
tx.Rollback() tx.Rollback()
h.log.Error( h.log.Error("failed to create delivery",
"failed to create delivery",
"target_id", targets[i].ID, "target_id", targets[i].ID,
"error", err, "error", err,
) )
http.Error( http.Error(w, "Internal server error", http.StatusInternalServerError)
w, "Internal server error", return
http.StatusInternalServerError,
)
return nil
} }
tasks = append(tasks, delivery.Task{ tasks = append(tasks, delivery.DeliveryTask{
DeliveryID: dlv.ID, DeliveryID: dlv.ID,
EventID: event.ID, EventID: event.ID,
WebhookID: entrypoint.WebhookID, WebhookID: entrypoint.WebhookID,
@@ -342,5 +161,31 @@ func (h *Handlers) buildDeliveryTasks(
}) })
} }
return tasks if err := tx.Commit().Error; err != nil {
h.log.Error("failed to commit transaction", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Notify the delivery engine with self-contained delivery tasks.
// Each task carries all target config and event data inline so
// the engine can deliver without touching any database (in the
// ≤16KB happy path). The engine only writes to the DB to record
// delivery results after each attempt.
if len(tasks) > 0 {
h.notifier.Notify(tasks)
}
h.log.Info("webhook event created",
"event_id", event.ID,
"webhook_id", entrypoint.WebhookID,
"entrypoint_id", entrypoint.ID,
"target_count", len(targets),
)
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil {
h.log.Error("failed to write response", "error", err)
}
}
} }

View File

@@ -1,4 +1,3 @@
// Package healthcheck provides application health status reporting.
package healthcheck package healthcheck
import ( import (
@@ -13,51 +12,55 @@ import (
"sneak.berlin/go/webhooker/internal/logger" "sneak.berlin/go/webhooker/internal/logger"
) )
//nolint:revive // HealthcheckParams is a standard fx naming convention. // nolint:revive // HealthcheckParams is a standard fx naming convention
type HealthcheckParams struct { type HealthcheckParams struct {
fx.In fx.In
Globals *globals.Globals Globals *globals.Globals
Config *config.Config Config *config.Config
Logger *logger.Logger Logger *logger.Logger
Database *database.Database Database *database.Database
} }
// Healthcheck tracks application uptime and reports health status.
type Healthcheck struct { type Healthcheck struct {
StartupTime time.Time StartupTime time.Time
log *slog.Logger log *slog.Logger
params *HealthcheckParams params *HealthcheckParams
} }
// New creates a Healthcheck that records the startup time on fx func New(lc fx.Lifecycle, params HealthcheckParams) (*Healthcheck, error) {
// start.
func New(
lc fx.Lifecycle,
params HealthcheckParams,
) (*Healthcheck, error) {
s := new(Healthcheck) s := new(Healthcheck)
s.params = &params s.params = &params
s.log = params.Logger.Get() s.log = params.Logger.Get()
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
s.StartupTime = time.Now() s.StartupTime = time.Now()
return nil return nil
}, },
OnStop: func(_ context.Context) error { OnStop: func(ctx context.Context) error {
return nil return nil
}, },
}) })
return s, nil return s, nil
} }
// Healthcheck returns the current health status of the // nolint:revive // HealthcheckResponse is a clear, descriptive name
// application. type HealthcheckResponse struct {
func (s *Healthcheck) Healthcheck() *Response { Status string `json:"status"`
resp := &Response{ Now string `json:"now"`
UptimeSeconds int64 `json:"uptime_seconds"`
UptimeHuman string `json:"uptime_human"`
Version string `json:"version"`
Appname string `json:"appname"`
Maintenance bool `json:"maintenance_mode"`
}
func (s *Healthcheck) uptime() time.Duration {
return time.Since(s.StartupTime)
}
func (s *Healthcheck) Healthcheck() *HealthcheckResponse {
resp := &HealthcheckResponse{
Status: "ok", Status: "ok",
Now: time.Now().UTC().Format(time.RFC3339Nano), Now: time.Now().UTC().Format(time.RFC3339Nano),
UptimeSeconds: int64(s.uptime().Seconds()), UptimeSeconds: int64(s.uptime().Seconds()),
@@ -66,21 +69,5 @@ func (s *Healthcheck) Healthcheck() *Response {
Version: s.params.Globals.Version, Version: s.params.Globals.Version,
Maintenance: s.params.Config.MaintenanceMode, Maintenance: s.params.Config.MaintenanceMode,
} }
return resp return resp
} }
// Response contains the JSON-serialised health status.
type Response struct {
Status string `json:"status"`
Now string `json:"now"`
UptimeSeconds int64 `json:"uptimeSeconds"`
UptimeHuman string `json:"uptimeHuman"`
Version string `json:"version"`
Appname string `json:"appname"`
Maintenance bool `json:"maintenanceMode"`
}
func (s *Healthcheck) uptime() time.Duration {
return time.Since(s.StartupTime)
}

View File

@@ -1,5 +1,3 @@
// Package logger provides structured logging with dynamic level
// control.
package logger package logger
import ( import (
@@ -12,25 +10,19 @@ import (
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
) )
//nolint:revive // LoggerParams is a standard fx naming convention. // nolint:revive // LoggerParams is a standard fx naming convention
type LoggerParams struct { type LoggerParams struct {
fx.In fx.In
Globals *globals.Globals Globals *globals.Globals
} }
// Logger wraps slog with dynamic level control and structured
// output.
type Logger struct { type Logger struct {
logger *slog.Logger logger *slog.Logger
levelVar *slog.LevelVar levelVar *slog.LevelVar
params LoggerParams params LoggerParams
} }
// New creates a Logger that outputs text (TTY) or JSON (non-TTY) // nolint:revive // lc parameter is required by fx even if unused
// to stdout.
//
//nolint:revive // lc parameter is required by fx even if unused.
func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) { func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
l := new(Logger) l := new(Logger)
l.params = params l.params = params
@@ -45,22 +37,17 @@ func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
tty = true tty = true
} }
//nolint:revive // groups param unused but required by slog ReplaceAttr signature. replaceAttr := func(_ []string, a slog.Attr) slog.Attr { // nolint:revive // groups unused
replaceAttr := func(_ []string, a slog.Attr) slog.Attr {
// Always use UTC for timestamps // Always use UTC for timestamps
if a.Key == slog.TimeKey { if a.Key == slog.TimeKey {
if t, ok := a.Value.Any().(time.Time); ok { if t, ok := a.Value.Any().(time.Time); ok {
return slog.Time(slog.TimeKey, t.UTC()) return slog.Time(slog.TimeKey, t.UTC())
} }
return a
} }
return a return a
} }
var handler slog.Handler var handler slog.Handler
opts := &slog.HandlerOptions{ opts := &slog.HandlerOptions{
Level: l.levelVar, Level: l.levelVar,
ReplaceAttr: replaceAttr, ReplaceAttr: replaceAttr,
@@ -82,18 +69,15 @@ func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
return l, nil return l, nil
} }
// EnableDebugLogging switches the log level to debug.
func (l *Logger) EnableDebugLogging() { func (l *Logger) EnableDebugLogging() {
l.levelVar.Set(slog.LevelDebug) l.levelVar.Set(slog.LevelDebug)
l.logger.Debug("debug logging enabled", "debug", true) l.logger.Debug("debug logging enabled", "debug", true)
} }
// Get returns the underlying slog.Logger.
func (l *Logger) Get() *slog.Logger { func (l *Logger) Get() *slog.Logger {
return l.logger return l.logger
} }
// Identify logs the application name and version at startup.
func (l *Logger) Identify() { func (l *Logger) Identify() {
l.logger.Info("starting", l.logger.Info("starting",
"appname", l.params.Globals.Appname, "appname", l.params.Globals.Appname,
@@ -101,8 +85,7 @@ func (l *Logger) Identify() {
) )
} }
// Writer returns an io.Writer suitable for standard library // Helper methods to maintain compatibility with existing code
// loggers.
func (l *Logger) Writer() io.Writer { func (l *Logger) Writer() io.Writer {
return os.Stdout return os.Stdout
} }

View File

@@ -1,59 +1,63 @@
package logger_test package logger
import ( import (
"testing" "testing"
"go.uber.org/fx/fxtest" "go.uber.org/fx/fxtest"
"sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/globals"
"sneak.berlin/go/webhooker/internal/logger"
) )
func testGlobals() *globals.Globals {
return &globals.Globals{
Appname: "test-app",
Version: "1.0.0",
}
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
t.Parallel() // Set up globals
globals.Appname = "test-app"
globals.Version = "1.0.0"
lc := fxtest.NewLifecycle(t) lc := fxtest.NewLifecycle(t)
g, err := globals.New(lc)
params := logger.LoggerParams{ if err != nil {
Globals: testGlobals(), t.Fatalf("globals.New() error = %v", err)
} }
l, err := logger.New(lc, params) params := LoggerParams{
Globals: g,
}
logger, err := New(lc, params)
if err != nil { if err != nil {
t.Fatalf("New() error = %v", err) t.Fatalf("New() error = %v", err)
} }
if l.Get() == nil { if logger.Get() == nil {
t.Error("Get() returned nil logger") t.Error("Get() returned nil logger")
} }
// Test that we can log without panic // Test that we can log without panic
l.Get().Info("test message", "key", "value") logger.Get().Info("test message", "key", "value")
} }
func TestEnableDebugLogging(t *testing.T) { func TestEnableDebugLogging(t *testing.T) {
t.Parallel() // Set up globals
globals.Appname = "test-app"
globals.Version = "1.0.0"
lc := fxtest.NewLifecycle(t) lc := fxtest.NewLifecycle(t)
g, err := globals.New(lc)
params := logger.LoggerParams{ if err != nil {
Globals: testGlobals(), t.Fatalf("globals.New() error = %v", err)
} }
l, err := logger.New(lc, params) params := LoggerParams{
Globals: g,
}
logger, err := New(lc, params)
if err != nil { if err != nil {
t.Fatalf("New() error = %v", err) t.Fatalf("New() error = %v", err)
} }
// Enable debug logging should not panic // Enable debug logging should not panic
l.EnableDebugLogging() logger.EnableDebugLogging()
// Test debug logging // Test debug logging
l.Get().Debug("debug message", "test", true) logger.Get().Debug("debug message", "test", true)
} }

View File

@@ -1,84 +0,0 @@
package middleware
import (
"net/http"
"github.com/gorilla/csrf"
)
// CSRFToken retrieves the CSRF token from the request context.
// Returns an empty string if the gorilla/csrf middleware has not run.
func CSRFToken(r *http.Request) string {
return csrf.Token(r)
}
// isClientTLS reports whether the client-facing connection uses TLS.
// It checks for a direct TLS connection (r.TLS) or a TLS-terminating
// reverse proxy that sets the standard X-Forwarded-Proto header.
func isClientTLS(r *http.Request) bool {
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
}
// CSRF returns middleware that provides CSRF protection using the
// gorilla/csrf library. The middleware uses the session authentication
// key to sign a CSRF cookie and validates a masked token submitted via
// the "csrf_token" form field (or the "X-CSRF-Token" header) on
// POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing
// token receive a 403 Forbidden response.
//
// The middleware detects the client-facing transport protocol per-request
// using r.TLS and the X-Forwarded-Proto header. This allows correct
// behavior in all deployment scenarios:
//
// - Direct HTTPS: strict Referer/Origin checks, Secure cookies.
// - Behind a TLS-terminating reverse proxy: strict checks (the
// browser is on HTTPS, so Origin/Referer headers use https://),
// Secure cookies (the browser sees HTTPS from the proxy).
// - Direct HTTP: relaxed Referer/Origin checks via PlaintextHTTPRequest,
// non-Secure cookies so the browser sends them over HTTP.
//
// Two gorilla/csrf instances are maintained — one with Secure cookies
// (for TLS) and one without (for plaintext HTTP) — because the
// csrf.Secure option is set at creation time, not per-request.
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
csrfErrorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.log.Warn("csrf: token validation failed",
"method", r.Method,
"path", r.URL.Path,
"remote_addr", r.RemoteAddr,
"reason", csrf.FailureReason(r),
)
http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden)
})
key := m.session.GetKey()
baseOpts := []csrf.Option{
csrf.FieldName("csrf_token"),
csrf.SameSite(csrf.SameSiteLaxMode),
csrf.Path("/"),
csrf.ErrorHandler(csrfErrorHandler),
}
// Two middleware instances with different Secure flags but the
// same signing key, so cookies are interchangeable between them.
tlsProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(true))...)
httpProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(false))...)
return func(next http.Handler) http.Handler {
tlsCSRF := tlsProtect(next)
httpCSRF := httpProtect(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isClientTLS(r) {
// Client is on TLS (directly or via reverse proxy).
// Use Secure cookies and strict Origin/Referer checks.
tlsCSRF.ServeHTTP(w, r)
} else {
// Plaintext HTTP: use non-Secure cookies and tell
// gorilla/csrf to use "http" for scheme comparisons,
// skipping the strict Referer check that assumes TLS.
httpCSRF.ServeHTTP(w, csrf.PlaintextHTTPRequest(r))
}
})
}
}

View File

@@ -1,494 +0,0 @@
package middleware_test
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/middleware"
)
// csrfCookieName is the gorilla/csrf cookie name.
const csrfCookieName = "_gorilla_csrf"
// csrfGetToken performs a GET request through the CSRF middleware
// and returns the token and cookies.
func csrfGetToken(
t *testing.T,
csrfMW func(http.Handler) http.Handler,
getReq *http.Request,
) (string, []*http.Cookie) {
t.Helper()
var token string
getHandler := csrfMW(http.HandlerFunc(
func(_ http.ResponseWriter, r *http.Request) {
token = middleware.CSRFToken(r)
},
))
getW := httptest.NewRecorder()
getHandler.ServeHTTP(getW, getReq)
cookies := getW.Result().Cookies()
require.NotEmpty(t, cookies, "CSRF cookie should be set")
require.NotEmpty(t, token, "CSRF token should be set")
return token, cookies
}
// csrfPostWithToken performs a POST request with the given CSRF
// token and cookies through the middleware. Returns whether the
// handler was called and the response code.
func csrfPostWithToken(
t *testing.T,
csrfMW func(http.Handler) http.Handler,
postReq *http.Request,
token string,
cookies []*http.Cookie,
) (bool, int) {
t.Helper()
var called bool
postHandler := csrfMW(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true
},
))
form := url.Values{"csrf_token": {token}}
postReq.Body = http.NoBody
postReq.Body = nil
// Rebuild the request with the form body
rebuilt := httptest.NewRequestWithContext(
context.Background(),
postReq.Method, postReq.URL.String(),
strings.NewReader(form.Encode()),
)
rebuilt.Header = postReq.Header.Clone()
rebuilt.TLS = postReq.TLS
rebuilt.Header.Set(
"Content-Type", "application/x-www-form-urlencoded",
)
for _, c := range cookies {
rebuilt.AddCookie(c)
}
postW := httptest.NewRecorder()
postHandler.ServeHTTP(postW, rebuilt)
return called, postW.Code
}
func TestCSRF_GETSetsToken(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
var gotToken string
handler := m.CSRF()(http.HandlerFunc(
func(_ http.ResponseWriter, r *http.Request) {
gotToken = middleware.CSRFToken(r)
},
))
req := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/form", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.NotEmpty(
t, gotToken,
"CSRF token should be set in context on GET",
)
}
func TestCSRF_POSTWithValidToken(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
csrfMW := m.CSRF()
getReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodGet, "/form", nil,
)
token, cookies := csrfGetToken(t, csrfMW, getReq)
postReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/form", nil,
)
called, _ := csrfPostWithToken(
t, csrfMW, postReq, token, cookies,
)
assert.True(
t, called,
"handler should be called with valid CSRF token",
)
}
// csrfPOSTWithoutTokenTest is a shared helper for testing POST
// requests without a CSRF token in both dev and prod modes.
func csrfPOSTWithoutTokenTest(
t *testing.T,
env string,
msg string,
) {
t.Helper()
m, _ := testMiddleware(t, env)
csrfMW := m.CSRF()
// GET to establish the CSRF cookie
getHandler := csrfMW(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {},
))
getReq := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/form", nil)
getW := httptest.NewRecorder()
getHandler.ServeHTTP(getW, getReq)
cookies := getW.Result().Cookies()
// POST without CSRF token
var called bool
postHandler := csrfMW(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true
},
))
postReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/form", nil,
)
postReq.Header.Set(
"Content-Type", "application/x-www-form-urlencoded",
)
for _, c := range cookies {
postReq.AddCookie(c)
}
postW := httptest.NewRecorder()
postHandler.ServeHTTP(postW, postReq)
assert.False(t, called, msg)
assert.Equal(t, http.StatusForbidden, postW.Code)
}
func TestCSRF_POSTWithoutToken(t *testing.T) {
t.Parallel()
csrfPOSTWithoutTokenTest(
t,
config.EnvironmentDev,
"handler should NOT be called without CSRF token",
)
}
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
csrfMW := m.CSRF()
// GET to establish the CSRF cookie
getHandler := csrfMW(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {},
))
getReq := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/form", nil)
getW := httptest.NewRecorder()
getHandler.ServeHTTP(getW, getReq)
cookies := getW.Result().Cookies()
// POST with wrong CSRF token
var called bool
postHandler := csrfMW(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true
},
))
form := url.Values{"csrf_token": {"invalid-token-value"}}
postReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/form",
strings.NewReader(form.Encode()),
)
postReq.Header.Set(
"Content-Type", "application/x-www-form-urlencoded",
)
for _, c := range cookies {
postReq.AddCookie(c)
}
postW := httptest.NewRecorder()
postHandler.ServeHTTP(postW, postReq)
assert.False(
t, called,
"handler should NOT be called with invalid CSRF token",
)
assert.Equal(t, http.StatusForbidden, postW.Code)
}
func TestCSRF_GETDoesNotValidate(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
var called bool
handler := m.CSRF()(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true
},
))
req := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/form", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.True(
t, called,
"GET requests should pass through CSRF middleware",
)
}
func TestCSRFToken_NoMiddleware(t *testing.T) {
t.Parallel()
req := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
assert.Empty(
t, middleware.CSRFToken(req),
"CSRFToken should return empty string when "+
"middleware has not run",
)
}
// --- TLS Detection Tests ---
func TestIsClientTLS_DirectTLS(t *testing.T) {
t.Parallel()
r := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
r.TLS = &tls.ConnectionState{}
assert.True(
t, middleware.IsClientTLS(r),
"should detect direct TLS connection",
)
}
func TestIsClientTLS_XForwardedProto(t *testing.T) {
t.Parallel()
r := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
r.Header.Set("X-Forwarded-Proto", "https")
assert.True(
t, middleware.IsClientTLS(r),
"should detect TLS via X-Forwarded-Proto",
)
}
func TestIsClientTLS_PlaintextHTTP(t *testing.T) {
t.Parallel()
r := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
assert.False(
t, middleware.IsClientTLS(r),
"should detect plaintext HTTP",
)
}
func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) {
t.Parallel()
r := httptest.NewRequestWithContext(
context.Background(), http.MethodGet, "/", nil)
r.Header.Set("X-Forwarded-Proto", "http")
assert.False(
t, middleware.IsClientTLS(r),
"should detect plaintext when X-Forwarded-Proto is http",
)
}
// --- Production Mode: POST over plaintext HTTP ---
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(
t *testing.T,
) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentProd)
csrfMW := m.CSRF()
getReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodGet, "/form", nil,
)
token, cookies := csrfGetToken(t, csrfMW, getReq)
// Verify cookie is NOT Secure (plaintext HTTP in prod)
for _, c := range cookies {
if c.Name == csrfCookieName {
assert.False(t, c.Secure,
"CSRF cookie should not be Secure "+
"over plaintext HTTP")
}
}
postReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/form", nil,
)
called, code := csrfPostWithToken(
t, csrfMW, postReq, token, cookies,
)
assert.True(t, called,
"handler should be called -- prod mode over "+
"plaintext HTTP must work")
assert.NotEqual(t, http.StatusForbidden, code,
"should not return 403")
}
// --- Production Mode: POST with X-Forwarded-Proto ---
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(
t *testing.T,
) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentProd)
csrfMW := m.CSRF()
getReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodGet, "http://example.com/form", nil,
)
getReq.Header.Set("X-Forwarded-Proto", "https")
token, cookies := csrfGetToken(t, csrfMW, getReq)
// Verify cookie IS Secure (X-Forwarded-Proto: https)
for _, c := range cookies {
if c.Name == csrfCookieName {
assert.True(t, c.Secure,
"CSRF cookie should be Secure behind "+
"TLS proxy")
}
}
postReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "http://example.com/form", nil,
)
postReq.Header.Set("X-Forwarded-Proto", "https")
postReq.Header.Set("Origin", "https://example.com")
called, code := csrfPostWithToken(
t, csrfMW, postReq, token, cookies,
)
assert.True(t, called,
"handler should be called -- prod mode behind "+
"TLS proxy must work")
assert.NotEqual(t, http.StatusForbidden, code,
"should not return 403")
}
// --- Production Mode: direct TLS ---
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(
t *testing.T,
) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentProd)
csrfMW := m.CSRF()
getReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodGet, "https://example.com/form", nil,
)
getReq.TLS = &tls.ConnectionState{}
token, cookies := csrfGetToken(t, csrfMW, getReq)
// Verify cookie IS Secure (direct TLS)
for _, c := range cookies {
if c.Name == csrfCookieName {
assert.True(t, c.Secure,
"CSRF cookie should be Secure over "+
"direct TLS")
}
}
postReq := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "https://example.com/form", nil,
)
postReq.TLS = &tls.ConnectionState{}
postReq.Header.Set("Origin", "https://example.com")
called, code := csrfPostWithToken(
t, csrfMW, postReq, token, cookies,
)
assert.True(t, called,
"handler should be called -- direct TLS must work")
assert.NotEqual(t, http.StatusForbidden, code,
"should not return 403")
}
// --- Production Mode: POST without token still rejects ---
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(
t *testing.T,
) {
t.Parallel()
csrfPOSTWithoutTokenTest(
t,
config.EnvironmentProd,
"handler should NOT be called without CSRF token "+
"even in prod+plaintext",
)
}

View File

@@ -1,34 +0,0 @@
package middleware
import (
"net/http"
)
// NewLoggingResponseWriterForTest wraps newLoggingResponseWriter
// for use in external test packages.
func NewLoggingResponseWriterForTest(
w http.ResponseWriter,
) *loggingResponseWriter {
return newLoggingResponseWriter(w)
}
// LoggingResponseWriterStatusCode returns the status code
// captured by the loggingResponseWriter.
func LoggingResponseWriterStatusCode(
lrw *loggingResponseWriter,
) int {
return lrw.statusCode
}
// IPFromHostPort exposes ipFromHostPort for testing.
func IPFromHostPort(hp string) string {
return ipFromHostPort(hp)
}
// IsClientTLS exposes isClientTLS for testing.
func IsClientTLS(r *http.Request) bool {
return isClientTLS(r)
}
// LoginRateLimitConst exposes the loginRateLimit constant.
const LoginRateLimitConst = loginRateLimit

View File

@@ -1,5 +1,3 @@
// Package middleware provides HTTP middleware for logging, auth,
// CORS, and metrics.
package middleware package middleware
import ( import (
@@ -21,42 +19,26 @@ import (
"sneak.berlin/go/webhooker/internal/session" "sneak.berlin/go/webhooker/internal/session"
) )
const ( // nolint:revive // MiddlewareParams is a standard fx naming convention
// corsMaxAge is the maximum time (in seconds) that a
// preflight response can be cached.
corsMaxAge = 300
)
//nolint:revive // MiddlewareParams is a standard fx naming convention.
type MiddlewareParams struct { type MiddlewareParams struct {
fx.In fx.In
Logger *logger.Logger Logger *logger.Logger
Globals *globals.Globals Globals *globals.Globals
Config *config.Config Config *config.Config
Session *session.Session Session *session.Session
} }
// Middleware provides HTTP middleware for logging, CORS, auth, and
// metrics.
type Middleware struct { type Middleware struct {
log *slog.Logger log *slog.Logger
params *MiddlewareParams params *MiddlewareParams
session *session.Session session *session.Session
} }
// New creates a Middleware from the provided fx parameters. func New(lc fx.Lifecycle, params MiddlewareParams) (*Middleware, error) {
//
//nolint:revive // lc parameter is required by fx even if unused.
func New(
lc fx.Lifecycle,
params MiddlewareParams,
) (*Middleware, error) {
s := new(Middleware) s := new(Middleware)
s.params = &params s.params = &params
s.log = params.Logger.Get() s.log = params.Logger.Get()
s.session = params.Session s.session = params.Session
return s, nil return s, nil
} }
@@ -68,24 +50,19 @@ func ipFromHostPort(hp string) string {
if err != nil { if err != nil {
return "" return ""
} }
if len(h) > 0 && h[0] == '[' { if len(h) > 0 && h[0] == '[' {
return h[1 : len(h)-1] return h[1 : len(h)-1]
} }
return h return h
} }
type loggingResponseWriter struct { type loggingResponseWriter struct {
http.ResponseWriter http.ResponseWriter
statusCode int statusCode int
} }
// newLoggingResponseWriter wraps w and records status codes. // nolint:revive // unexported type is only used internally
func newLoggingResponseWriter( func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
w http.ResponseWriter,
) *loggingResponseWriter {
return &loggingResponseWriter{w, http.StatusOK} return &loggingResponseWriter{w, http.StatusOK}
} }
@@ -94,30 +71,23 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.ResponseWriter.WriteHeader(code) lrw.ResponseWriter.WriteHeader(code)
} }
// Logging returns middleware that logs each HTTP request with // type Middleware func(http.Handler) http.Handler
// timing and metadata. // this returns a Middleware that is designed to do every request through the
// mux, note the signature:
func (s *Middleware) Logging() func(http.Handler) http.Handler { func (s *Middleware) Logging() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func( return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w http.ResponseWriter,
r *http.Request,
) {
start := time.Now() start := time.Now()
lrw := newLoggingResponseWriter(w) lrw := NewLoggingResponseWriter(w)
ctx := r.Context() ctx := r.Context()
defer func() { defer func() {
latency := time.Since(start) latency := time.Since(start)
requestID := "" requestID := ""
if reqID := ctx.Value(middleware.RequestIDKey); reqID != nil {
if reqID := ctx.Value(
middleware.RequestIDKey,
); reqID != nil {
if id, ok := reqID.(string); ok { if id, ok := reqID.(string); ok {
requestID = id requestID = id
} }
} }
s.log.Info("http request", s.log.Info("http request",
"request_start", start, "request_start", start,
"method", r.Method, "method", r.Method,
@@ -137,29 +107,20 @@ func (s *Middleware) Logging() func(http.Handler) http.Handler {
} }
} }
// CORS returns middleware that sets CORS headers (permissive in
// dev, no-op in prod).
func (s *Middleware) CORS() func(http.Handler) http.Handler { func (s *Middleware) CORS() func(http.Handler) http.Handler {
if s.params.Config.IsDev() { if s.params.Config.IsDev() {
// In development, allow any origin for local testing. // In development, allow any origin for local testing.
return cors.Handler(cors.Options{ return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
AllowedMethods: []string{ AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
"GET", "POST", "PUT", "DELETE", "OPTIONS", AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
},
AllowedHeaders: []string{
"Accept", "Authorization",
"Content-Type", "X-CSRF-Token",
},
ExposedHeaders: []string{"Link"}, ExposedHeaders: []string{"Link"},
AllowCredentials: false, AllowCredentials: false,
MaxAge: corsMaxAge, MaxAge: 300,
}) })
} }
// In production, the web UI is server-rendered so cross-origin
// In production, the web UI is server-rendered so // requests are not expected. Return a no-op middleware.
// cross-origin requests are not expected. Return a no-op
// middleware.
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return next return next
} }
@@ -169,33 +130,20 @@ func (s *Middleware) CORS() func(http.Handler) http.Handler {
// Unauthenticated users are redirected to the login page. // Unauthenticated users are redirected to the login page.
func (s *Middleware) RequireAuth() func(http.Handler) http.Handler { func (s *Middleware) RequireAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func( return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w http.ResponseWriter,
r *http.Request,
) {
sess, err := s.session.Get(r) sess, err := s.session.Get(r)
if err != nil { if err != nil {
s.log.Debug( s.log.Debug("auth middleware: failed to get session", "error", err)
"auth middleware: failed to get session", http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
"error", err,
)
http.Redirect(
w, r, "/pages/login", http.StatusSeeOther,
)
return return
} }
if !s.session.IsAuthenticated(sess) { if !s.session.IsAuthenticated(sess) {
s.log.Debug( s.log.Debug("auth middleware: unauthenticated request",
"auth middleware: unauthenticated request",
"path", r.URL.Path, "path", r.URL.Path,
"method", r.Method, "method", r.Method,
) )
http.Redirect( http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
w, r, "/pages/login", http.StatusSeeOther,
)
return return
} }
@@ -204,19 +152,15 @@ func (s *Middleware) RequireAuth() func(http.Handler) http.Handler {
} }
} }
// Metrics returns middleware that records Prometheus HTTP metrics.
func (s *Middleware) Metrics() func(http.Handler) http.Handler { func (s *Middleware) Metrics() func(http.Handler) http.Handler {
mdlw := ghmm.New(ghmm.Config{ mdlw := ghmm.New(ghmm.Config{
Recorder: metrics.NewRecorder(metrics.Config{}), Recorder: metrics.NewRecorder(metrics.Config{}),
}) })
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return std.Handler("", mdlw, next) return std.Handler("", mdlw, next)
} }
} }
// MetricsAuth returns middleware that protects metrics endpoints
// with basic auth.
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler { func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
return basicauth.New( return basicauth.New(
"metrics", "metrics",
@@ -228,63 +172,33 @@ func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
) )
} }
// SecurityHeaders returns middleware that sets production security // SecurityHeaders returns middleware that sets production security headers
// headers on every response: HSTS, X-Content-Type-Options, // on every response: HSTS, X-Content-Type-Options, X-Frame-Options, CSP,
// X-Frame-Options, CSP, Referrer-Policy, and Permissions-Policy. // Referrer-Policy, and Permissions-Policy.
func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler { func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func( return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w http.ResponseWriter, w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
r *http.Request, w.Header().Set("X-Content-Type-Options", "nosniff")
) {
w.Header().Set(
"Strict-Transport-Security",
"max-age=63072000; includeSubDomains; preload",
)
w.Header().Set(
"X-Content-Type-Options", "nosniff",
)
w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set( w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
"Content-Security-Policy", w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
"default-src 'self'; "+ w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
"script-src 'self' 'unsafe-inline'; "+
"style-src 'self' 'unsafe-inline'",
)
w.Header().Set(
"Referrer-Policy",
"strict-origin-when-cross-origin",
)
w.Header().Set(
"Permissions-Policy",
"camera=(), microphone=(), geolocation=()",
)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
} }
// MaxBodySize returns middleware that limits the request body size // MaxBodySize returns middleware that limits the request body size for POST
// for POST requests. If the body exceeds the given limit in // requests. If the body exceeds the given limit in bytes, the server returns
// bytes, the server returns 413 Request Entity Too Large. This // 413 Request Entity Too Large. This prevents clients from sending arbitrarily
// prevents clients from sending arbitrarily large form bodies. // large form bodies.
func (s *Middleware) MaxBodySize( func (s *Middleware) MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
maxBytes int64,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func( return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w http.ResponseWriter, if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
r *http.Request, r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
) {
if r.Method == http.MethodPost ||
r.Method == http.MethodPut ||
r.Method == http.MethodPatch {
r.Body = http.MaxBytesReader(
w, r.Body, maxBytes,
)
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }

View File

@@ -1,7 +1,6 @@
package middleware_test package middleware
import ( import (
"context"
"encoding/base64" "encoding/base64"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -13,37 +12,25 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"sneak.berlin/go/webhooker/internal/config" "sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/middleware"
"sneak.berlin/go/webhooker/internal/session" "sneak.berlin/go/webhooker/internal/session"
) )
const testKeySize = 32 // testMiddleware creates a Middleware with minimal dependencies for testing.
// It uses a real session.Session backed by an in-memory cookie store.
// testMiddleware creates a Middleware with minimal dependencies func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
// for testing. It uses a real session.Session backed by an
// in-memory cookie store.
func testMiddleware(
t *testing.T,
env string,
) (*middleware.Middleware, *session.Session) {
t.Helper() t.Helper()
log := slog.New(slog.NewTextHandler( log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
cfg := &config.Config{ cfg := &config.Config{
Environment: env, Environment: env,
} }
// Create a real session manager with a known key // Create a real session manager with a known key
key := make([]byte, testKeySize) key := make([]byte, 32)
for i := range key { for i := range key {
key[i] = byte(i) key[i] = byte(i)
} }
store := sessions.NewCookieStore(key) store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: "/",
@@ -53,33 +40,40 @@ func testMiddleware(
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
} }
sessManager := session.NewForTest(store, cfg, log, key) sessManager := newTestSession(t, store, cfg, log)
m := middleware.NewForTest(log, cfg, sessManager) m := &Middleware{
log: log,
params: &MiddlewareParams{
Config: cfg,
},
session: sessManager,
}
return m, sessManager return m, sessManager
} }
// newTestSession creates a session.Session with a pre-configured cookie store
// for testing. This avoids needing the fx lifecycle and database.
func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *session.Session {
t.Helper()
return session.NewForTest(store, cfg, log)
}
// --- Logging Middleware Tests --- // --- Logging Middleware Tests ---
func TestLogging_SetsStatusCode(t *testing.T) { func TestLogging_SetsStatusCode(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
handler := m.Logging()(http.HandlerFunc( handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
if _, err := w.Write([]byte("created")); err != nil {
_, err := w.Write([]byte("created"))
if err != nil {
return return
} }
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/test", nil)
context.Background(), http.MethodGet, "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
@@ -90,20 +84,15 @@ func TestLogging_SetsStatusCode(t *testing.T) {
func TestLogging_DefaultStatusOK(t *testing.T) { func TestLogging_DefaultStatusOK(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
handler := m.Logging()(http.HandlerFunc( handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write([]byte("ok")); err != nil {
_, err := w.Write([]byte("ok"))
if err != nil {
return return
} }
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
@@ -114,31 +103,20 @@ func TestLogging_DefaultStatusOK(t *testing.T) {
func TestLogging_PassesThroughToNext(t *testing.T) { func TestLogging_PassesThroughToNext(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
var called bool var called bool
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := m.Logging()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
called = true called = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodPost, "/api/webhook", nil)
context.Background(),
http.MethodPost, "/api/webhook", nil,
)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.True( assert.True(t, called, "logging middleware should call the next handler")
t, called,
"logging middleware should call the next handler",
)
} }
// --- LoggingResponseWriter Tests --- // --- LoggingResponseWriter Tests ---
@@ -147,33 +125,24 @@ func TestLoggingResponseWriter_CapturesStatusCode(t *testing.T) {
t.Parallel() t.Parallel()
w := httptest.NewRecorder() w := httptest.NewRecorder()
lrw := middleware.NewLoggingResponseWriterForTest(w) lrw := NewLoggingResponseWriter(w)
// Default should be 200 // Default should be 200
assert.Equal( assert.Equal(t, http.StatusOK, lrw.statusCode)
t, http.StatusOK,
middleware.LoggingResponseWriterStatusCode(lrw),
)
// WriteHeader should capture the status code // WriteHeader should capture the status code
lrw.WriteHeader(http.StatusNotFound) lrw.WriteHeader(http.StatusNotFound)
assert.Equal(t, http.StatusNotFound, lrw.statusCode)
assert.Equal(
t, http.StatusNotFound,
middleware.LoggingResponseWriterStatusCode(lrw),
)
// Underlying writer should also get the status code // Underlying writer should also get the status code
assert.Equal(t, http.StatusNotFound, w.Code) assert.Equal(t, http.StatusNotFound, w.Code)
} }
func TestLoggingResponseWriter_WriteDelegatesToUnderlying( func TestLoggingResponseWriter_WriteDelegatesToUnderlying(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
w := httptest.NewRecorder() w := httptest.NewRecorder()
lrw := middleware.NewLoggingResponseWriterForTest(w) lrw := NewLoggingResponseWriter(w)
n, err := lrw.Write([]byte("hello world")) n, err := lrw.Write([]byte("hello world"))
require.NoError(t, err) require.NoError(t, err)
@@ -185,124 +154,79 @@ func TestLoggingResponseWriter_WriteDelegatesToUnderlying(
func TestCORS_DevMode_AllowsAnyOrigin(t *testing.T) { func TestCORS_DevMode_AllowsAnyOrigin(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
handler := m.CORS()(http.HandlerFunc( handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, }))
))
// Preflight request // Preflight request
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodOptions, "/api/test", nil)
context.Background(),
http.MethodOptions, "/api/test", nil,
)
req.Header.Set("Origin", "http://localhost:3000") req.Header.Set("Origin", "http://localhost:3000")
req.Header.Set("Access-Control-Request-Method", "POST") req.Header.Set("Access-Control-Request-Method", "POST")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
// In dev mode, CORS should allow any origin // In dev mode, CORS should allow any origin
assert.Equal( assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
t, "*",
w.Header().Get("Access-Control-Allow-Origin"),
)
} }
func TestCORS_ProdMode_NoOp(t *testing.T) { func TestCORS_ProdMode_NoOp(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentProd) m, _ := testMiddleware(t, config.EnvironmentProd)
var called bool var called bool
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := m.CORS()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
called = true called = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
context.Background(),
http.MethodGet, "/api/test", nil,
)
req.Header.Set("Origin", "http://evil.com") req.Header.Set("Origin", "http://evil.com")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.True( assert.True(t, called, "prod CORS middleware should pass through to handler")
t, called,
"prod CORS middleware should pass through to handler",
)
// In prod, no CORS headers should be set (no-op middleware) // In prod, no CORS headers should be set (no-op middleware)
assert.Empty( assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
t, "prod mode should not set CORS headers")
w.Header().Get("Access-Control-Allow-Origin"),
"prod mode should not set CORS headers",
)
} }
// --- RequireAuth Middleware Tests --- // --- RequireAuth Middleware Tests ---
func TestRequireAuth_NoSession_RedirectsToLogin(t *testing.T) { func TestRequireAuth_NoSession_RedirectsToLogin(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
var called bool var called bool
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
handler := m.RequireAuth()(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
context.Background(),
http.MethodGet, "/dashboard", nil,
)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.False( assert.False(t, called, "handler should not be called for unauthenticated request")
t, called,
"handler should not be called for "+
"unauthenticated request",
)
assert.Equal(t, http.StatusSeeOther, w.Code) assert.Equal(t, http.StatusSeeOther, w.Code)
assert.Equal(t, "/pages/login", w.Header().Get("Location")) assert.Equal(t, "/pages/login", w.Header().Get("Location"))
} }
func TestRequireAuth_AuthenticatedSession_PassesThrough( func TestRequireAuth_AuthenticatedSession_PassesThrough(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
m, sessManager := testMiddleware(t, config.EnvironmentDev) m, sessManager := testMiddleware(t, config.EnvironmentDev)
var called bool var called bool
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
handler := m.RequireAuth()(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
}, }))
))
// Create an authenticated session by making a request, // Create an authenticated session by making a request, setting session data,
// setting session data, and saving the session cookie // and saving the session cookie
setupReq := httptest.NewRequestWithContext( setupReq := httptest.NewRequest(http.MethodGet, "/setup", nil)
context.Background(),
http.MethodGet, "/setup", nil,
)
setupW := httptest.NewRecorder() setupW := httptest.NewRecorder()
sess, err := sessManager.Get(setupReq) sess, err := sessManager.Get(setupReq)
@@ -315,74 +239,47 @@ func TestRequireAuth_AuthenticatedSession_PassesThrough(
require.NotEmpty(t, cookies, "session cookie should be set") require.NotEmpty(t, cookies, "session cookie should be set")
// Make the actual request with the session cookie // Make the actual request with the session cookie
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
context.Background(),
http.MethodGet, "/dashboard", nil,
)
for _, c := range cookies { for _, c := range cookies {
req.AddCookie(c) req.AddCookie(c)
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.True( assert.True(t, called, "handler should be called for authenticated request")
t, called,
"handler should be called for authenticated request",
)
} }
func TestRequireAuth_UnauthenticatedSession_RedirectsToLogin( func TestRequireAuth_UnauthenticatedSession_RedirectsToLogin(t *testing.T) {
t *testing.T,
) {
t.Parallel() t.Parallel()
m, sessManager := testMiddleware(t, config.EnvironmentDev) m, sessManager := testMiddleware(t, config.EnvironmentDev)
var called bool var called bool
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
handler := m.RequireAuth()(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
}, }))
))
// Create a session but don't authenticate it // Create a session but don't authenticate it
setupReq := httptest.NewRequestWithContext( setupReq := httptest.NewRequest(http.MethodGet, "/setup", nil)
context.Background(),
http.MethodGet, "/setup", nil,
)
setupW := httptest.NewRecorder() setupW := httptest.NewRecorder()
sess, err := sessManager.Get(setupReq) sess, err := sessManager.Get(setupReq)
require.NoError(t, err) require.NoError(t, err)
// Don't call SetUser -- session exists but is not // Don't call SetUser session exists but is not authenticated
// authenticated
require.NoError(t, sessManager.Save(setupReq, setupW, sess)) require.NoError(t, sessManager.Save(setupReq, setupW, sess))
cookies := setupW.Result().Cookies() cookies := setupW.Result().Cookies()
require.NotEmpty(t, cookies) require.NotEmpty(t, cookies)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
context.Background(),
http.MethodGet, "/dashboard", nil,
)
for _, c := range cookies { for _, c := range cookies {
req.AddCookie(c) req.AddCookie(c)
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.False( assert.False(t, called, "handler should not be called for unauthenticated session")
t, called,
"handler should not be called for "+
"unauthenticated session",
)
assert.Equal(t, http.StatusSeeOther, w.Code) assert.Equal(t, http.StatusSeeOther, w.Code)
assert.Equal(t, "/pages/login", w.Header().Get("Location")) assert.Equal(t, "/pages/login", w.Header().Get("Location"))
} }
@@ -407,9 +304,7 @@ func TestIpFromHostPort(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
result := ipFromHostPort(tt.input)
result := middleware.IPFromHostPort(tt.input)
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }
@@ -417,124 +312,122 @@ func TestIpFromHostPort(t *testing.T) {
// --- MetricsAuth Tests --- // --- MetricsAuth Tests ---
// metricsAuthMiddleware creates a Middleware configured for func TestMetricsAuth_ValidCredentials(t *testing.T) {
// metrics auth testing. This helper de-duplicates the setup in t.Parallel()
// metrics auth test functions.
func metricsAuthMiddleware(
t *testing.T,
) *middleware.Middleware {
t.Helper()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := &config.Config{ cfg := &config.Config{
Environment: config.EnvironmentDev, Environment: config.EnvironmentDev,
MetricsUsername: "admin", MetricsUsername: "admin",
MetricsPassword: "secret", MetricsPassword: "secret",
} }
key := make([]byte, testKeySize) key := make([]byte, 32)
store := sessions.NewCookieStore(key) store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{Path: "/", MaxAge: 86400} store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
sessManager := session.NewForTest(store, cfg, log, key) sessManager := session.NewForTest(store, cfg, log)
return middleware.NewForTest(log, cfg, sessManager) m := &Middleware{
log: log,
params: &MiddlewareParams{
Config: cfg,
},
session: sessManager,
} }
func TestMetricsAuth_ValidCredentials(t *testing.T) {
t.Parallel()
m := metricsAuthMiddleware(t)
var called bool var called bool
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := m.MetricsAuth()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
called = true called = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
context.Background(),
http.MethodGet, "/metrics", nil,
)
req.SetBasicAuth("admin", "secret") req.SetBasicAuth("admin", "secret")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.True( assert.True(t, called, "handler should be called with valid basic auth")
t, called,
"handler should be called with valid basic auth",
)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
} }
func TestMetricsAuth_InvalidCredentials(t *testing.T) { func TestMetricsAuth_InvalidCredentials(t *testing.T) {
t.Parallel() t.Parallel()
m := metricsAuthMiddleware(t) log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := &config.Config{
Environment: config.EnvironmentDev,
MetricsUsername: "admin",
MetricsPassword: "secret",
}
key := make([]byte, 32)
store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
sessManager := session.NewForTest(store, cfg, log)
m := &Middleware{
log: log,
params: &MiddlewareParams{
Config: cfg,
},
session: sessManager,
}
var called bool var called bool
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := m.MetricsAuth()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
called = true called = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
context.Background(),
http.MethodGet, "/metrics", nil,
)
req.SetBasicAuth("admin", "wrong-password") req.SetBasicAuth("admin", "wrong-password")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.False( assert.False(t, called, "handler should not be called with invalid basic auth")
t, called,
"handler should not be called with invalid basic auth",
)
assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Equal(t, http.StatusUnauthorized, w.Code)
} }
func TestMetricsAuth_NoCredentials(t *testing.T) { func TestMetricsAuth_NoCredentials(t *testing.T) {
t.Parallel() t.Parallel()
m := metricsAuthMiddleware(t) log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := &config.Config{
Environment: config.EnvironmentDev,
MetricsUsername: "admin",
MetricsPassword: "secret",
}
key := make([]byte, 32)
store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
sessManager := session.NewForTest(store, cfg, log)
m := &Middleware{
log: log,
params: &MiddlewareParams{
Config: cfg,
},
session: sessManager,
}
var called bool var called bool
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := m.MetricsAuth()(http.HandlerFunc(
func(_ http.ResponseWriter, _ *http.Request) {
called = true called = true
}, }))
))
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
context.Background(),
http.MethodGet, "/metrics", nil,
)
// No basic auth header // No basic auth header
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.False( assert.False(t, called, "handler should not be called without credentials")
t, called,
"handler should not be called without credentials",
)
assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Equal(t, http.StatusUnauthorized, w.Code)
} }
@@ -542,23 +435,16 @@ func TestMetricsAuth_NoCredentials(t *testing.T) {
func TestCORS_DevMode_AllowsMethods(t *testing.T) { func TestCORS_DevMode_AllowsMethods(t *testing.T) {
t.Parallel() t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev) m, _ := testMiddleware(t, config.EnvironmentDev)
handler := m.CORS()(http.HandlerFunc( handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, }))
))
// Preflight for POST // Preflight for POST
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodOptions, "/api/webhooks", nil)
context.Background(),
http.MethodOptions, "/api/webhooks", nil,
)
req.Header.Set("Origin", "http://localhost:5173") req.Header.Set("Origin", "http://localhost:5173")
req.Header.Set("Access-Control-Request-Method", "POST") req.Header.Set("Access-Control-Request-Method", "POST")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
@@ -572,17 +458,14 @@ func TestCORS_DevMode_AllowsMethods(t *testing.T) {
func TestSessionKeyFormat(t *testing.T) { func TestSessionKeyFormat(t *testing.T) {
t.Parallel() t.Parallel()
// Verify that the session initialization correctly validates // Verify that the session initialization correctly validates key format.
// key format. A proper 32-byte key encoded as base64 should // A proper 32-byte key encoded as base64 should work.
// work. key := make([]byte, 32)
key := make([]byte, testKeySize)
for i := range key { for i := range key {
key[i] = byte(i + 1) key[i] = byte(i + 1)
} }
encoded := base64.StdEncoding.EncodeToString(key) encoded := base64.StdEncoding.EncodeToString(key)
decoded, err := base64.StdEncoding.DecodeString(encoded) decoded, err := base64.StdEncoding.DecodeString(encoded)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, decoded, testKeySize) assert.Len(t, decoded, 32)
} }

View File

@@ -1,64 +0,0 @@
package middleware
import (
"net/http"
"time"
"github.com/go-chi/httprate"
)
const (
// loginRateLimit is the maximum number of login attempts
// per interval.
loginRateLimit = 5
// loginRateInterval is the time window for the rate limit.
loginRateInterval = 1 * time.Minute
)
// LoginRateLimit returns middleware that enforces per-IP rate
// limiting on login attempts using go-chi/httprate. Only POST
// requests are rate-limited; GET requests (rendering the login
// form) pass through unaffected. When the rate limit is exceeded,
// a 429 Too Many Requests response is returned. IP extraction
// honours X-Forwarded-For, X-Real-IP, and True-Client-IP headers
// for reverse-proxy setups.
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
limiter := httprate.Limit(
loginRateLimit,
loginRateInterval,
httprate.WithKeyFuncs(httprate.KeyByRealIP),
httprate.WithLimitHandler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
m.log.Warn("login rate limit exceeded",
"path", r.URL.Path,
)
http.Error(
w,
"Too many login attempts. "+
"Please try again later.",
http.StatusTooManyRequests,
)
},
)),
)
return func(next http.Handler) http.Handler {
limited := limiter(next)
return http.HandlerFunc(func(
w http.ResponseWriter,
r *http.Request,
) {
// Only rate-limit POST requests (actual login
// attempts)
if r.Method != http.MethodPost {
next.ServeHTTP(w, r)
return
}
limited.ServeHTTP(w, r)
})
}
}

View File

@@ -1,147 +0,0 @@
package middleware_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/middleware"
)
func TestLoginRateLimit_AllowsGET(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
var callCount int
handler := m.LoginRateLimit()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
callCount++
w.WriteHeader(http.StatusOK)
},
))
// GET requests should never be rate-limited
for i := range 20 {
req := httptest.NewRequestWithContext(
context.Background(),
http.MethodGet, "/pages/login", nil,
)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(
t, http.StatusOK, w.Code,
"GET request %d should pass", i,
)
}
assert.Equal(t, 20, callCount)
}
func TestLoginRateLimit_LimitsPOST(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
var callCount int
handler := m.LoginRateLimit()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
callCount++
w.WriteHeader(http.StatusOK)
},
))
// First loginRateLimit POST requests should succeed
for i := range middleware.LoginRateLimitConst {
req := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/pages/login", nil,
)
req.RemoteAddr = "10.0.0.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(
t, http.StatusOK, w.Code,
"POST request %d should pass", i,
)
}
// Next POST should be rate-limited
req := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/pages/login", nil,
)
req.RemoteAddr = "10.0.0.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(
t, http.StatusTooManyRequests, w.Code,
"POST after limit should be 429",
)
assert.Equal(t, middleware.LoginRateLimitConst, callCount)
}
func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
t.Parallel()
m, _ := testMiddleware(t, config.EnvironmentDev)
handler := m.LoginRateLimit()(http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
},
))
// Exhaust limit for IP1
for range middleware.LoginRateLimitConst {
req := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/pages/login", nil,
)
req.RemoteAddr = "1.2.3.4:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
// IP1 should be rate-limited
req := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/pages/login", nil,
)
req.RemoteAddr = "1.2.3.4:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
// IP2 should still be allowed
req2 := httptest.NewRequestWithContext(
context.Background(),
http.MethodPost, "/pages/login", nil,
)
req2.RemoteAddr = "5.6.7.8:12345"
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(
t, http.StatusOK, w2.Code,
"different IP should not be affected",
)
}

View File

@@ -1,24 +0,0 @@
package middleware
import (
"log/slog"
"sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/session"
)
// NewForTest creates a Middleware with the minimum dependencies
// needed for testing. This bypasses the fx lifecycle.
func NewForTest(
log *slog.Logger,
cfg *config.Config,
sess *session.Session,
) *Middleware {
return &Middleware{
log: log,
params: &MiddlewareParams{
Config: cfg,
},
session: sess,
}
}

View File

@@ -1,33 +1,18 @@
package server package server
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
) )
const (
// httpReadTimeout is the maximum duration for reading the
// entire request, including the body.
httpReadTimeout = 10 * time.Second
// httpWriteTimeout is the maximum duration before timing out
// writes of the response.
httpWriteTimeout = 10 * time.Second
// httpMaxHeaderBytes is the maximum number of bytes the
// server will read parsing the request headers.
httpMaxHeaderBytes = 1 << 20
)
func (s *Server) serveUntilShutdown() { func (s *Server) serveUntilShutdown() {
listenAddr := fmt.Sprintf(":%d", s.params.Config.Port) listenAddr := fmt.Sprintf(":%d", s.params.Config.Port)
s.httpServer = &http.Server{ s.httpServer = &http.Server{
Addr: listenAddr, Addr: listenAddr,
ReadTimeout: httpReadTimeout, ReadTimeout: 10 * time.Second,
WriteTimeout: httpWriteTimeout, WriteTimeout: 10 * time.Second,
MaxHeaderBytes: httpMaxHeaderBytes, MaxHeaderBytes: 1 << 20,
Handler: s, Handler: s,
} }
@@ -36,21 +21,14 @@ func (s *Server) serveUntilShutdown() {
s.SetupRoutes() s.SetupRoutes()
s.log.Info("http begin listen", "listenaddr", listenAddr) s.log.Info("http begin listen", "listenaddr", listenAddr)
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
err := s.httpServer.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.log.Error("listen error", "error", err) s.log.Error("listen error", "error", err)
if s.cancelFunc != nil { if s.cancelFunc != nil {
s.cancelFunc() s.cancelFunc()
} }
} }
} }
// ServeHTTP delegates to the router. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) ServeHTTP(
w http.ResponseWriter,
r *http.Request,
) {
s.router.ServeHTTP(w, r) s.router.ServeHTTP(w, r)
} }

View File

@@ -11,24 +11,15 @@ import (
"sneak.berlin/go/webhooker/static" "sneak.berlin/go/webhooker/static"
) )
// maxFormBodySize is the maximum allowed request body size (in // maxFormBodySize is the maximum allowed request body size (in bytes) for
// bytes) for form POST endpoints. 1 MB is generous for any form // form POST endpoints. 1 MB is generous for any form submission while
// submission while preventing abuse from oversized payloads. // preventing abuse from oversized payloads.
const maxFormBodySize int64 = 1 * 1024 * 1024 // 1 MB const maxFormBodySize int64 = 1 * 1024 * 1024 // 1 MB
// requestTimeout is the maximum time allowed for a single HTTP
// request.
const requestTimeout = 60 * time.Second
// SetupRoutes configures all HTTP routes and middleware on the
// server's router.
func (s *Server) SetupRoutes() { func (s *Server) SetupRoutes() {
s.router = chi.NewRouter() s.router = chi.NewRouter()
s.setupGlobalMiddleware()
s.setupRoutes()
}
func (s *Server) setupGlobalMiddleware() { // Global middleware stack — applied to every request.
s.router.Use(middleware.Recoverer) s.router.Use(middleware.Recoverer)
s.router.Use(middleware.RequestID) s.router.Use(middleware.RequestID)
s.router.Use(s.mw.SecurityHeaders()) s.router.Use(s.mw.SecurityHeaders())
@@ -40,28 +31,24 @@ func (s *Server) setupGlobalMiddleware() {
} }
s.router.Use(s.mw.CORS()) s.router.Use(s.mw.CORS())
s.router.Use(middleware.Timeout(requestTimeout)) s.router.Use(middleware.Timeout(60 * time.Second))
// Sentry error reporting (if SENTRY_DSN is set). Repanic is // Sentry error reporting (if SENTRY_DSN is set). Repanic is true
// true so panics still bubble up to the Recoverer middleware. // so panics still bubble up to the Recoverer middleware above.
if s.sentryEnabled { if s.sentryEnabled {
sentryHandler := sentryhttp.New(sentryhttp.Options{ sentryHandler := sentryhttp.New(sentryhttp.Options{
Repanic: true, Repanic: true,
}) })
s.router.Use(sentryHandler.Handle) s.router.Use(sentryHandler.Handle)
} }
}
func (s *Server) setupRoutes() { // Routes
s.router.Get("/", s.h.HandleIndex()) s.router.Get("/", s.h.HandleIndex())
s.router.Mount( s.router.Mount("/s", http.StripPrefix("/s", http.FileServer(http.FS(static.Static))))
"/s",
http.StripPrefix("/s", http.FileServer(http.FS(static.Static))),
)
s.router.Route("/api/v1", func(_ chi.Router) { s.router.Route("/api/v1", func(_ chi.Router) {
// API routes will be added here. // TODO: Add API routes here
}) })
s.router.Get( s.router.Get(
@@ -73,89 +60,54 @@ func (s *Server) setupRoutes() {
if s.params.Config.MetricsUsername != "" { if s.params.Config.MetricsUsername != "" {
s.router.Group(func(r chi.Router) { s.router.Group(func(r chi.Router) {
r.Use(s.mw.MetricsAuth()) r.Use(s.mw.MetricsAuth())
r.Get( r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP))
"/metrics",
http.HandlerFunc(
promhttp.Handler().ServeHTTP,
),
)
}) })
} }
s.setupPageRoutes() // pages that are rendered server-side
s.setupUserRoutes()
s.setupSourceRoutes()
s.setupWebhookRoutes()
}
func (s *Server) setupPageRoutes() {
s.router.Route("/pages", func(r chi.Router) { s.router.Route("/pages", func(r chi.Router) {
r.Use(s.mw.CSRF())
r.Use(s.mw.MaxBodySize(maxFormBodySize)) r.Use(s.mw.MaxBodySize(maxFormBodySize))
r.Group(func(r chi.Router) { // Login page (no auth required)
r.Use(s.mw.LoginRateLimit())
r.Get("/login", s.h.HandleLoginPage()) r.Get("/login", s.h.HandleLoginPage())
r.Post("/login", s.h.HandleLoginSubmit()) r.Post("/login", s.h.HandleLoginSubmit())
})
// Logout (auth required)
r.Post("/logout", s.h.HandleLogout()) r.Post("/logout", s.h.HandleLogout())
}) })
}
func (s *Server) setupUserRoutes() { // User profile routes
s.router.Route("/user/{username}", func(r chi.Router) { s.router.Route("/user/{username}", func(r chi.Router) {
r.Use(s.mw.CSRF())
r.Get("/", s.h.HandleProfile()) r.Get("/", s.h.HandleProfile())
}) })
}
func (s *Server) setupSourceRoutes() { // Webhook management routes (require authentication)
s.router.Route("/sources", func(r chi.Router) { s.router.Route("/sources", func(r chi.Router) {
r.Use(s.mw.CSRF())
r.Use(s.mw.RequireAuth()) r.Use(s.mw.RequireAuth())
r.Use(s.mw.MaxBodySize(maxFormBodySize)) r.Use(s.mw.MaxBodySize(maxFormBodySize))
r.Get("/", s.h.HandleSourceList()) r.Get("/", s.h.HandleSourceList()) // List all webhooks
r.Get("/new", s.h.HandleSourceCreate()) r.Get("/new", s.h.HandleSourceCreate()) // Show create form
r.Post("/new", s.h.HandleSourceCreateSubmit()) r.Post("/new", s.h.HandleSourceCreateSubmit()) // Handle create submission
}) })
s.router.Route("/source/{sourceID}", func(r chi.Router) { s.router.Route("/source/{sourceID}", func(r chi.Router) {
r.Use(s.mw.CSRF())
r.Use(s.mw.RequireAuth()) r.Use(s.mw.RequireAuth())
r.Use(s.mw.MaxBodySize(maxFormBodySize)) r.Use(s.mw.MaxBodySize(maxFormBodySize))
r.Get("/", s.h.HandleSourceDetail()) r.Get("/", s.h.HandleSourceDetail()) // View webhook details
r.Get("/edit", s.h.HandleSourceEdit()) r.Get("/edit", s.h.HandleSourceEdit()) // Show edit form
r.Post("/edit", s.h.HandleSourceEditSubmit()) r.Post("/edit", s.h.HandleSourceEditSubmit()) // Handle edit submission
r.Post("/delete", s.h.HandleSourceDelete()) r.Post("/delete", s.h.HandleSourceDelete()) // Delete webhook
r.Get("/logs", s.h.HandleSourceLogs()) r.Get("/logs", s.h.HandleSourceLogs()) // View webhook logs
r.Post( r.Post("/entrypoints", s.h.HandleEntrypointCreate()) // Add entrypoint
"/entrypoints", r.Post("/entrypoints/{entrypointID}/delete", s.h.HandleEntrypointDelete()) // Delete entrypoint
s.h.HandleEntrypointCreate(), r.Post("/entrypoints/{entrypointID}/toggle", s.h.HandleEntrypointToggle()) // Toggle entrypoint active
) r.Post("/targets", s.h.HandleTargetCreate()) // Add target
r.Post( r.Post("/targets/{targetID}/delete", s.h.HandleTargetDelete()) // Delete target
"/entrypoints/{entrypointID}/delete", r.Post("/targets/{targetID}/toggle", s.h.HandleTargetToggle()) // Toggle target active
s.h.HandleEntrypointDelete(),
)
r.Post(
"/entrypoints/{entrypointID}/toggle",
s.h.HandleEntrypointToggle(),
)
r.Post("/targets", s.h.HandleTargetCreate())
r.Post(
"/targets/{targetID}/delete",
s.h.HandleTargetDelete(),
)
r.Post(
"/targets/{targetID}/toggle",
s.h.HandleTargetToggle(),
)
}) })
}
func (s *Server) setupWebhookRoutes() { // Entrypoint endpoint — accepts incoming webhook POST requests only.
s.router.HandleFunc( // Using HandleFunc so the handler itself can return 405 for non-POST
"/webhook/{uuid}", // methods (chi's Method routing returns 405 without Allow header).
s.h.HandleWebhook(), s.router.HandleFunc("/webhook/{uuid}", s.h.HandleWebhook())
)
} }

View File

@@ -1,5 +1,3 @@
// Package server wires up HTTP routes and manages the
// application lifecycle.
package server package server
import ( import (
@@ -23,20 +21,9 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
) )
const ( // nolint:revive // ServerParams is a standard fx naming convention
// shutdownTimeout is the maximum time to wait for the HTTP
// server to finish in-flight requests during shutdown.
shutdownTimeout = 5 * time.Second
// sentryFlushTimeout is the maximum time to wait for Sentry
// to flush pending events during shutdown.
sentryFlushTimeout = 2 * time.Second
)
//nolint:revive // ServerParams is a standard fx naming convention.
type ServerParams struct { type ServerParams struct {
fx.In fx.In
Logger *logger.Logger Logger *logger.Logger
Globals *globals.Globals Globals *globals.Globals
Config *config.Config Config *config.Config
@@ -44,13 +31,12 @@ type ServerParams struct {
Handlers *handlers.Handlers Handlers *handlers.Handlers
} }
// Server is the main HTTP server that wires up routes and manages
// graceful shutdown.
type Server struct { type Server struct {
startupTime time.Time startupTime time.Time
exitCode int exitCode int
sentryEnabled bool sentryEnabled bool
log *slog.Logger log *slog.Logger
ctx context.Context
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
httpServer *http.Server httpServer *http.Server
router *chi.Mux router *chi.Mux
@@ -59,8 +45,6 @@ type Server struct {
h *handlers.Handlers h *handlers.Handlers
} }
// New creates a Server that starts the HTTP listener on fx start
// and stops it gracefully.
func New(lc fx.Lifecycle, params ServerParams) (*Server, error) { func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
s := new(Server) s := new(Server)
s.params = params s.params = params
@@ -69,23 +53,19 @@ func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
s.log = params.Logger.Get() s.log = params.Logger.Get()
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(ctx context.Context) error {
s.startupTime = time.Now() s.startupTime = time.Now()
go s.Run() go s.Run()
return nil return nil
}, },
OnStop: func(ctx context.Context) error { OnStop: func(ctx context.Context) error {
s.cleanShutdown(ctx) s.cleanShutdown()
return nil return nil
}, },
}) })
return s, nil return s, nil
} }
// Run configures Sentry and starts serving HTTP requests.
func (s *Server) Run() { func (s *Server) Run() {
s.configure() s.configure()
@@ -95,12 +75,6 @@ func (s *Server) Run() {
s.serve() s.serve()
} }
// MaintenanceMode returns whether the server is in maintenance
// mode.
func (s *Server) MaintenanceMode() bool {
return s.params.Config.MaintenanceMode
}
func (s *Server) enableSentry() { func (s *Server) enableSentry() {
s.sentryEnabled = false s.sentryEnabled = false
@@ -110,36 +84,28 @@ func (s *Server) enableSentry() {
err := sentry.Init(sentry.ClientOptions{ err := sentry.Init(sentry.ClientOptions{
Dsn: s.params.Config.SentryDSN, Dsn: s.params.Config.SentryDSN,
Release: fmt.Sprintf( Release: fmt.Sprintf("%s-%s", s.params.Globals.Appname, s.params.Globals.Version),
"%s-%s",
s.params.Globals.Appname,
s.params.Globals.Version,
),
}) })
if err != nil { if err != nil {
s.log.Error("sentry init failure", "error", err) s.log.Error("sentry init failure", "error", err)
// Don't use fatal since we still want the service to run // Don't use fatal since we still want the service to run
return return
} }
s.log.Info("sentry error reporting activated") s.log.Info("sentry error reporting activated")
s.sentryEnabled = true s.sentryEnabled = true
} }
func (s *Server) serve() int { func (s *Server) serve() int {
ctx, cancelFunc := context.WithCancel(context.Background()) s.ctx, s.cancelFunc = context.WithCancel(context.Background())
s.cancelFunc = cancelFunc
// signal watcher // signal watcher
go func() { go func() {
c := make(chan os.Signal, 1) c := make(chan os.Signal, 1)
signal.Ignore(syscall.SIGPIPE) signal.Ignore(syscall.SIGPIPE)
signal.Notify(c, os.Interrupt, syscall.SIGTERM) signal.Notify(c, os.Interrupt, syscall.SIGTERM)
// block and wait for signal // block and wait for signal
sig := <-c sig := <-c
s.log.Info("signal received", "signal", sig.String()) s.log.Info("signal received", "signal", sig.String())
if s.cancelFunc != nil { if s.cancelFunc != nil {
// cancelling the main context will trigger a clean // cancelling the main context will trigger a clean
// shutdown via the fx OnStop hook. // shutdown via the fx OnStop hook.
@@ -149,9 +115,9 @@ func (s *Server) serve() int {
go s.serveUntilShutdown() go s.serveUntilShutdown()
<-ctx.Done() <-s.ctx.Done()
// Shutdown is handled by the fx OnStop hook (cleanShutdown). // Shutdown is handled by the fx OnStop hook (cleanShutdown).
// Do not call cleanShutdown() here to avoid double invocation. // Do not call cleanShutdown() here to avoid a double invocation.
return s.exitCode return s.exitCode
} }
@@ -159,29 +125,27 @@ func (s *Server) cleanupForExit() {
s.log.Info("cleaning up") s.log.Info("cleaning up")
} }
func (s *Server) cleanShutdown(ctx context.Context) { func (s *Server) cleanShutdown() {
// initiate clean shutdown // initiate clean shutdown
s.exitCode = 0 s.exitCode = 0
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
ctxShutdown, shutdownCancel := context.WithTimeout(
ctx, shutdownTimeout,
)
defer shutdownCancel() defer shutdownCancel()
err := s.httpServer.Shutdown(ctxShutdown) if err := s.httpServer.Shutdown(ctxShutdown); err != nil {
if err != nil { s.log.Error("server clean shutdown failed", "error", err)
s.log.Error(
"server clean shutdown failed", "error", err,
)
} }
s.cleanupForExit() s.cleanupForExit()
if s.sentryEnabled { if s.sentryEnabled {
sentry.Flush(sentryFlushTimeout) sentry.Flush(2 * time.Second)
} }
} }
func (s *Server) MaintenanceMode() bool {
return s.params.Config.MaintenanceMode
}
func (s *Server) configure() { func (s *Server) configure() {
// identify ourselves in the logs // identify ourselves in the logs
s.params.Logger.Identify() s.params.Logger.Identify()

View File

@@ -1,14 +1,10 @@
// Package session manages HTTP session storage and authentication
// state.
package session package session
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"maps"
"net/http" "net/http"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
@@ -19,89 +15,57 @@ import (
) )
const ( const (
// SessionName is the name of the session cookie. // SessionName is the name of the session cookie
SessionName = "webhooker_session" SessionName = "webhooker_session"
// UserIDKey is the session key for user ID. // UserIDKey is the session key for user ID
UserIDKey = "user_id" UserIDKey = "user_id"
// UsernameKey is the session key for username. // UsernameKey is the session key for username
UsernameKey = "username" UsernameKey = "username"
// AuthenticatedKey is the session key for authentication // AuthenticatedKey is the session key for authentication status
// status.
AuthenticatedKey = "authenticated" AuthenticatedKey = "authenticated"
// sessionKeyLength is the required length in bytes for the
// session authentication key.
sessionKeyLength = 32
// sessionMaxAgeDays is the session cookie lifetime in days.
sessionMaxAgeDays = 7
// secondsPerDay is the number of seconds in a day.
secondsPerDay = 86400
) )
// ErrSessionKeyLength is returned when the decoded session key // nolint:revive // SessionParams is a standard fx naming convention
// does not have the expected length. type SessionParams struct {
var ErrSessionKeyLength = errors.New("session key length mismatch")
// Params holds dependencies injected by fx.
type Params struct {
fx.In fx.In
Config *config.Config Config *config.Config
Database *database.Database Database *database.Database
Logger *logger.Logger Logger *logger.Logger
} }
// Session manages encrypted session storage. // Session manages encrypted session storage
type Session struct { type Session struct {
store *sessions.CookieStore store *sessions.CookieStore
key []byte // raw 32-byte auth key, also used for CSRF cookie signing
log *slog.Logger log *slog.Logger
config *config.Config config *config.Config
} }
// New creates a new session manager. The cookie store is // New creates a new session manager. The cookie store is initialized
// initialized during the fx OnStart phase after the database is // during the fx OnStart phase after the database is connected, using
// connected, using a session key that is auto-generated and stored // a session key that is auto-generated and stored in the database.
// in the database. func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
func New(
lc fx.Lifecycle,
params Params,
) (*Session, error) {
s := &Session{ s := &Session{
log: params.Logger.Get(), log: params.Logger.Get(),
config: params.Config, config: params.Config,
} }
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
sessionKey, err := params.Database.GetOrCreateSessionKey() sessionKey, err := params.Database.GetOrCreateSessionKey()
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf("failed to get session key: %w", err)
"failed to get session key: %w", err,
)
} }
keyBytes, err := base64.StdEncoding.DecodeString( keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
sessionKey,
)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf("invalid session key format: %w", err)
"invalid session key format: %w", err,
)
} }
if len(keyBytes) != sessionKeyLength { if len(keyBytes) != 32 {
return fmt.Errorf( return fmt.Errorf("session key must be 32 bytes (got %d)", len(keyBytes))
"%w: want %d, got %d",
ErrSessionKeyLength,
sessionKeyLength,
len(keyBytes),
)
} }
store := sessions.NewCookieStore(keyBytes) store := sessions.NewCookieStore(keyBytes)
@@ -109,16 +73,14 @@ func New(
// Configure cookie options for security // Configure cookie options for security
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: "/",
MaxAge: secondsPerDay * sessionMaxAgeDays, MaxAge: 86400 * 7, // 7 days
HttpOnly: true, HttpOnly: true,
Secure: !params.Config.IsDev(), Secure: !params.Config.IsDev(), // HTTPS in production
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
} }
s.key = keyBytes
s.store = store s.store = store
s.log.Info("session manager initialized") s.log.Info("session manager initialized")
return nil return nil
}, },
}) })
@@ -126,126 +88,93 @@ func New(
return s, nil return s, nil
} }
// Get retrieves a session for the request. // Get retrieves a session for the request
func (s *Session) Get( func (s *Session) Get(r *http.Request) (*sessions.Session, error) {
r *http.Request,
) (*sessions.Session, error) {
return s.store.Get(r, SessionName) return s.store.Get(r, SessionName)
} }
// GetKey returns the raw 32-byte authentication key used for // Save saves the session
// session encryption. This key is also suitable for CSRF cookie func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
// signing.
func (s *Session) GetKey() []byte {
return s.key
}
// Save saves the session.
func (s *Session) Save(
r *http.Request,
w http.ResponseWriter,
sess *sessions.Session,
) error {
return sess.Save(r, w) return sess.Save(r, w)
} }
// SetUser sets the user information in the session. // SetUser sets the user information in the session
func (s *Session) SetUser( func (s *Session) SetUser(sess *sessions.Session, userID, username string) {
sess *sessions.Session,
userID, username string,
) {
sess.Values[UserIDKey] = userID sess.Values[UserIDKey] = userID
sess.Values[UsernameKey] = username sess.Values[UsernameKey] = username
sess.Values[AuthenticatedKey] = true sess.Values[AuthenticatedKey] = true
} }
// ClearUser removes user information from the session. // ClearUser removes user information from the session
func (s *Session) ClearUser(sess *sessions.Session) { func (s *Session) ClearUser(sess *sessions.Session) {
delete(sess.Values, UserIDKey) delete(sess.Values, UserIDKey)
delete(sess.Values, UsernameKey) delete(sess.Values, UsernameKey)
delete(sess.Values, AuthenticatedKey) delete(sess.Values, AuthenticatedKey)
} }
// IsAuthenticated checks if the session has an authenticated // IsAuthenticated checks if the session has an authenticated user
// user.
func (s *Session) IsAuthenticated(sess *sessions.Session) bool { func (s *Session) IsAuthenticated(sess *sessions.Session) bool {
auth, ok := sess.Values[AuthenticatedKey].(bool) auth, ok := sess.Values[AuthenticatedKey].(bool)
return ok && auth return ok && auth
} }
// GetUserID retrieves the user ID from the session. // GetUserID retrieves the user ID from the session
func (s *Session) GetUserID( func (s *Session) GetUserID(sess *sessions.Session) (string, bool) {
sess *sessions.Session,
) (string, bool) {
userID, ok := sess.Values[UserIDKey].(string) userID, ok := sess.Values[UserIDKey].(string)
return userID, ok return userID, ok
} }
// GetUsername retrieves the username from the session. // GetUsername retrieves the username from the session
func (s *Session) GetUsername( func (s *Session) GetUsername(sess *sessions.Session) (string, bool) {
sess *sessions.Session,
) (string, bool) {
username, ok := sess.Values[UsernameKey].(string) username, ok := sess.Values[UsernameKey].(string)
return username, ok return username, ok
} }
// Destroy invalidates the session. // Destroy invalidates the session
func (s *Session) Destroy(sess *sessions.Session) { func (s *Session) Destroy(sess *sessions.Session) {
sess.Options.MaxAge = -1 sess.Options.MaxAge = -1
s.ClearUser(sess) s.ClearUser(sess)
} }
// Regenerate creates a new session with the same values but a // Regenerate creates a new session with the same values but a fresh ID.
// fresh ID. The old session is destroyed (MaxAge = -1) and saved, // The old session is destroyed (MaxAge = -1) and saved, then a new session
// then a new session is created. This prevents session fixation // is created. This prevents session fixation attacks by ensuring the
// attacks by ensuring the session ID changes after privilege // session ID changes after privilege escalation (e.g. login).
// escalation (e.g. login). func (s *Session) Regenerate(r *http.Request, w http.ResponseWriter, oldSess *sessions.Session) (*sessions.Session, error) {
func (s *Session) Regenerate(
r *http.Request,
w http.ResponseWriter,
oldSess *sessions.Session,
) (*sessions.Session, error) {
// Copy the values from the old session // Copy the values from the old session
oldValues := make(map[any]any) oldValues := make(map[interface{}]interface{})
maps.Copy(oldValues, oldSess.Values) for k, v := range oldSess.Values {
oldValues[k] = v
}
// Destroy the old session // Destroy the old session
oldSess.Options.MaxAge = -1 oldSess.Options.MaxAge = -1
s.ClearUser(oldSess) s.ClearUser(oldSess)
if err := oldSess.Save(r, w); err != nil {
err := oldSess.Save(r, w) return nil, fmt.Errorf("failed to destroy old session: %w", err)
if err != nil {
return nil, fmt.Errorf(
"failed to destroy old session: %w", err,
)
} }
// Create a new session (gorilla/sessions generates a new ID) // Create a new session (gorilla/sessions generates a new ID)
newSess, err := s.store.New(r, SessionName) newSess, err := s.store.New(r, SessionName)
if err != nil { if err != nil {
// store.New may return an error alongside a new empty // store.New may return an error alongside a new empty session
// session if the old cookie is now invalid. That is // if the old cookie is now invalid. That is expected after we
// expected after we destroyed it above. Only fail on a // destroyed it above. Only fail on a nil session.
// nil session.
if newSess == nil { if newSess == nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("failed to create new session: %w", err)
"failed to create new session: %w", err,
)
} }
} }
// Restore the copied values into the new session // Restore the copied values into the new session
maps.Copy(newSess.Values, oldValues) for k, v := range oldValues {
newSess.Values[k] = v
}
// Apply the standard session options (the destroyed old // Apply the standard session options (the destroyed old session had
// session had MaxAge = -1, which store.New might inherit // MaxAge = -1, which store.New might inherit from the cookie).
// from the cookie).
newSess.Options = &sessions.Options{ newSess.Options = &sessions.Options{
Path: "/", Path: "/",
MaxAge: secondsPerDay * sessionMaxAgeDays, MaxAge: 86400 * 7,
HttpOnly: true, HttpOnly: true,
Secure: !s.config.IsDev(), Secure: !s.config.IsDev(),
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,

View File

@@ -1,7 +1,6 @@
package session_test package session
import ( import (
"context"
"log/slog" "log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -12,22 +11,15 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"sneak.berlin/go/webhooker/internal/config" "sneak.berlin/go/webhooker/internal/config"
"sneak.berlin/go/webhooker/internal/session"
) )
const testKeySize = 32 // testSession creates a Session with a real cookie store for testing.
func testSession(t *testing.T) *Session {
// testSession creates a Session with a real cookie store for
// testing.
func testSession(t *testing.T) *session.Session {
t.Helper() t.Helper()
key := make([]byte, 32)
key := make([]byte, testKeySize)
for i := range key { for i := range key {
key[i] = byte(i + 42) key[i] = byte(i + 42)
} }
store := sessions.NewCookieStore(key) store := sessions.NewCookieStore(key)
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: "/",
@@ -40,47 +32,34 @@ func testSession(t *testing.T) *session.Session {
cfg := &config.Config{ cfg := &config.Config{
Environment: config.EnvironmentDev, Environment: config.EnvironmentDev,
} }
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
log := slog.New(slog.NewTextHandler( return NewForTest(store, cfg, log)
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
return session.NewForTest(store, cfg, log, key)
} }
// --- Get and Save Tests --- // --- Get and Save Tests ---
func TestGet_NewSession(t *testing.T) { func TestGet_NewSession(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, sess) require.NotNil(t, sess)
assert.True( assert.True(t, sess.IsNew, "session should be new when no cookie is present")
t, sess.IsNew,
"session should be new when no cookie is present",
)
} }
func TestGet_ExistingSession(t *testing.T) { func TestGet_ExistingSession(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
// Create and save a session // Create and save a session
req1 := httptest.NewRequestWithContext( req1 := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
sess1, err := s.Get(req1) sess1, err := s.Get(req1)
require.NoError(t, err) require.NoError(t, err)
sess1.Values["test_key"] = "test_value" sess1.Values["test_key"] = "test_value"
require.NoError(t, s.Save(req1, w1, sess1)) require.NoError(t, s.Save(req1, w1, sess1))
@@ -89,34 +68,26 @@ func TestGet_ExistingSession(t *testing.T) {
require.NotEmpty(t, cookies) require.NotEmpty(t, cookies)
// Make a new request with the session cookie // Make a new request with the session cookie
req2 := httptest.NewRequestWithContext( req2 := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
for _, c := range cookies { for _, c := range cookies {
req2.AddCookie(c) req2.AddCookie(c)
} }
sess2, err := s.Get(req2) sess2, err := s.Get(req2)
require.NoError(t, err) require.NoError(t, err)
assert.False( assert.False(t, sess2.IsNew, "session should not be new when cookie is present")
t, sess2.IsNew,
"session should not be new when cookie is present",
)
assert.Equal(t, "test_value", sess2.Values["test_key"]) assert.Equal(t, "test_value", sess2.Values["test_key"])
} }
func TestSave_SetsCookie(t *testing.T) { func TestSave_SetsCookie(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
sess.Values["key"] = "value" sess.Values["key"] = "value"
err = s.Save(req, w, sess) err = s.Save(req, w, sess)
@@ -127,73 +98,48 @@ func TestSave_SetsCookie(t *testing.T) {
// Verify the cookie has the expected name // Verify the cookie has the expected name
var found bool var found bool
for _, c := range cookies { for _, c := range cookies {
if c.Name == session.SessionName { if c.Name == SessionName {
found = true found = true
assert.True(t, c.HttpOnly, "session cookie should be HTTP-only")
assert.True(
t, c.HttpOnly,
"session cookie should be HTTP-only",
)
break break
} }
} }
assert.True(t, found, "should find a cookie named %s", SessionName)
assert.True(
t, found,
"should find a cookie named %s", session.SessionName,
)
} }
// --- SetUser and User Retrieval Tests --- // --- SetUser and User Retrieval Tests ---
func TestSetUser_SetsAllFields(t *testing.T) { func TestSetUser_SetsAllFields(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
s.SetUser(sess, "user-abc-123", "alice") s.SetUser(sess, "user-abc-123", "alice")
assert.Equal( assert.Equal(t, "user-abc-123", sess.Values[UserIDKey])
t, "user-abc-123", sess.Values[session.UserIDKey], assert.Equal(t, "alice", sess.Values[UsernameKey])
) assert.Equal(t, true, sess.Values[AuthenticatedKey])
assert.Equal(
t, "alice", sess.Values[session.UsernameKey],
)
assert.Equal(
t, true, sess.Values[session.AuthenticatedKey],
)
} }
func TestGetUserID(t *testing.T) { func TestGetUserID(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
// Before setting user // Before setting user
userID, ok := s.GetUserID(sess) userID, ok := s.GetUserID(sess)
assert.False( assert.False(t, ok, "should return false when no user ID is set")
t, ok, "should return false when no user ID is set",
)
assert.Empty(t, userID) assert.Empty(t, userID)
// After setting user // After setting user
s.SetUser(sess, "user-xyz", "bob") s.SetUser(sess, "user-xyz", "bob")
userID, ok = s.GetUserID(sess) userID, ok = s.GetUserID(sess)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "user-xyz", userID) assert.Equal(t, "user-xyz", userID)
@@ -201,25 +147,19 @@ func TestGetUserID(t *testing.T) {
func TestGetUsername(t *testing.T) { func TestGetUsername(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
// Before setting user // Before setting user
username, ok := s.GetUsername(sess) username, ok := s.GetUsername(sess)
assert.False( assert.False(t, ok, "should return false when no username is set")
t, ok, "should return false when no username is set",
)
assert.Empty(t, username) assert.Empty(t, username)
// After setting user // After setting user
s.SetUser(sess, "user-xyz", "bob") s.SetUser(sess, "user-xyz", "bob")
username, ok = s.GetUsername(sess) username, ok = s.GetUsername(sess)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "bob", username) assert.Equal(t, "bob", username)
@@ -229,29 +169,20 @@ func TestGetUsername(t *testing.T) {
func TestIsAuthenticated_NoSession(t *testing.T) { func TestIsAuthenticated_NoSession(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
assert.False( assert.False(t, s.IsAuthenticated(sess), "new session should not be authenticated")
t, s.IsAuthenticated(sess),
"new session should not be authenticated",
)
} }
func TestIsAuthenticated_AfterSetUser(t *testing.T) { func TestIsAuthenticated_AfterSetUser(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
@@ -261,12 +192,9 @@ func TestIsAuthenticated_AfterSetUser(t *testing.T) {
func TestIsAuthenticated_AfterClearUser(t *testing.T) { func TestIsAuthenticated_AfterClearUser(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
@@ -274,71 +202,52 @@ func TestIsAuthenticated_AfterClearUser(t *testing.T) {
require.True(t, s.IsAuthenticated(sess)) require.True(t, s.IsAuthenticated(sess))
s.ClearUser(sess) s.ClearUser(sess)
assert.False(t, s.IsAuthenticated(sess), "should not be authenticated after ClearUser")
assert.False(
t, s.IsAuthenticated(sess),
"should not be authenticated after ClearUser",
)
} }
func TestIsAuthenticated_WrongType(t *testing.T) { func TestIsAuthenticated_WrongType(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
// Set authenticated to a non-bool value // Set authenticated to a non-bool value
sess.Values[session.AuthenticatedKey] = "yes" sess.Values[AuthenticatedKey] = "yes"
assert.False(t, s.IsAuthenticated(sess), "should return false for non-bool authenticated value")
assert.False(
t, s.IsAuthenticated(sess),
"should return false for non-bool authenticated value",
)
} }
// --- ClearUser Tests --- // --- ClearUser Tests ---
func TestClearUser_RemovesAllKeys(t *testing.T) { func TestClearUser_RemovesAllKeys(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
s.SetUser(sess, "user-123", "alice") s.SetUser(sess, "user-123", "alice")
s.ClearUser(sess) s.ClearUser(sess)
_, hasUserID := sess.Values[session.UserIDKey] _, hasUserID := sess.Values[UserIDKey]
assert.False(t, hasUserID, "UserIDKey should be removed") assert.False(t, hasUserID, "UserIDKey should be removed")
_, hasUsername := sess.Values[session.UsernameKey] _, hasUsername := sess.Values[UsernameKey]
assert.False(t, hasUsername, "UsernameKey should be removed") assert.False(t, hasUsername, "UsernameKey should be removed")
_, hasAuth := sess.Values[session.AuthenticatedKey] _, hasAuth := sess.Values[AuthenticatedKey]
assert.False( assert.False(t, hasAuth, "AuthenticatedKey should be removed")
t, hasAuth, "AuthenticatedKey should be removed",
)
} }
// --- Destroy Tests --- // --- Destroy Tests ---
func TestDestroy_InvalidatesSession(t *testing.T) { func TestDestroy_InvalidatesSession(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
@@ -346,18 +255,11 @@ func TestDestroy_InvalidatesSession(t *testing.T) {
s.Destroy(sess) s.Destroy(sess)
// After Destroy: MaxAge should be -1 (delete cookie) and // After Destroy: MaxAge should be -1 (delete cookie) and user data cleared
// user data cleared assert.Equal(t, -1, sess.Options.MaxAge, "Destroy should set MaxAge to -1")
assert.Equal( assert.False(t, s.IsAuthenticated(sess), "should not be authenticated after Destroy")
t, -1, sess.Options.MaxAge,
"Destroy should set MaxAge to -1",
)
assert.False(
t, s.IsAuthenticated(sess),
"should not be authenticated after Destroy",
)
_, hasUserID := sess.Values[session.UserIDKey] _, hasUserID := sess.Values[UserIDKey]
assert.False(t, hasUserID, "Destroy should clear user ID") assert.False(t, hasUserID, "Destroy should clear user ID")
} }
@@ -365,12 +267,10 @@ func TestDestroy_InvalidatesSession(t *testing.T) {
func TestSessionPersistence_RoundTrip(t *testing.T) { func TestSessionPersistence_RoundTrip(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
// Step 1: Create session, set user, save // Step 1: Create session, set user, save
req1 := httptest.NewRequestWithContext( req1 := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
sess1, err := s.Get(req1) sess1, err := s.Get(req1)
@@ -381,13 +281,8 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
cookies := w1.Result().Cookies() cookies := w1.Result().Cookies()
require.NotEmpty(t, cookies) require.NotEmpty(t, cookies)
// Step 2: New request with cookies -- session data should // Step 2: New request with cookies session data should persist
// persist req2 := httptest.NewRequest(http.MethodGet, "/profile", nil)
req2 := httptest.NewRequestWithContext(
context.Background(),
http.MethodGet, "/profile", nil,
)
for _, c := range cookies { for _, c := range cookies {
req2.AddCookie(c) req2.AddCookie(c)
} }
@@ -395,10 +290,7 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
sess2, err := s.Get(req2) sess2, err := s.Get(req2)
require.NoError(t, err) require.NoError(t, err)
assert.True( assert.True(t, s.IsAuthenticated(sess2), "session should be authenticated after round-trip")
t, s.IsAuthenticated(sess2),
"session should be authenticated after round-trip",
)
userID, ok := s.GetUserID(sess2) userID, ok := s.GetUserID(sess2)
assert.True(t, ok) assert.True(t, ok)
@@ -413,23 +305,19 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
func TestSessionConstants(t *testing.T) { func TestSessionConstants(t *testing.T) {
t.Parallel() t.Parallel()
assert.Equal(t, "webhooker_session", SessionName)
assert.Equal(t, "webhooker_session", session.SessionName) assert.Equal(t, "user_id", UserIDKey)
assert.Equal(t, "user_id", session.UserIDKey) assert.Equal(t, "username", UsernameKey)
assert.Equal(t, "username", session.UsernameKey) assert.Equal(t, "authenticated", AuthenticatedKey)
assert.Equal(t, "authenticated", session.AuthenticatedKey)
} }
// --- Edge Cases --- // --- Edge Cases ---
func TestSetUser_OverwritesPreviousUser(t *testing.T) { func TestSetUser_OverwritesPreviousUser(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
req := httptest.NewRequestWithContext( req := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
sess, err := s.Get(req) sess, err := s.Get(req)
require.NoError(t, err) require.NoError(t, err)
@@ -450,12 +338,10 @@ func TestSetUser_OverwritesPreviousUser(t *testing.T) {
func TestDestroy_ThenSave_DeletesCookie(t *testing.T) { func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
t.Parallel() t.Parallel()
s := testSession(t) s := testSession(t)
// Create a session // Create a session
req1 := httptest.NewRequestWithContext( req1 := httptest.NewRequest(http.MethodGet, "/", nil)
context.Background(), http.MethodGet, "/", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
sess, err := s.Get(req1) sess, err := s.Get(req1)
@@ -467,15 +353,10 @@ func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
require.NotEmpty(t, cookies) require.NotEmpty(t, cookies)
// Destroy and save // Destroy and save
req2 := httptest.NewRequestWithContext( req2 := httptest.NewRequest(http.MethodGet, "/logout", nil)
context.Background(),
http.MethodGet, "/logout", nil,
)
for _, c := range cookies { for _, c := range cookies {
req2.AddCookie(c) req2.AddCookie(c)
} }
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
sess2, err := s.Get(req2) sess2, err := s.Get(req2)
@@ -483,25 +364,15 @@ func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
s.Destroy(sess2) s.Destroy(sess2)
require.NoError(t, s.Save(req2, w2, sess2)) require.NoError(t, s.Save(req2, w2, sess2))
// The cookie should have MaxAge = -1 (browser should delete) // The cookie should have MaxAge = -1 (browser should delete it)
responseCookies := w2.Result().Cookies() responseCookies := w2.Result().Cookies()
var sessionCookie *http.Cookie var sessionCookie *http.Cookie
for _, c := range responseCookies { for _, c := range responseCookies {
if c.Name == session.SessionName { if c.Name == SessionName {
sessionCookie = c sessionCookie = c
break break
} }
} }
require.NotNil(t, sessionCookie, "should have a session cookie in response")
require.NotNil( assert.True(t, sessionCookie.MaxAge < 0, "destroyed session cookie should have negative MaxAge")
t, sessionCookie,
"should have a session cookie in response",
)
assert.Negative(
t, sessionCookie.MaxAge,
"destroyed session cookie should have negative MaxAge",
)
} }

View File

@@ -9,13 +9,10 @@ import (
// NewForTest creates a Session with a pre-configured cookie store for use // NewForTest creates a Session with a pre-configured cookie store for use
// in tests. This bypasses the fx lifecycle and database dependency, allowing // in tests. This bypasses the fx lifecycle and database dependency, allowing
// middleware and handler tests to use real session functionality. The key // middleware and handler tests to use real session functionality.
// parameter is the raw 32-byte authentication key used for session encryption func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *Session {
// and CSRF cookie signing.
func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *Session {
return &Session{ return &Session{
store: store, store: store,
key: key,
config: cfg, config: cfg,
log: log, log: log,
} }

View File

@@ -1,11 +1,8 @@
// Package static embeds static assets (CSS, JS) served by the web UI.
package static package static
import ( import (
"embed" "embed"
) )
// Static holds the embedded CSS and JavaScript files for the web UI.
//
//go:embed css js //go:embed css js
var Static embed.FS var Static embed.FS

67
templates/index.html Normal file
View File

@@ -0,0 +1,67 @@
{{template "base" .}}
{{define "title"}}Home - Webhooker{{end}}
{{define "content"}}
<div class="max-w-4xl mx-auto px-6 py-12">
<div class="text-center mb-10">
<h1 class="text-4xl font-medium text-gray-900">Welcome to Webhooker</h1>
<p class="mt-3 text-lg text-gray-500">A reliable webhook proxy service for event delivery</p>
</div>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<!-- Server Status Card -->
<div class="card-elevated p-6">
<div class="flex items-center mb-4">
<div class="rounded-full bg-success-50 p-3 mr-4">
<svg class="w-6 h-6 text-success-500" fill="currentColor" viewBox="0 0 16 16">
<path d="M1.333 2.667C1.333 1.194 4.318 0 8 0s6.667 1.194 6.667 2.667V4c0 1.473-2.985 2.667-6.667 2.667S1.333 5.473 1.333 4V2.667z"/>
<path d="M1.333 6.334v3C1.333 10.805 4.318 12 8 12s6.667-1.194 6.667-2.667V6.334a6.51 6.51 0 0 1-1.458.79C11.81 7.684 9.967 8 8 8c-1.966 0-3.809-.317-5.208-.876a6.508 6.508 0 0 1-1.458-.79z"/>
<path d="M14.667 11.668a6.51 6.51 0 0 1-1.458.789c-1.4.56-3.242.876-5.21.876-1.966 0-3.809-.316-5.208-.876a6.51 6.51 0 0 1-1.458-.79v1.666C1.333 14.806 4.318 16 8 16s6.667-1.194 6.667-2.667v-1.665z"/>
</svg>
</div>
<div>
<h2 class="text-lg font-medium text-gray-900">Server Status</h2>
<span class="badge-success">Online</span>
</div>
</div>
<div class="space-y-3">
<div>
<p class="text-sm text-gray-500">Uptime</p>
<p class="text-2xl font-medium text-gray-900">{{.Uptime}}</p>
</div>
<div>
<p class="text-sm text-gray-500">Version</p>
<p class="font-mono text-sm text-gray-700">{{.Version}}</p>
</div>
</div>
</div>
<!-- Users Card -->
<div class="card-elevated p-6">
<div class="flex items-center mb-4">
<div class="rounded-full bg-primary-50 p-3 mr-4">
<svg class="w-6 h-6 text-primary-500" fill="currentColor" viewBox="0 0 16 16">
<path d="M15 14s1 0 1-1-1-4-5-4-5 3-5 4 1 1 1 1h8zm-7.978-1A.261.261 0 0 1 7 12.996c.001-.264.167-1.03.76-1.72C8.312 10.629 9.282 10 11 10c1.717 0 2.687.63 3.24 1.276.593.69.758 1.457.76 1.72l-.008.002a.274.274 0 0 1-.014.002H7.022zM11 7a2 2 0 1 0 0-4 2 2 0 0 0 0 4zm3-2a3 3 0 1 1-6 0 3 3 0 0 1 6 0zM6.936 9.28a5.88 5.88 0 0 0-1.23-.247A7.35 7.35 0 0 0 5 9c-4 0-5 3-5 4 0 .667.333 1 1 1h4.216A2.238 2.238 0 0 1 5 13c0-1.01.377-2.042 1.09-2.904.243-.294.526-.569.846-.816zM4.92 10A5.493 5.493 0 0 0 4 13H1c0-.26.164-1.03.76-1.724.545-.636 1.492-1.256 3.16-1.275zM1.5 5.5a3 3 0 1 1 6 0 3 3 0 0 1-6 0zm3-2a2 2 0 1 0 0 4 2 2 0 0 0 0-4z"/>
</svg>
</div>
<div>
<h2 class="text-lg font-medium text-gray-900">Users</h2>
<p class="text-sm text-gray-500">Registered accounts</p>
</div>
</div>
<div>
<p class="text-4xl font-medium text-gray-900">{{.UserCount}}</p>
<p class="text-sm text-gray-500 mt-1">Total users</p>
</div>
</div>
</div>
{{if not .User}}
<div class="text-center mt-10">
<p class="text-gray-500 mb-4">Ready to get started?</p>
<a href="/pages/login" class="btn-primary">Login to your account</a>
</div>
{{end}}
</div>
{{end}}

View File

@@ -23,7 +23,6 @@
{{end}} {{end}}
<form method="POST" action="/pages/login" class="space-y-6"> <form method="POST" action="/pages/login" class="space-y-6">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form-group"> <div class="form-group">
<label for="username" class="label">Username</label> <label for="username" class="label">Username</label>
<input <input

View File

@@ -25,7 +25,6 @@
{{.User.Username}} {{.User.Username}}
</a> </a>
<form method="POST" action="/pages/logout" class="inline"> <form method="POST" action="/pages/logout" class="inline">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button type="submit" class="btn-text">Logout</button> <button type="submit" class="btn-text">Logout</button>
</form> </form>
{{else}} {{else}}
@@ -41,7 +40,6 @@
<a href="/sources" class="btn-text w-full text-left">Sources</a> <a href="/sources" class="btn-text w-full text-left">Sources</a>
<a href="/user/{{.User.Username}}" class="btn-text w-full text-left">Profile</a> <a href="/user/{{.User.Username}}" class="btn-text w-full text-left">Profile</a>
<form method="POST" action="/pages/logout"> <form method="POST" action="/pages/logout">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button type="submit" class="btn-text w-full text-left">Logout</button> <button type="submit" class="btn-text w-full text-left">Logout</button>
</form> </form>
{{else}} {{else}}

View File

@@ -17,7 +17,6 @@
<a href="/source/{{.Webhook.ID}}/logs" class="btn-secondary">Event Log</a> <a href="/source/{{.Webhook.ID}}/logs" class="btn-secondary">Event Log</a>
<a href="/source/{{.Webhook.ID}}/edit" class="btn-secondary">Edit</a> <a href="/source/{{.Webhook.ID}}/edit" class="btn-secondary">Edit</a>
<form method="POST" action="/source/{{.Webhook.ID}}/delete" onsubmit="return confirm('Delete this webhook and all its data?')"> <form method="POST" action="/source/{{.Webhook.ID}}/delete" onsubmit="return confirm('Delete this webhook and all its data?')">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button type="submit" class="btn-danger">Delete</button> <button type="submit" class="btn-danger">Delete</button>
</form> </form>
</div> </div>
@@ -40,7 +39,6 @@
<!-- Add entrypoint form --> <!-- Add entrypoint form -->
<div x-show="showAddEntrypoint" x-cloak class="p-4 bg-gray-50 border-b border-gray-200"> <div x-show="showAddEntrypoint" x-cloak class="p-4 bg-gray-50 border-b border-gray-200">
<form method="POST" action="/source/{{.Webhook.ID}}/entrypoints" class="flex gap-2"> <form method="POST" action="/source/{{.Webhook.ID}}/entrypoints" class="flex gap-2">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<input type="text" name="description" placeholder="Description (optional)" class="input text-sm flex-1"> <input type="text" name="description" placeholder="Description (optional)" class="input text-sm flex-1">
<button type="submit" class="btn-primary text-sm">Add</button> <button type="submit" class="btn-primary text-sm">Add</button>
</form> </form>
@@ -58,13 +56,11 @@
<span class="badge-error">Inactive</span> <span class="badge-error">Inactive</span>
{{end}} {{end}}
<form method="POST" action="/source/{{$.Webhook.ID}}/entrypoints/{{.ID}}/toggle" class="inline"> <form method="POST" action="/source/{{$.Webhook.ID}}/entrypoints/{{.ID}}/toggle" class="inline">
<input type="hidden" name="csrf_token" value="{{$.CSRFToken}}">
<button type="submit" class="text-xs text-gray-500 hover:text-primary-600" title="{{if .Active}}Deactivate{{else}}Activate{{end}}"> <button type="submit" class="text-xs text-gray-500 hover:text-primary-600" title="{{if .Active}}Deactivate{{else}}Activate{{end}}">
{{if .Active}}Deactivate{{else}}Activate{{end}} {{if .Active}}Deactivate{{else}}Activate{{end}}
</button> </button>
</form> </form>
<form method="POST" action="/source/{{$.Webhook.ID}}/entrypoints/{{.ID}}/delete" onsubmit="return confirm('Delete this entrypoint?')" class="inline"> <form method="POST" action="/source/{{$.Webhook.ID}}/entrypoints/{{.ID}}/delete" onsubmit="return confirm('Delete this entrypoint?')" class="inline">
<input type="hidden" name="csrf_token" value="{{$.CSRFToken}}">
<button type="submit" class="text-xs text-red-500 hover:text-red-700" title="Delete">Delete</button> <button type="submit" class="text-xs text-red-500 hover:text-red-700" title="Delete">Delete</button>
</form> </form>
</div> </div>
@@ -92,7 +88,6 @@
<!-- Add target form --> <!-- Add target form -->
<div x-show="showAddTarget" x-cloak class="p-4 bg-gray-50 border-b border-gray-200"> <div x-show="showAddTarget" x-cloak class="p-4 bg-gray-50 border-b border-gray-200">
<form method="POST" action="/source/{{.Webhook.ID}}/targets" x-data="{ targetType: 'http' }" class="space-y-3"> <form method="POST" action="/source/{{.Webhook.ID}}/targets" x-data="{ targetType: 'http' }" class="space-y-3">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="flex gap-2"> <div class="flex gap-2">
<input type="text" name="name" placeholder="Target name" required class="input text-sm flex-1"> <input type="text" name="name" placeholder="Target name" required class="input text-sm flex-1">
<select name="type" x-model="targetType" class="input text-sm w-32"> <select name="type" x-model="targetType" class="input text-sm w-32">
@@ -130,13 +125,11 @@
<span class="badge-error">Inactive</span> <span class="badge-error">Inactive</span>
{{end}} {{end}}
<form method="POST" action="/source/{{$.Webhook.ID}}/targets/{{.ID}}/toggle" class="inline"> <form method="POST" action="/source/{{$.Webhook.ID}}/targets/{{.ID}}/toggle" class="inline">
<input type="hidden" name="csrf_token" value="{{$.CSRFToken}}">
<button type="submit" class="text-xs text-gray-500 hover:text-primary-600" title="{{if .Active}}Deactivate{{else}}Activate{{end}}"> <button type="submit" class="text-xs text-gray-500 hover:text-primary-600" title="{{if .Active}}Deactivate{{else}}Activate{{end}}">
{{if .Active}}Deactivate{{else}}Activate{{end}} {{if .Active}}Deactivate{{else}}Activate{{end}}
</button> </button>
</form> </form>
<form method="POST" action="/source/{{$.Webhook.ID}}/targets/{{.ID}}/delete" onsubmit="return confirm('Delete this target?')" class="inline"> <form method="POST" action="/source/{{$.Webhook.ID}}/targets/{{.ID}}/delete" onsubmit="return confirm('Delete this target?')" class="inline">
<input type="hidden" name="csrf_token" value="{{$.CSRFToken}}">
<button type="submit" class="text-xs text-red-500 hover:text-red-700" title="Delete">Delete</button> <button type="submit" class="text-xs text-red-500 hover:text-red-700" title="Delete">Delete</button>
</form> </form>
</div> </div>

View File

@@ -15,7 +15,6 @@
{{end}} {{end}}
<form method="POST" action="/source/{{.Webhook.ID}}/edit" class="space-y-6"> <form method="POST" action="/source/{{.Webhook.ID}}/edit" class="space-y-6">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form-group"> <div class="form-group">
<label for="name" class="label">Name</label> <label for="name" class="label">Name</label>
<input type="text" id="name" name="name" value="{{.Webhook.Name}}" required class="input"> <input type="text" id="name" name="name" value="{{.Webhook.Name}}" required class="input">

View File

@@ -15,7 +15,6 @@
{{end}} {{end}}
<form method="POST" action="/sources/new" class="space-y-6"> <form method="POST" action="/sources/new" class="space-y-6">
<input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div class="form-group"> <div class="form-group">
<label for="name" class="label">Name</label> <label for="name" class="label">Name</label>
<input type="text" id="name" name="name" required autofocus placeholder="My Webhook" class="input"> <input type="text" id="name" name="name" required autofocus placeholder="My Webhook" class="input">

View File

@@ -1,11 +1,8 @@
// Package templates embeds HTML templates used by the web UI.
package templates package templates
import ( import (
"embed" "embed"
) )
// Templates holds the embedded HTML template files.
//
//go:embed *.html //go:embed *.html
var Templates embed.FS var Templates embed.FS