refactor: use pinned golangci-lint Docker image for linting #55
@@ -1,46 +1,32 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
run:
|
run:
|
||||||
timeout: 5m
|
timeout: 5m
|
||||||
tests: true
|
modules-download-mode: readonly
|
||||||
|
|
||||||
linters:
|
linters:
|
||||||
enable:
|
default: all
|
||||||
- gofmt
|
disable:
|
||||||
- revive
|
# Genuinely incompatible with project patterns
|
||||||
- govet
|
- exhaustruct # Requires all struct fields
|
||||||
- errcheck
|
- depguard # Dependency allow/block lists
|
||||||
- staticcheck
|
- godot # Requires comments to end with periods
|
||||||
- unused
|
- wsl # Deprecated, replaced by wsl_v5
|
||||||
- gosimple
|
- wrapcheck # Too verbose for internal packages
|
||||||
- ineffassign
|
- varnamelen # Short names like db, id are idiomatic Go
|
||||||
- typecheck
|
|
||||||
- gosec
|
|
||||||
- misspell
|
|
||||||
- unparam
|
|
||||||
- prealloc
|
|
||||||
- copyloopvar
|
|
||||||
- gocritic
|
|
||||||
- gochecknoinits
|
|
||||||
- gochecknoglobals
|
|
||||||
|
|
||||||
linters-settings:
|
linters-settings:
|
||||||
gofmt:
|
lll:
|
||||||
simplify: true
|
line-length: 88
|
||||||
revive:
|
funlen:
|
||||||
confidence: 0.8
|
lines: 80
|
||||||
govet:
|
statements: 50
|
||||||
enable:
|
cyclop:
|
||||||
- shadow
|
max-complexity: 15
|
||||||
errcheck:
|
dupl:
|
||||||
check-type-assertions: true
|
threshold: 100
|
||||||
check-blank: true
|
|
||||||
|
|
||||||
issues:
|
issues:
|
||||||
exclude-rules:
|
exclude-use-default: false
|
||||||
# Exclude globals check for version variables in main
|
max-issues-per-linter: 0
|
||||||
- path: cmd/webhooker/main.go
|
max-same-issues: 0
|
||||||
linters:
|
|
||||||
- gochecknoglobals
|
|
||||||
# Exclude globals check for version variables in globals package
|
|
||||||
- path: internal/globals/globals.go
|
|
||||||
linters:
|
|
||||||
- gochecknoglobals
|
|
||||||
|
|||||||
74
Dockerfile
74
Dockerfile
@@ -1,56 +1,58 @@
|
|||||||
# golang:1.24 (bookworm) — 2026-03-01
|
# Lint stage
|
||||||
# Using Debian-based image because gorm.io/driver/sqlite pulls in
|
# golangci/golangci-lint:v2.11.3 (Debian-based), 2026-03-17
|
||||||
# mattn/go-sqlite3 (CGO), which does not compile on Alpine musl.
|
# Using Debian-based image because mattn/go-sqlite3 (CGO) does not
|
||||||
FROM golang@sha256:d2d2bc1c84f7e60d7d2438a3836ae7d0c847f4888464e7ec9ba3a1339a1ee804 AS builder
|
# compile on Alpine musl (off64_t is a glibc type).
|
||||||
|
FROM golangci/golangci-lint:v2.11.3@sha256:e838e8ab68aaefe83e2408691510867ade9329c0e0b895a3fb35eb93d1c2a4ba AS lint
|
||||||
|
|
||||||
# gcc is pre-installed in the Debian-based golang image
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends make && rm -rf /var/lib/apt/lists/*
|
RUN apt-get update && apt-get install -y --no-install-recommends make && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /src
|
||||||
|
|
||||||
# Install golangci-lint v1.64.8 — 2026-03-01
|
# Copy go mod files first for better layer caching
|
||||||
# Using v1.x because the repo's .golangci.yml uses v1 config format.
|
|
||||||
RUN set -eux; \
|
|
||||||
GOLANGCI_VERSION="1.64.8"; \
|
|
||||||
ARCH="$(uname -m)"; \
|
|
||||||
case "${ARCH}" in \
|
|
||||||
x86_64) \
|
|
||||||
GOARCH="amd64"; \
|
|
||||||
GOLANGCI_SHA256="b6270687afb143d019f387c791cd2a6f1cb383be9b3124d241ca11bd3ce2e54e"; \
|
|
||||||
;; \
|
|
||||||
aarch64) \
|
|
||||||
GOARCH="arm64"; \
|
|
||||||
GOLANGCI_SHA256="a6ab58ebcb1c48572622146cdaec2956f56871038a54ed1149f1386e287789a5"; \
|
|
||||||
;; \
|
|
||||||
*) echo "unsupported architecture: ${ARCH}" && exit 1 ;; \
|
|
||||||
esac; \
|
|
||||||
wget -q "https://github.com/golangci/golangci-lint/releases/download/v${GOLANGCI_VERSION}/golangci-lint-${GOLANGCI_VERSION}-linux-${GOARCH}.tar.gz" \
|
|
||||||
-O /tmp/golangci-lint.tar.gz; \
|
|
||||||
echo "${GOLANGCI_SHA256} /tmp/golangci-lint.tar.gz" | sha256sum -c -; \
|
|
||||||
tar -xzf /tmp/golangci-lint.tar.gz -C /tmp; \
|
|
||||||
mv "/tmp/golangci-lint-${GOLANGCI_VERSION}-linux-${GOARCH}/golangci-lint" /usr/local/bin/; \
|
|
||||||
rm -rf /tmp/golangci-lint*; \
|
|
||||||
golangci-lint --version
|
|
||||||
|
|
||||||
# Copy go module files and download dependencies
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
# Copy source code
|
# Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Run all checks (fmt-check, lint, test, build)
|
# Run formatting check and linter
|
||||||
RUN make check
|
RUN make fmt-check
|
||||||
|
RUN make lint
|
||||||
|
|
||||||
|
# Build stage
|
||||||
|
# golang:1.26.1-bookworm (Debian-based), 2026-03-17
|
||||||
|
# Using Debian-based image because gorm.io/driver/sqlite pulls in
|
||||||
|
# mattn/go-sqlite3 (CGO), which does not compile on Alpine musl.
|
||||||
|
FROM golang:1.26.1-bookworm@sha256:4465644228bc2857a954b092167e12aa59c006a3492282a6c820bf4755fd64a4 AS builder
|
||||||
|
|
||||||
|
# Depend on lint stage passing
|
||||||
|
COPY --from=lint /src/go.sum /dev/null
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends make && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# Copy go mod files first for better layer caching
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
# Copy source code
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Run tests and build
|
||||||
|
RUN make test
|
||||||
|
RUN make build
|
||||||
|
|
||||||
# Rebuild with static linking for Alpine runtime.
|
# Rebuild with static linking for Alpine runtime.
|
||||||
# make check already verified formatting, linting, tests, and compilation.
|
# make build already verified compilation.
|
||||||
# The CGO binary from `make build` is dynamically linked against glibc,
|
# The CGO binary from `make build` is dynamically linked against glibc,
|
||||||
# which doesn't exist on Alpine (musl). Rebuild with static linking so
|
# which doesn't exist on Alpine (musl). Rebuild with static linking so
|
||||||
# the binary runs on Alpine without glibc.
|
# the binary runs on Alpine without glibc.
|
||||||
RUN CGO_ENABLED=1 go build -ldflags '-extldflags "-static"' -o bin/webhooker ./cmd/webhooker
|
RUN CGO_ENABLED=1 go build -ldflags '-extldflags "-static"' -o bin/webhooker ./cmd/webhooker
|
||||||
|
|
||||||
# alpine:3.21 — 2026-03-01
|
# Runtime stage
|
||||||
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
|
# alpine:3.21, 2026-03-17
|
||||||
|
FROM alpine:3.21@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
|
||||||
|
|
||||||
RUN apk --no-cache add ca-certificates
|
RUN apk --no-cache add ca-certificates
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ with retry support, logging, and observability. Category: infrastructure
|
|||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
- Go 1.24+
|
- Go 1.26+
|
||||||
- golangci-lint v1.64+
|
- golangci-lint v2.11+
|
||||||
- Docker (for containerized deployment)
|
- Docker (for containerized deployment)
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
@@ -777,7 +777,7 @@ webhooker/
|
|||||||
│ ├── css/style.css # Custom stylesheet (system font stack, card effects, layout)
|
│ ├── css/style.css # Custom stylesheet (system font stack, card effects, layout)
|
||||||
│ └── js/app.js # Client-side JavaScript (minimal bootstrap)
|
│ └── js/app.js # Client-side JavaScript (minimal bootstrap)
|
||||||
├── templates/ # Go HTML templates (base, index, login, etc.)
|
├── templates/ # Go HTML templates (base, index, login, etc.)
|
||||||
├── Dockerfile # Multi-stage: build + check, then Alpine runtime
|
├── Dockerfile # Multi-stage: lint, build+test, then Alpine runtime
|
||||||
├── Makefile # fmt, lint, test, check, build, docker targets
|
├── Makefile # fmt, lint, test, check, build, docker targets
|
||||||
├── go.mod / go.sum
|
├── go.mod / go.sum
|
||||||
└── .golangci.yml # Linter configuration
|
└── .golangci.yml # Linter configuration
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package main is the entry point for the webhooker application.
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -15,6 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Build-time variables set via -ldflags.
|
// Build-time variables set via -ldflags.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Build-time variables injected by the linker.
|
||||||
var (
|
var (
|
||||||
version = "dev"
|
version = "dev"
|
||||||
appname = "webhooker"
|
appname = "webhooker"
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -1,8 +1,6 @@
|
|||||||
module sneak.berlin/go/webhooker
|
module sneak.berlin/go/webhooker
|
||||||
|
|
||||||
go 1.24.0
|
go 1.26.1
|
||||||
|
|
||||||
toolchain go1.24.1
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8
|
github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
// Package config loads application configuration from environment variables.
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,19 +19,29 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// EnvironmentDev represents development environment
|
// EnvironmentDev represents development environment.
|
||||||
EnvironmentDev = "dev"
|
EnvironmentDev = "dev"
|
||||||
// EnvironmentProd represents production environment
|
// EnvironmentProd represents production environment.
|
||||||
EnvironmentProd = "prod"
|
EnvironmentProd = "prod"
|
||||||
|
|
||||||
|
// defaultPort is the default HTTP listen port.
|
||||||
|
defaultPort = 8080
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // ConfigParams is a standard fx naming convention
|
// ErrInvalidEnvironment is returned when WEBHOOKER_ENVIRONMENT
|
||||||
|
// contains an unrecognised value.
|
||||||
|
var ErrInvalidEnvironment = errors.New("invalid environment")
|
||||||
|
|
||||||
|
//nolint:revive // ConfigParams is a standard fx naming convention.
|
||||||
type ConfigParams struct {
|
type ConfigParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Globals *globals.Globals
|
Globals *globals.Globals
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Config holds all application configuration loaded from
|
||||||
|
// environment variables.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
DataDir string
|
DataDir string
|
||||||
Debug bool
|
Debug bool
|
||||||
@@ -43,56 +55,67 @@ type Config struct {
|
|||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDev returns true if running in development environment
|
// IsDev returns true if running in development environment.
|
||||||
func (c *Config) IsDev() bool {
|
func (c *Config) IsDev() bool {
|
||||||
return c.Environment == EnvironmentDev
|
return c.Environment == EnvironmentDev
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsProd returns true if running in production environment
|
// IsProd returns true if running in production environment.
|
||||||
func (c *Config) IsProd() bool {
|
func (c *Config) IsProd() bool {
|
||||||
return c.Environment == EnvironmentProd
|
return c.Environment == EnvironmentProd
|
||||||
}
|
}
|
||||||
|
|
||||||
// envString returns the value of the named environment variable, or
|
// envString returns the value of the named environment variable,
|
||||||
// an empty string if not set.
|
// or an empty string if not set.
|
||||||
func envString(key string) string {
|
func envString(key string) string {
|
||||||
return os.Getenv(key)
|
return os.Getenv(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// envBool returns the value of the named environment variable parsed as a
|
// envBool returns the value of the named environment variable
|
||||||
// boolean. Returns defaultValue if not set.
|
// parsed as a boolean. Returns defaultValue if not set.
|
||||||
func envBool(key string, defaultValue bool) bool {
|
func envBool(key string, defaultValue bool) bool {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return strings.EqualFold(v, "true") || v == "1"
|
return strings.EqualFold(v, "true") || v == "1"
|
||||||
}
|
}
|
||||||
|
|
||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// envInt returns the value of the named environment variable parsed as an
|
// envInt returns the value of the named environment variable
|
||||||
// integer. Returns defaultValue if not set or unparseable.
|
// parsed as an integer. Returns defaultValue if not set or
|
||||||
|
// unparseable.
|
||||||
func envInt(key string, defaultValue int) int {
|
func envInt(key string, defaultValue int) int {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if i, err := strconv.Atoi(v); err == nil {
|
i, err := strconv.Atoi(v)
|
||||||
|
if err == nil {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:revive // lc parameter is required by fx even if unused
|
// New creates a Config by reading environment variables.
|
||||||
|
//
|
||||||
|
//nolint:revive // lc parameter is required by fx even if unused.
|
||||||
func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
||||||
log := params.Logger.Get()
|
log := params.Logger.Get()
|
||||||
|
|
||||||
// Determine environment from WEBHOOKER_ENVIRONMENT env var, default to dev
|
// Determine environment from WEBHOOKER_ENVIRONMENT env var,
|
||||||
|
// default to dev
|
||||||
environment := os.Getenv("WEBHOOKER_ENVIRONMENT")
|
environment := os.Getenv("WEBHOOKER_ENVIRONMENT")
|
||||||
if environment == "" {
|
if environment == "" {
|
||||||
environment = EnvironmentDev
|
environment = EnvironmentDev
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate environment
|
// Validate environment
|
||||||
if environment != EnvironmentDev && environment != EnvironmentProd {
|
if environment != EnvironmentDev &&
|
||||||
return nil, fmt.Errorf("WEBHOOKER_ENVIRONMENT must be either '%s' or '%s', got '%s'",
|
environment != EnvironmentProd {
|
||||||
EnvironmentDev, EnvironmentProd, environment)
|
return nil, fmt.Errorf(
|
||||||
|
"%w: WEBHOOKER_ENVIRONMENT must be '%s' or '%s', got '%s'",
|
||||||
|
ErrInvalidEnvironment,
|
||||||
|
EnvironmentDev, EnvironmentProd, environment,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load configuration values from environment variables
|
// Load configuration values from environment variables
|
||||||
@@ -103,15 +126,16 @@ func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
|||||||
Environment: environment,
|
Environment: environment,
|
||||||
MetricsUsername: envString("METRICS_USERNAME"),
|
MetricsUsername: envString("METRICS_USERNAME"),
|
||||||
MetricsPassword: envString("METRICS_PASSWORD"),
|
MetricsPassword: envString("METRICS_PASSWORD"),
|
||||||
Port: envInt("PORT", 8080),
|
Port: envInt("PORT", defaultPort),
|
||||||
SentryDSN: envString("SENTRY_DSN"),
|
SentryDSN: envString("SENTRY_DSN"),
|
||||||
log: log,
|
log: log,
|
||||||
params: ¶ms,
|
params: ¶ms,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default DataDir. All SQLite databases (main application DB
|
// Set default DataDir. All SQLite databases (main application
|
||||||
// and per-webhook event DBs) live here. The same default is used
|
// DB and per-webhook event DBs) live here. The same default is
|
||||||
// regardless of environment; override with DATA_DIR if needed.
|
// used regardless of environment; override with DATA_DIR if
|
||||||
|
// needed.
|
||||||
if s.DataDir == "" {
|
if s.DataDir == "" {
|
||||||
s.DataDir = "/var/lib/webhooker"
|
s.DataDir = "/var/lib/webhooker"
|
||||||
}
|
}
|
||||||
@@ -128,7 +152,8 @@ func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
|||||||
"maintenanceMode", s.MaintenanceMode,
|
"maintenanceMode", s.MaintenanceMode,
|
||||||
"dataDir", s.DataDir,
|
"dataDir", s.DataDir,
|
||||||
"hasSentryDSN", s.SentryDSN != "",
|
"hasSentryDSN", s.SentryDSN != "",
|
||||||
"hasMetricsAuth", s.MetricsUsername != "" && s.MetricsPassword != "",
|
"hasMetricsAuth",
|
||||||
|
s.MetricsUsername != "" && s.MetricsPassword != "",
|
||||||
)
|
)
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package config_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
"go.uber.org/fx/fxtest"
|
"go.uber.org/fx/fxtest"
|
||||||
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
"sneak.berlin/go/webhooker/internal/globals"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
@@ -22,121 +23,143 @@ func TestEnvironmentConfig(t *testing.T) {
|
|||||||
isProd bool
|
isProd bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "default is dev",
|
name: "default is dev",
|
||||||
envValue: "",
|
isDev: true,
|
||||||
envVars: map[string]string{},
|
isProd: false,
|
||||||
expectError: false,
|
|
||||||
isDev: true,
|
|
||||||
isProd: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "explicit dev",
|
name: "explicit dev",
|
||||||
envValue: "dev",
|
envValue: "dev",
|
||||||
envVars: map[string]string{},
|
isDev: true,
|
||||||
expectError: false,
|
isProd: false,
|
||||||
isDev: true,
|
|
||||||
isProd: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "explicit prod",
|
name: "explicit prod",
|
||||||
envValue: "prod",
|
envValue: "prod",
|
||||||
envVars: map[string]string{},
|
isDev: false,
|
||||||
expectError: false,
|
isProd: true,
|
||||||
isDev: false,
|
|
||||||
isProd: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid environment",
|
name: "invalid environment",
|
||||||
envValue: "staging",
|
envValue: "staging",
|
||||||
envVars: map[string]string{},
|
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Set environment variable if specified
|
// Cannot use t.Parallel() here because t.Setenv
|
||||||
|
// is incompatible with parallel subtests.
|
||||||
if tt.envValue != "" {
|
if tt.envValue != "" {
|
||||||
os.Setenv("WEBHOOKER_ENVIRONMENT", tt.envValue)
|
t.Setenv(
|
||||||
defer os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
"WEBHOOKER_ENVIRONMENT", tt.envValue,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
require.NoError(t, os.Unsetenv(
|
||||||
|
"WEBHOOKER_ENVIRONMENT",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set additional environment variables
|
|
||||||
for k, v := range tt.envVars {
|
for k, v := range tt.envVars {
|
||||||
os.Setenv(k, v)
|
t.Setenv(k, v)
|
||||||
defer os.Unsetenv(k)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.expectError {
|
if tt.expectError {
|
||||||
// Use regular fx.New for error cases since fxtest doesn't expose errors the same way
|
testEnvironmentConfigError(t)
|
||||||
var cfg *Config
|
|
||||||
app := fx.New(
|
|
||||||
fx.NopLogger, // Suppress fx logs in tests
|
|
||||||
fx.Provide(
|
|
||||||
globals.New,
|
|
||||||
logger.New,
|
|
||||||
New,
|
|
||||||
),
|
|
||||||
fx.Populate(&cfg),
|
|
||||||
)
|
|
||||||
assert.Error(t, app.Err())
|
|
||||||
} else {
|
} else {
|
||||||
// Use fxtest for success cases
|
testEnvironmentConfigSuccess(
|
||||||
var cfg *Config
|
t, tt.isDev, tt.isProd,
|
||||||
app := fxtest.New(
|
|
||||||
t,
|
|
||||||
fx.Provide(
|
|
||||||
globals.New,
|
|
||||||
logger.New,
|
|
||||||
New,
|
|
||||||
),
|
|
||||||
fx.Populate(&cfg),
|
|
||||||
)
|
)
|
||||||
require.NoError(t, app.Err())
|
|
||||||
app.RequireStart()
|
|
||||||
defer app.RequireStop()
|
|
||||||
|
|
||||||
assert.Equal(t, tt.isDev, cfg.IsDev())
|
|
||||||
assert.Equal(t, tt.isProd, cfg.IsProd())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testEnvironmentConfigError(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var cfg *config.Config
|
||||||
|
|
||||||
|
app := fx.New(
|
||||||
|
fx.NopLogger,
|
||||||
|
fx.Provide(
|
||||||
|
globals.New,
|
||||||
|
logger.New,
|
||||||
|
config.New,
|
||||||
|
),
|
||||||
|
fx.Populate(&cfg),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Error(t, app.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEnvironmentConfigSuccess(
|
||||||
|
t *testing.T,
|
||||||
|
isDev, isProd bool,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var cfg *config.Config
|
||||||
|
|
||||||
|
app := fxtest.New(
|
||||||
|
t,
|
||||||
|
fx.Provide(
|
||||||
|
globals.New,
|
||||||
|
logger.New,
|
||||||
|
config.New,
|
||||||
|
),
|
||||||
|
fx.Populate(&cfg),
|
||||||
|
)
|
||||||
|
require.NoError(t, app.Err())
|
||||||
|
|
||||||
|
app.RequireStart()
|
||||||
|
|
||||||
|
defer app.RequireStop()
|
||||||
|
|
||||||
|
assert.Equal(t, isDev, cfg.IsDev())
|
||||||
|
assert.Equal(t, isProd, cfg.IsProd())
|
||||||
|
}
|
||||||
|
|
||||||
func TestDefaultDataDir(t *testing.T) {
|
func TestDefaultDataDir(t *testing.T) {
|
||||||
// Verify that when DATA_DIR is unset, the default is /var/lib/webhooker
|
|
||||||
// regardless of the environment setting.
|
|
||||||
for _, env := range []string{"", "dev", "prod"} {
|
for _, env := range []string{"", "dev", "prod"} {
|
||||||
name := env
|
name := env
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = "unset"
|
name = "unset"
|
||||||
}
|
}
|
||||||
t.Run("env="+name, func(t *testing.T) {
|
|
||||||
if env != "" {
|
|
||||||
os.Setenv("WEBHOOKER_ENVIRONMENT", env)
|
|
||||||
defer os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
|
||||||
} else {
|
|
||||||
os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
|
||||||
}
|
|
||||||
os.Unsetenv("DATA_DIR")
|
|
||||||
|
|
||||||
var cfg *Config
|
t.Run("env="+name, func(t *testing.T) {
|
||||||
|
// Cannot use t.Parallel() here because t.Setenv
|
||||||
|
// is incompatible with parallel subtests.
|
||||||
|
if env != "" {
|
||||||
|
t.Setenv("WEBHOOKER_ENVIRONMENT", env)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, os.Unsetenv(
|
||||||
|
"WEBHOOKER_ENVIRONMENT",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, os.Unsetenv("DATA_DIR"))
|
||||||
|
|
||||||
|
var cfg *config.Config
|
||||||
|
|
||||||
app := fxtest.New(
|
app := fxtest.New(
|
||||||
t,
|
t,
|
||||||
fx.Provide(
|
fx.Provide(
|
||||||
globals.New,
|
globals.New,
|
||||||
logger.New,
|
logger.New,
|
||||||
New,
|
config.New,
|
||||||
),
|
),
|
||||||
fx.Populate(&cfg),
|
fx.Populate(&cfg),
|
||||||
)
|
)
|
||||||
require.NoError(t, app.Err())
|
require.NoError(t, app.Err())
|
||||||
|
|
||||||
app.RequireStart()
|
app.RequireStart()
|
||||||
|
|
||||||
defer app.RequireStop()
|
defer app.RequireStop()
|
||||||
|
|
||||||
assert.Equal(t, "/var/lib/webhooker", cfg.DataDir)
|
assert.Equal(
|
||||||
|
t, "/var/lib/webhooker", cfg.DataDir,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,15 +11,16 @@ import (
|
|||||||
// This replaces gorm.Model but uses UUID instead of uint for ID
|
// This replaces gorm.Model but uses UUID instead of uint for ID
|
||||||
type BaseModel struct {
|
type BaseModel struct {
|
||||||
ID string `gorm:"type:uuid;primary_key" json:"id"`
|
ID string `gorm:"type:uuid;primary_key" json:"id"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
DeletedAt gorm.DeletedAt `gorm:"index" json:"deletedAt,omitzero"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeforeCreate hook to set UUID before creating a record
|
// BeforeCreate hook to set UUID before creating a record.
|
||||||
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
|
func (b *BaseModel) BeforeCreate(_ *gorm.DB) error {
|
||||||
if b.ID == "" {
|
if b.ID == "" {
|
||||||
b.ID = uuid.New().String()
|
b.ID = uuid.New().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package database provides SQLite persistence for webhooks, events, and users.
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -19,30 +20,42 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // DatabaseParams is a standard fx naming convention
|
const (
|
||||||
|
dataDirPerm = 0750
|
||||||
|
randomPasswordLen = 16
|
||||||
|
sessionKeyLen = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:revive // DatabaseParams is a standard fx naming convention.
|
||||||
type DatabaseParams struct {
|
type DatabaseParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Database manages the main SQLite connection and schema migrations.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
params *DatabaseParams
|
params *DatabaseParams
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
// New creates a Database that connects on fx start and disconnects on stop.
|
||||||
|
func New(
|
||||||
|
lc fx.Lifecycle,
|
||||||
|
params DatabaseParams,
|
||||||
|
) (*Database, error) {
|
||||||
d := &Database{
|
d := &Database{
|
||||||
params: ¶ms,
|
params: ¶ms,
|
||||||
log: params.Logger.Get(),
|
log: params.Logger.Get(),
|
||||||
}
|
}
|
||||||
|
|
||||||
lc.Append(fx.Hook{
|
lc.Append(fx.Hook{
|
||||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
OnStart: func(_ context.Context) error {
|
||||||
return d.connect()
|
return d.connect()
|
||||||
},
|
},
|
||||||
OnStop: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
OnStop: func(_ context.Context) error {
|
||||||
return d.close()
|
return d.close()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -50,21 +63,92 @@ func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
|||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB returns the underlying GORM database handle.
|
||||||
|
func (d *Database) DB() *gorm.DB {
|
||||||
|
return d.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateSessionKey retrieves the session encryption key from the
|
||||||
|
// settings table. If no key exists, a cryptographically secure random
|
||||||
|
// 32-byte key is generated, base64-encoded, and stored for future use.
|
||||||
|
func (d *Database) GetOrCreateSessionKey() (string, error) {
|
||||||
|
var setting Setting
|
||||||
|
|
||||||
|
result := d.db.Where(
|
||||||
|
&Setting{Key: "session_key"},
|
||||||
|
).First(&setting)
|
||||||
|
if result.Error == nil {
|
||||||
|
return setting.Value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"failed to query session key: %w",
|
||||||
|
result.Error,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new cryptographically secure 32-byte key
|
||||||
|
keyBytes := make([]byte, sessionKeyLen)
|
||||||
|
|
||||||
|
_, err := rand.Read(keyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"failed to generate session key: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(keyBytes)
|
||||||
|
|
||||||
|
setting = Setting{
|
||||||
|
Key: "session_key",
|
||||||
|
Value: encoded,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.db.Create(&setting).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"failed to store session key: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.log.Info(
|
||||||
|
"generated new session key and stored in database",
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) connect() error {
|
func (d *Database) connect() error {
|
||||||
// Ensure the data directory exists before opening the database.
|
// Ensure the data directory exists before opening the database.
|
||||||
dataDir := d.params.Config.DataDir
|
dataDir := d.params.Config.DataDir
|
||||||
if err := os.MkdirAll(dataDir, 0750); err != nil {
|
|
||||||
return fmt.Errorf("creating data directory %s: %w", dataDir, err)
|
err := os.MkdirAll(dataDir, dataDirPerm)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"creating data directory %s: %w",
|
||||||
|
dataDir,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the main application database path inside DATA_DIR.
|
// Construct the main application database path inside DATA_DIR.
|
||||||
dbPath := filepath.Join(dataDir, "webhooker.db")
|
dbPath := filepath.Join(dataDir, "webhooker.db")
|
||||||
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbPath)
|
dbURL := fmt.Sprintf(
|
||||||
|
"file:%s?cache=shared&mode=rwc",
|
||||||
|
dbPath,
|
||||||
|
)
|
||||||
|
|
||||||
// Open the database with the pure Go SQLite driver
|
// Open the database with the pure Go SQLite driver
|
||||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.log.Error("failed to open database", "error", err)
|
d.log.Error(
|
||||||
|
"failed to open database",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +157,11 @@ func (d *Database) connect() error {
|
|||||||
Conn: sqlDB,
|
Conn: sqlDB,
|
||||||
}, &gorm.Config{})
|
}, &gorm.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.log.Error("failed to connect to database", "error", err)
|
d.log.Error(
|
||||||
|
"failed to connect to database",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,101 +174,100 @@ func (d *Database) connect() error {
|
|||||||
|
|
||||||
func (d *Database) migrate() error {
|
func (d *Database) migrate() error {
|
||||||
// Run GORM auto-migrations
|
// Run GORM auto-migrations
|
||||||
if err := d.Migrate(); err != nil {
|
err := d.Migrate()
|
||||||
d.log.Error("failed to run database migrations", "error", err)
|
if err != nil {
|
||||||
|
d.log.Error(
|
||||||
|
"failed to run database migrations",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.log.Info("database migrations completed")
|
d.log.Info("database migrations completed")
|
||||||
|
|
||||||
// Check if admin user exists
|
// Check if admin user exists
|
||||||
var userCount int64
|
var userCount int64
|
||||||
if err := d.db.Model(&User{}).Count(&userCount).Error; err != nil {
|
|
||||||
d.log.Error("failed to count users", "error", err)
|
err = d.db.Model(&User{}).Count(&userCount).Error
|
||||||
|
if err != nil {
|
||||||
|
d.log.Error(
|
||||||
|
"failed to count users",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if userCount == 0 {
|
if userCount == 0 {
|
||||||
// Create admin user
|
return d.createAdminUser()
|
||||||
d.log.Info("no users found, creating admin user")
|
|
||||||
|
|
||||||
// Generate random password
|
|
||||||
password, err := GenerateRandomPassword(16)
|
|
||||||
if err != nil {
|
|
||||||
d.log.Error("failed to generate random password", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash the password
|
|
||||||
hashedPassword, err := HashPassword(password)
|
|
||||||
if err != nil {
|
|
||||||
d.log.Error("failed to hash password", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create admin user
|
|
||||||
adminUser := &User{
|
|
||||||
Username: "admin",
|
|
||||||
Password: hashedPassword,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.db.Create(adminUser).Error; err != nil {
|
|
||||||
d.log.Error("failed to create admin user", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.log.Info("admin user created",
|
|
||||||
"username", "admin",
|
|
||||||
"password", password,
|
|
||||||
"message", "SAVE THIS PASSWORD - it will not be shown again!",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) createAdminUser() error {
|
||||||
|
d.log.Info("no users found, creating admin user")
|
||||||
|
|
||||||
|
// Generate random password
|
||||||
|
password, err := GenerateRandomPassword(
|
||||||
|
randomPasswordLen,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
d.log.Error(
|
||||||
|
"failed to generate random password",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the password
|
||||||
|
hashedPassword, err := HashPassword(password)
|
||||||
|
if err != nil {
|
||||||
|
d.log.Error(
|
||||||
|
"failed to hash password",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create admin user
|
||||||
|
adminUser := &User{
|
||||||
|
Username: "admin",
|
||||||
|
Password: hashedPassword,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.db.Create(adminUser).Error
|
||||||
|
if err != nil {
|
||||||
|
d.log.Error(
|
||||||
|
"failed to create admin user",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
d.log.Info("admin user created",
|
||||||
|
"username", "admin",
|
||||||
|
"password", password,
|
||||||
|
"message",
|
||||||
|
"SAVE THIS PASSWORD - it will not be shown again!",
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) close() error {
|
func (d *Database) close() error {
|
||||||
if d.db != nil {
|
if d.db != nil {
|
||||||
sqlDB, err := d.db.DB()
|
sqlDB, err := d.db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return sqlDB.Close()
|
return sqlDB.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DB() *gorm.DB {
|
|
||||||
return d.db
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOrCreateSessionKey retrieves the session encryption key from the
|
|
||||||
// settings table. If no key exists, a cryptographically secure random
|
|
||||||
// 32-byte key is generated, base64-encoded, and stored for future use.
|
|
||||||
func (d *Database) GetOrCreateSessionKey() (string, error) {
|
|
||||||
var setting Setting
|
|
||||||
result := d.db.Where(&Setting{Key: "session_key"}).First(&setting)
|
|
||||||
if result.Error == nil {
|
|
||||||
return setting.Value, nil
|
|
||||||
}
|
|
||||||
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
return "", fmt.Errorf("failed to query session key: %w", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a new cryptographically secure 32-byte key
|
|
||||||
keyBytes := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(keyBytes); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to generate session key: %w", err)
|
|
||||||
}
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(keyBytes)
|
|
||||||
|
|
||||||
setting = Setting{
|
|
||||||
Key: "session_key",
|
|
||||||
Value: encoded,
|
|
||||||
}
|
|
||||||
if err := d.db.Create(&setting).Error; err != nil {
|
|
||||||
return "", fmt.Errorf("failed to store session key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
d.log.Info("generated new session key and stored in database")
|
|
||||||
return encoded, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package database
|
package database_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -6,37 +6,37 @@ import (
|
|||||||
|
|
||||||
"go.uber.org/fx/fxtest"
|
"go.uber.org/fx/fxtest"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/database"
|
||||||
"sneak.berlin/go/webhooker/internal/globals"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDatabaseConnection(t *testing.T) {
|
func setupTestDB(
|
||||||
// Set up test dependencies
|
t *testing.T,
|
||||||
|
) (*database.Database, *fxtest.Lifecycle) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
lc := fxtest.NewLifecycle(t)
|
lc := fxtest.NewLifecycle(t)
|
||||||
|
|
||||||
// Create globals
|
g := &globals.Globals{
|
||||||
globals.Appname = "webhooker-test"
|
Appname: "webhooker-test",
|
||||||
globals.Version = "test"
|
Version: "test",
|
||||||
|
|
||||||
g, err := globals.New(lc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create globals: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create logger
|
l, err := logger.New(
|
||||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
lc,
|
||||||
|
logger.LoggerParams{Globals: g},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create logger: %v", err)
|
t.Fatalf("Failed to create logger: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create config with DataDir pointing to a temp directory
|
|
||||||
c := &config.Config{
|
c := &config.Config{
|
||||||
DataDir: t.TempDir(),
|
DataDir: t.TempDir(),
|
||||||
Environment: "dev",
|
Environment: "dev",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create database
|
db, err := database.New(lc, database.DatabaseParams{
|
||||||
db, err := New(lc, DatabaseParams{
|
|
||||||
Config: c,
|
Config: c,
|
||||||
Logger: l,
|
Logger: l,
|
||||||
})
|
})
|
||||||
@@ -44,31 +44,45 @@ func TestDatabaseConnection(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create database: %v", err)
|
t.Fatalf("Failed to create database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start lifecycle (this will trigger the connection)
|
return db, lc
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseConnection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, lc := setupTestDB(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
err = lc.Start(ctx)
|
|
||||||
|
err := lc.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to connect to database: %v", err)
|
t.Fatalf("Failed to connect to database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if stopErr := lc.Stop(ctx); stopErr != nil {
|
stopErr := lc.Stop(ctx)
|
||||||
t.Errorf("Failed to stop lifecycle: %v", stopErr)
|
if stopErr != nil {
|
||||||
|
t.Errorf(
|
||||||
|
"Failed to stop lifecycle: %v",
|
||||||
|
stopErr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Verify we can get the DB instance
|
|
||||||
if db.DB() == nil {
|
if db.DB() == nil {
|
||||||
t.Error("Expected non-nil database connection")
|
t.Error("Expected non-nil database connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that we can perform a simple query
|
|
||||||
var result int
|
var result int
|
||||||
|
|
||||||
err = db.DB().Raw("SELECT 1").Scan(&result).Error
|
err = db.DB().Raw("SELECT 1").Scan(&result).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to execute test query: %v", err)
|
t.Fatalf("Failed to execute test query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result != 1 {
|
if result != 1 {
|
||||||
t.Errorf("Expected query result to be 1, got %d", result)
|
t.Errorf(
|
||||||
|
"Expected query result to be 1, got %d",
|
||||||
|
result,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ import "time"
|
|||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
UserID string `gorm:"type:uuid;not null" json:"userId"`
|
||||||
Key string `gorm:"uniqueIndex;not null" json:"key"`
|
Key string `gorm:"uniqueIndex;not null" json:"key"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
LastUsedAt *time.Time `json:"lastUsedAt,omitempty"`
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
User User `json:"user,omitempty"`
|
User User `json:"user,omitzero"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package database
|
|||||||
// DeliveryStatus represents the status of a delivery
|
// DeliveryStatus represents the status of a delivery
|
||||||
type DeliveryStatus string
|
type DeliveryStatus string
|
||||||
|
|
||||||
|
// Delivery status values.
|
||||||
const (
|
const (
|
||||||
DeliveryStatusPending DeliveryStatus = "pending"
|
DeliveryStatusPending DeliveryStatus = "pending"
|
||||||
DeliveryStatusDelivered DeliveryStatus = "delivered"
|
DeliveryStatusDelivered DeliveryStatus = "delivered"
|
||||||
@@ -14,12 +15,12 @@ const (
|
|||||||
type Delivery struct {
|
type Delivery struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
EventID string `gorm:"type:uuid;not null" json:"event_id"`
|
EventID string `gorm:"type:uuid;not null" json:"eventId"`
|
||||||
TargetID string `gorm:"type:uuid;not null" json:"target_id"`
|
TargetID string `gorm:"type:uuid;not null" json:"targetId"`
|
||||||
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
|
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
Event Event `json:"event,omitempty"`
|
Event Event `json:"event,omitzero"`
|
||||||
Target Target `json:"target,omitempty"`
|
Target Target `json:"target,omitzero"`
|
||||||
DeliveryResults []DeliveryResult `json:"delivery_results,omitempty"`
|
DeliveryResults []DeliveryResult `json:"deliveryResults,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ package database
|
|||||||
type DeliveryResult struct {
|
type DeliveryResult struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
DeliveryID string `gorm:"type:uuid;not null" json:"delivery_id"`
|
DeliveryID string `gorm:"type:uuid;not null" json:"deliveryId"`
|
||||||
AttemptNum int `gorm:"not null" json:"attempt_num"`
|
AttemptNum int `gorm:"not null" json:"attemptNum"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
StatusCode int `json:"status_code,omitempty"`
|
StatusCode int `json:"statusCode,omitempty"`
|
||||||
ResponseBody string `gorm:"type:text" json:"response_body,omitempty"`
|
ResponseBody string `gorm:"type:text" json:"responseBody,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
Duration int64 `json:"duration_ms"` // Duration in milliseconds
|
Duration int64 `json:"durationMs"` // Duration in milliseconds
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
Delivery Delivery `json:"delivery,omitempty"`
|
Delivery Delivery `json:"delivery,omitzero"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ package database
|
|||||||
type Entrypoint struct {
|
type Entrypoint struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||||
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this entrypoint
|
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this entrypoint
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Active bool `gorm:"default:true" json:"active"`
|
Active bool `gorm:"default:true" json:"active"`
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
Webhook Webhook `json:"webhook,omitempty"`
|
Webhook Webhook `json:"webhook,omitzero"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ package database
|
|||||||
type Event struct {
|
type Event struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||||
EntrypointID string `gorm:"type:uuid;not null" json:"entrypoint_id"`
|
EntrypointID string `gorm:"type:uuid;not null" json:"entrypointId"`
|
||||||
|
|
||||||
// Request data
|
// Request data
|
||||||
Method string `gorm:"not null" json:"method"`
|
Method string `gorm:"not null" json:"method"`
|
||||||
Headers string `gorm:"type:text" json:"headers"` // JSON
|
Headers string `gorm:"type:text" json:"headers"` // JSON
|
||||||
Body string `gorm:"type:text" json:"body"`
|
Body string `gorm:"type:text" json:"body"`
|
||||||
ContentType string `json:"content_type"`
|
ContentType string `json:"contentType"`
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
Webhook Webhook `json:"webhook,omitempty"`
|
Webhook Webhook `json:"webhook,omitzero"`
|
||||||
Entrypoint Entrypoint `json:"entrypoint,omitempty"`
|
Entrypoint Entrypoint `json:"entrypoint,omitzero"`
|
||||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,6 @@ package database
|
|||||||
// Setting stores application-level key-value configuration.
|
// Setting stores application-level key-value configuration.
|
||||||
// Used for auto-generated values like the session encryption key.
|
// Used for auto-generated values like the session encryption key.
|
||||||
type Setting struct {
|
type Setting struct {
|
||||||
Key string `gorm:"primaryKey" json:"key"`
|
Key string `gorm:"primaryKey" json:"key"`
|
||||||
Value string `gorm:"type:text;not null" json:"value"`
|
Value string `gorm:"type:text;not null" json:"value"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package database
|
|||||||
// TargetType represents the type of delivery target
|
// TargetType represents the type of delivery target
|
||||||
type TargetType string
|
type TargetType string
|
||||||
|
|
||||||
|
// Target type values.
|
||||||
const (
|
const (
|
||||||
TargetTypeHTTP TargetType = "http"
|
TargetTypeHTTP TargetType = "http"
|
||||||
TargetTypeDatabase TargetType = "database"
|
TargetTypeDatabase TargetType = "database"
|
||||||
@@ -14,19 +15,19 @@ const (
|
|||||||
type Target struct {
|
type Target struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||||
Name string `gorm:"not null" json:"name"`
|
Name string `gorm:"not null" json:"name"`
|
||||||
Type TargetType `gorm:"not null" json:"type"`
|
Type TargetType `gorm:"not null" json:"type"`
|
||||||
Active bool `gorm:"default:true" json:"active"`
|
Active bool `gorm:"default:true" json:"active"`
|
||||||
|
|
||||||
// Configuration fields (JSON stored based on type)
|
// Configuration fields (JSON stored based on type)
|
||||||
Config string `gorm:"type:text" json:"config"` // JSON configuration
|
Config string `gorm:"type:text" json:"config"` // JSON configuration
|
||||||
|
|
||||||
// For HTTP targets (max_retries=0 means fire-and-forget, >0 enables retries with backoff)
|
// For HTTP targets (max_retries=0 means fire-and-forget, >0 enables retries with backoff)
|
||||||
MaxRetries int `json:"max_retries,omitempty"`
|
MaxRetries int `json:"maxRetries,omitempty"`
|
||||||
MaxQueueSize int `json:"max_queue_size,omitempty"`
|
MaxQueueSize int `json:"maxQueueSize,omitempty"`
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
Webhook Webhook `json:"webhook,omitempty"`
|
Webhook Webhook `json:"webhook,omitzero"`
|
||||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ type User struct {
|
|||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
||||||
Password string `gorm:"not null" json:"-"` // Argon2 hashed
|
Password string `gorm:"not null" json:"-"` // Argon2 hashed
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
Webhooks []Webhook `json:"webhooks,omitempty"`
|
Webhooks []Webhook `json:"webhooks,omitempty"`
|
||||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
APIKeys []APIKey `json:"apiKeys,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ package database
|
|||||||
type Webhook struct {
|
type Webhook struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|
||||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
UserID string `gorm:"type:uuid;not null" json:"userId"`
|
||||||
Name string `gorm:"not null" json:"name"`
|
Name string `gorm:"not null" json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
RetentionDays int `gorm:"default:30" json:"retention_days"` // Days to retain events
|
RetentionDays int `gorm:"default:30" json:"retentionDays"` // Days to retain events
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
User User `json:"user,omitempty"`
|
User User `json:"user,omitzero"`
|
||||||
Entrypoints []Entrypoint `json:"entrypoints,omitempty"`
|
Entrypoints []Entrypoint `json:"entrypoints,omitempty"`
|
||||||
Targets []Target `json:"targets,omitempty"`
|
Targets []Target `json:"targets,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -20,6 +21,23 @@ const (
|
|||||||
argon2SaltLen = 16
|
argon2SaltLen = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// hashParts is the expected number of $-separated segments
|
||||||
|
// in an encoded Argon2id hash string.
|
||||||
|
const hashParts = 6
|
||||||
|
|
||||||
|
// minPasswordComplexityLen is the minimum password length that
|
||||||
|
// triggers per-character-class complexity enforcement.
|
||||||
|
const minPasswordComplexityLen = 4
|
||||||
|
|
||||||
|
// Sentinel errors returned by decodeHash.
|
||||||
|
var (
|
||||||
|
errInvalidHashFormat = errors.New("invalid hash format")
|
||||||
|
errInvalidAlgorithm = errors.New("invalid algorithm")
|
||||||
|
errIncompatibleVersion = errors.New("incompatible argon2 version")
|
||||||
|
errSaltLengthOutOfRange = errors.New("salt length out of range")
|
||||||
|
errHashLengthOutOfRange = errors.New("hash length out of range")
|
||||||
|
)
|
||||||
|
|
||||||
// PasswordConfig holds Argon2 configuration
|
// PasswordConfig holds Argon2 configuration
|
||||||
type PasswordConfig struct {
|
type PasswordConfig struct {
|
||||||
Time uint32
|
Time uint32
|
||||||
@@ -46,26 +64,44 @@ func HashPassword(password string) (string, error) {
|
|||||||
|
|
||||||
// Generate a salt
|
// Generate a salt
|
||||||
salt := make([]byte, config.SaltLen)
|
salt := make([]byte, config.SaltLen)
|
||||||
if _, err := rand.Read(salt); err != nil {
|
|
||||||
|
_, err := rand.Read(salt)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the hash
|
// Generate the hash
|
||||||
hash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
hash := argon2.IDKey(
|
||||||
|
[]byte(password),
|
||||||
|
salt,
|
||||||
|
config.Time,
|
||||||
|
config.Memory,
|
||||||
|
config.Threads,
|
||||||
|
config.KeyLen,
|
||||||
|
)
|
||||||
|
|
||||||
// Encode the hash and parameters
|
// Encode the hash and parameters
|
||||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||||
|
|
||||||
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
|
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
|
||||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
encoded := fmt.Sprintf(
|
||||||
argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash)
|
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||||
|
argon2.Version,
|
||||||
|
config.Memory,
|
||||||
|
config.Time,
|
||||||
|
config.Threads,
|
||||||
|
b64Salt,
|
||||||
|
b64Hash,
|
||||||
|
)
|
||||||
|
|
||||||
return encoded, nil
|
return encoded, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyPassword checks if the provided password matches the hash
|
// VerifyPassword checks if the provided password matches the hash
|
||||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
func VerifyPassword(
|
||||||
|
password, encodedHash string,
|
||||||
|
) (bool, error) {
|
||||||
// Extract parameters and hash from encoded string
|
// Extract parameters and hash from encoded string
|
||||||
config, salt, hash, err := decodeHash(encodedHash)
|
config, salt, hash, err := decodeHash(encodedHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -73,60 +109,119 @@ func VerifyPassword(password, encodedHash string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate hash of the provided password
|
// Generate hash of the provided password
|
||||||
otherHash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
otherHash := argon2.IDKey(
|
||||||
|
[]byte(password),
|
||||||
|
salt,
|
||||||
|
config.Time,
|
||||||
|
config.Memory,
|
||||||
|
config.Threads,
|
||||||
|
config.KeyLen,
|
||||||
|
)
|
||||||
|
|
||||||
// Compare hashes using constant time comparison
|
// Compare hashes using constant time comparison
|
||||||
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
|
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeHash extracts parameters, salt, and hash from an encoded hash string
|
// decodeHash extracts parameters, salt, and hash from an
|
||||||
func decodeHash(encodedHash string) (*PasswordConfig, []byte, []byte, error) {
|
// encoded hash string.
|
||||||
|
func decodeHash(
|
||||||
|
encodedHash string,
|
||||||
|
) (*PasswordConfig, []byte, []byte, error) {
|
||||||
parts := strings.Split(encodedHash, "$")
|
parts := strings.Split(encodedHash, "$")
|
||||||
if len(parts) != 6 {
|
if len(parts) != hashParts {
|
||||||
return nil, nil, nil, fmt.Errorf("invalid hash format")
|
return nil, nil, nil, errInvalidHashFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
if parts[1] != "argon2id" {
|
if parts[1] != "argon2id" {
|
||||||
return nil, nil, nil, fmt.Errorf("invalid algorithm")
|
return nil, nil, nil, errInvalidAlgorithm
|
||||||
}
|
}
|
||||||
|
|
||||||
var version int
|
version, err := parseVersion(parts[2])
|
||||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if version != argon2.Version {
|
if version != argon2.Version {
|
||||||
return nil, nil, nil, fmt.Errorf("incompatible argon2 version")
|
return nil, nil, nil, errIncompatibleVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &PasswordConfig{}
|
config, err := parseParams(parts[3])
|
||||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads); err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
saltLen := len(salt)
|
|
||||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
|
||||||
return nil, nil, nil, fmt.Errorf("salt length out of range")
|
|
||||||
}
|
|
||||||
config.SaltLen = uint32(saltLen) // nolint:gosec // checked above
|
|
||||||
|
|
||||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
salt, err := decodeSalt(parts[4])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
hashLen := len(hash)
|
|
||||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
config.SaltLen = uint32(len(salt)) //nolint:gosec // validated in decodeSalt
|
||||||
return nil, nil, nil, fmt.Errorf("hash length out of range")
|
|
||||||
|
hash, err := decodeHashBytes(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
config.KeyLen = uint32(hashLen) // nolint:gosec // checked above
|
|
||||||
|
config.KeyLen = uint32(len(hash)) //nolint:gosec // validated in decodeHashBytes
|
||||||
|
|
||||||
return config, salt, hash, nil
|
return config, salt, hash, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRandomPassword generates a cryptographically secure random password
|
func parseVersion(s string) (int, error) {
|
||||||
|
var version int
|
||||||
|
|
||||||
|
_, err := fmt.Sscanf(s, "v=%d", &version)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseParams(s string) (*PasswordConfig, error) {
|
||||||
|
config := &PasswordConfig{}
|
||||||
|
|
||||||
|
_, err := fmt.Sscanf(
|
||||||
|
s, "m=%d,t=%d,p=%d",
|
||||||
|
&config.Memory, &config.Time, &config.Threads,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSalt(s string) ([]byte, error) {
|
||||||
|
salt, err := base64.RawStdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decoding salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
saltLen := len(salt)
|
||||||
|
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||||
|
return nil, errSaltLengthOutOfRange
|
||||||
|
}
|
||||||
|
|
||||||
|
return salt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeHashBytes(s string) ([]byte, error) {
|
||||||
|
hash, err := base64.RawStdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decoding hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hashLen := len(hash)
|
||||||
|
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||||
|
return nil, errHashLengthOutOfRange
|
||||||
|
}
|
||||||
|
|
||||||
|
return hash, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRandomPassword generates a cryptographically secure
|
||||||
|
// random password.
|
||||||
func GenerateRandomPassword(length int) (string, error) {
|
func GenerateRandomPassword(length int) (string, error) {
|
||||||
const (
|
const (
|
||||||
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
@@ -141,27 +236,27 @@ func GenerateRandomPassword(length int) (string, error) {
|
|||||||
// Create password slice
|
// Create password slice
|
||||||
password := make([]byte, length)
|
password := make([]byte, length)
|
||||||
|
|
||||||
// Ensure at least one character from each set for password complexity
|
// Ensure at least one character from each set
|
||||||
if length >= 4 {
|
if length >= minPasswordComplexityLen {
|
||||||
// Get one character from each set
|
|
||||||
password[0] = uppercase[cryptoRandInt(len(uppercase))]
|
password[0] = uppercase[cryptoRandInt(len(uppercase))]
|
||||||
password[1] = lowercase[cryptoRandInt(len(lowercase))]
|
password[1] = lowercase[cryptoRandInt(len(lowercase))]
|
||||||
password[2] = digits[cryptoRandInt(len(digits))]
|
password[2] = digits[cryptoRandInt(len(digits))]
|
||||||
password[3] = special[cryptoRandInt(len(special))]
|
password[3] = special[cryptoRandInt(len(special))]
|
||||||
|
|
||||||
// Fill the rest randomly from all characters
|
// Fill the rest randomly from all characters
|
||||||
for i := 4; i < length; i++ {
|
for i := minPasswordComplexityLen; i < length; i++ {
|
||||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shuffle the password to avoid predictable pattern
|
// Shuffle the password to avoid predictable pattern
|
||||||
for i := len(password) - 1; i > 0; i-- {
|
for i := range len(password) - 1 {
|
||||||
j := cryptoRandInt(i + 1)
|
j := cryptoRandInt(len(password) - i)
|
||||||
password[i], password[j] = password[j], password[i]
|
idx := len(password) - 1 - i
|
||||||
|
password[idx], password[j] = password[j], password[idx]
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For very short passwords, just use all characters
|
// For very short passwords, just use all characters
|
||||||
for i := 0; i < length; i++ {
|
for i := range length {
|
||||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -169,16 +264,17 @@ func GenerateRandomPassword(length int) (string, error) {
|
|||||||
return string(password), nil
|
return string(password), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cryptoRandInt generates a cryptographically secure random integer in [0, max)
|
// cryptoRandInt generates a cryptographically secure random
|
||||||
func cryptoRandInt(max int) int {
|
// integer in [0, upperBound).
|
||||||
if max <= 0 {
|
func cryptoRandInt(upperBound int) int {
|
||||||
panic("max must be positive")
|
if upperBound <= 0 {
|
||||||
|
panic("upperBound must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the maximum valid value to avoid modulo bias
|
nBig, err := rand.Int(
|
||||||
// For example, if max=200 and we have 256 possible values,
|
rand.Reader,
|
||||||
// we only accept values 0-199 (reject 200-255)
|
big.NewInt(int64(upperBound)),
|
||||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("crypto/rand error: %v", err))
|
panic(fmt.Sprintf("crypto/rand error: %v", err))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package database
|
package database_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"sneak.berlin/go/webhooker/internal/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGenerateRandomPassword(t *testing.T) {
|
func TestGenerateRandomPassword(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
length int
|
length int
|
||||||
@@ -18,109 +22,172 @@ func TestGenerateRandomPassword(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
password, err := GenerateRandomPassword(tt.length)
|
t.Parallel()
|
||||||
|
|
||||||
|
password, err := database.GenerateRandomPassword(
|
||||||
|
tt.length,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
t.Fatalf(
|
||||||
|
"GenerateRandomPassword() error = %v",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(password) != tt.length {
|
if len(password) != tt.length {
|
||||||
t.Errorf("Password length = %v, want %v", len(password), tt.length)
|
t.Errorf(
|
||||||
|
"Password length = %v, want %v",
|
||||||
|
len(password), tt.length,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// For passwords >= 4 chars, check complexity
|
checkPasswordComplexity(
|
||||||
if tt.length >= 4 {
|
t, password, tt.length,
|
||||||
hasUpper := false
|
)
|
||||||
hasLower := false
|
|
||||||
hasDigit := false
|
|
||||||
hasSpecial := false
|
|
||||||
|
|
||||||
for _, char := range password {
|
|
||||||
switch {
|
|
||||||
case char >= 'A' && char <= 'Z':
|
|
||||||
hasUpper = true
|
|
||||||
case char >= 'a' && char <= 'z':
|
|
||||||
hasLower = true
|
|
||||||
case char >= '0' && char <= '9':
|
|
||||||
hasDigit = true
|
|
||||||
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
|
|
||||||
hasSpecial = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
|
||||||
t.Errorf("Password lacks required complexity: upper=%v, lower=%v, digit=%v, special=%v",
|
|
||||||
hasUpper, hasLower, hasDigit, hasSpecial)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkPasswordComplexity(
|
||||||
|
t *testing.T,
|
||||||
|
password string,
|
||||||
|
length int,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// For passwords >= 4 chars, check complexity
|
||||||
|
if length < 4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flags := classifyChars(password)
|
||||||
|
|
||||||
|
if !flags[0] || !flags[1] || !flags[2] || !flags[3] {
|
||||||
|
t.Errorf(
|
||||||
|
"Password lacks required complexity: "+
|
||||||
|
"upper=%v, lower=%v, digit=%v, special=%v",
|
||||||
|
flags[0], flags[1], flags[2], flags[3],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyChars(s string) [4]bool {
|
||||||
|
var flags [4]bool // upper, lower, digit, special
|
||||||
|
|
||||||
|
for _, char := range s {
|
||||||
|
switch {
|
||||||
|
case char >= 'A' && char <= 'Z':
|
||||||
|
flags[0] = true
|
||||||
|
case char >= 'a' && char <= 'z':
|
||||||
|
flags[1] = true
|
||||||
|
case char >= '0' && char <= '9':
|
||||||
|
flags[2] = true
|
||||||
|
case strings.ContainsRune(
|
||||||
|
"!@#$%^&*()_+-=[]{}|;:,.<>?",
|
||||||
|
char,
|
||||||
|
):
|
||||||
|
flags[3] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return flags
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateRandomPasswordUniqueness(t *testing.T) {
|
func TestGenerateRandomPasswordUniqueness(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
// Generate multiple passwords and ensure they're different
|
// Generate multiple passwords and ensure they're different
|
||||||
passwords := make(map[string]bool)
|
passwords := make(map[string]bool)
|
||||||
|
|
||||||
const numPasswords = 100
|
const numPasswords = 100
|
||||||
|
|
||||||
for i := 0; i < numPasswords; i++ {
|
for range numPasswords {
|
||||||
password, err := GenerateRandomPassword(16)
|
password, err := database.GenerateRandomPassword(16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
t.Fatalf(
|
||||||
|
"GenerateRandomPassword() error = %v",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if passwords[password] {
|
if passwords[password] {
|
||||||
t.Errorf("Duplicate password generated: %s", password)
|
t.Errorf(
|
||||||
|
"Duplicate password generated: %s",
|
||||||
|
password,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
passwords[password] = true
|
passwords[password] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHashPassword(t *testing.T) {
|
func TestHashPassword(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
password := "testPassword123!"
|
password := "testPassword123!"
|
||||||
|
|
||||||
hash, err := HashPassword(password)
|
hash, err := database.HashPassword(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("HashPassword() error = %v", err)
|
t.Fatalf("HashPassword() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that hash has correct format
|
// Check that hash has correct format
|
||||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||||
t.Errorf("Hash doesn't have correct prefix: %s", hash)
|
t.Errorf(
|
||||||
|
"Hash doesn't have correct prefix: %s",
|
||||||
|
hash,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify password
|
// Verify password
|
||||||
valid, err := VerifyPassword(password, hash)
|
valid, err := database.VerifyPassword(password, hash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("VerifyPassword() error = %v", err)
|
t.Fatalf("VerifyPassword() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("VerifyPassword() returned false for correct password")
|
t.Error(
|
||||||
|
"VerifyPassword() returned false " +
|
||||||
|
"for correct password",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify wrong password fails
|
// Verify wrong password fails
|
||||||
valid, err = VerifyPassword("wrongPassword", hash)
|
valid, err = database.VerifyPassword(
|
||||||
|
"wrongPassword", hash,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("VerifyPassword() error = %v", err)
|
t.Fatalf("VerifyPassword() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if valid {
|
if valid {
|
||||||
t.Error("VerifyPassword() returned true for wrong password")
|
t.Error(
|
||||||
|
"VerifyPassword() returned true " +
|
||||||
|
"for wrong password",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHashPasswordUniqueness(t *testing.T) {
|
func TestHashPasswordUniqueness(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
password := "testPassword123!"
|
password := "testPassword123!"
|
||||||
|
|
||||||
// Same password should produce different hashes due to salt
|
// Same password should produce different hashes
|
||||||
hash1, err := HashPassword(password)
|
hash1, err := database.HashPassword(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("HashPassword() error = %v", err)
|
t.Fatalf("HashPassword() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hash2, err := HashPassword(password)
|
hash2, err := database.HashPassword(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("HashPassword() error = %v", err)
|
t.Fatalf("HashPassword() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hash1 == hash2 {
|
if hash1 == hash2 {
|
||||||
t.Error("Same password produced identical hashes (salt not working)")
|
t.Error(
|
||||||
|
"Same password produced identical hashes " +
|
||||||
|
"(salt not working)",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package database
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
@@ -16,87 +17,82 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // WebhookDBManagerParams is a standard fx naming convention
|
// WebhookDBManagerParams holds the fx dependencies for
|
||||||
|
// WebhookDBManager.
|
||||||
type WebhookDBManagerParams struct {
|
type WebhookDBManagerParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// WebhookDBManager manages per-webhook SQLite database files for event storage.
|
// errInvalidCachedDBType indicates a type assertion failure
|
||||||
// Each webhook gets its own dedicated database containing Events, Deliveries,
|
// when retrieving a cached database connection.
|
||||||
// and DeliveryResults. Database connections are opened lazily and cached.
|
var errInvalidCachedDBType = errors.New(
|
||||||
|
"invalid cached database type",
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebhookDBManager manages per-webhook SQLite database files
|
||||||
|
// for event storage. Each webhook gets its own dedicated
|
||||||
|
// database containing Events, Deliveries, and DeliveryResults.
|
||||||
|
// Database connections are opened lazily and cached.
|
||||||
type WebhookDBManager struct {
|
type WebhookDBManager struct {
|
||||||
dataDir string
|
dataDir string
|
||||||
dbs sync.Map // map[webhookID]*gorm.DB
|
dbs sync.Map // map[webhookID]*gorm.DB
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWebhookDBManager creates a new WebhookDBManager and registers lifecycle hooks.
|
// NewWebhookDBManager creates a new WebhookDBManager and
|
||||||
func NewWebhookDBManager(lc fx.Lifecycle, params WebhookDBManagerParams) (*WebhookDBManager, error) {
|
// registers lifecycle hooks.
|
||||||
|
func NewWebhookDBManager(
|
||||||
|
lc fx.Lifecycle,
|
||||||
|
params WebhookDBManagerParams,
|
||||||
|
) (*WebhookDBManager, error) {
|
||||||
m := &WebhookDBManager{
|
m := &WebhookDBManager{
|
||||||
dataDir: params.Config.DataDir,
|
dataDir: params.Config.DataDir,
|
||||||
log: params.Logger.Get(),
|
log: params.Logger.Get(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create data directory if it doesn't exist
|
// Create data directory if it doesn't exist
|
||||||
if err := os.MkdirAll(m.dataDir, 0750); err != nil {
|
err := os.MkdirAll(m.dataDir, dataDirPerm)
|
||||||
return nil, fmt.Errorf("creating data directory %s: %w", m.dataDir, err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"creating data directory %s: %w",
|
||||||
|
m.dataDir,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
lc.Append(fx.Hook{
|
lc.Append(fx.Hook{
|
||||||
OnStop: func(_ context.Context) error { //nolint:revive // ctx unused but required by fx
|
OnStop: func(_ context.Context) error {
|
||||||
return m.CloseAll()
|
return m.CloseAll()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
m.log.Info("webhook database manager initialized", "data_dir", m.dataDir)
|
m.log.Info(
|
||||||
|
"webhook database manager initialized",
|
||||||
|
"data_dir", m.dataDir,
|
||||||
|
)
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// dbPath returns the filesystem path for a webhook's database file.
|
// GetDB returns the database connection for a webhook,
|
||||||
func (m *WebhookDBManager) dbPath(webhookID string) string {
|
// creating the database file lazily if it doesn't exist.
|
||||||
return filepath.Join(m.dataDir, fmt.Sprintf("events-%s.db", webhookID))
|
func (m *WebhookDBManager) GetDB(
|
||||||
}
|
webhookID string,
|
||||||
|
) (*gorm.DB, error) {
|
||||||
// openDB opens (or creates) a per-webhook SQLite database and runs migrations.
|
|
||||||
func (m *WebhookDBManager) openDB(webhookID string) (*gorm.DB, error) {
|
|
||||||
path := m.dbPath(webhookID)
|
|
||||||
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", path)
|
|
||||||
|
|
||||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("opening webhook database %s: %w", webhookID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := gorm.Open(sqlite.Dialector{
|
|
||||||
Conn: sqlDB,
|
|
||||||
}, &gorm.Config{})
|
|
||||||
if err != nil {
|
|
||||||
sqlDB.Close()
|
|
||||||
return nil, fmt.Errorf("connecting to webhook database %s: %w", webhookID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run migrations for event-tier models only
|
|
||||||
if err := db.AutoMigrate(&Event{}, &Delivery{}, &DeliveryResult{}); err != nil {
|
|
||||||
sqlDB.Close()
|
|
||||||
return nil, fmt.Errorf("migrating webhook database %s: %w", webhookID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.log.Info("opened per-webhook database", "webhook_id", webhookID, "path", path)
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDB returns the database connection for a webhook, creating the database
|
|
||||||
// file lazily if it doesn't exist. This handles both new webhooks and existing
|
|
||||||
// webhooks that were created before per-webhook databases were introduced.
|
|
||||||
func (m *WebhookDBManager) GetDB(webhookID string) (*gorm.DB, error) {
|
|
||||||
// Fast path: already open
|
// Fast path: already open
|
||||||
if val, ok := m.dbs.Load(webhookID); ok {
|
if val, ok := m.dbs.Load(webhookID); ok {
|
||||||
cachedDB, castOK := val.(*gorm.DB)
|
cachedDB, castOK := val.(*gorm.DB)
|
||||||
if !castOK {
|
if !castOK {
|
||||||
return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
|
return nil, fmt.Errorf(
|
||||||
|
"%w for webhook %s",
|
||||||
|
errInvalidCachedDBType,
|
||||||
|
webhookID,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cachedDB, nil
|
return cachedDB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,44 +102,61 @@ func (m *WebhookDBManager) GetDB(webhookID string) (*gorm.DB, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store it; if another goroutine beat us, close ours and use theirs
|
// Store it; if another goroutine beat us, close ours
|
||||||
actual, loaded := m.dbs.LoadOrStore(webhookID, db)
|
actual, loaded := m.dbs.LoadOrStore(webhookID, db)
|
||||||
if loaded {
|
if loaded {
|
||||||
// Another goroutine created it first; close our duplicate
|
// Another goroutine created it first; close our duplicate
|
||||||
if sqlDB, closeErr := db.DB(); closeErr == nil {
|
sqlDB, closeErr := db.DB()
|
||||||
sqlDB.Close()
|
if closeErr == nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
existingDB, castOK := actual.(*gorm.DB)
|
existingDB, castOK := actual.(*gorm.DB)
|
||||||
if !castOK {
|
if !castOK {
|
||||||
return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
|
return nil, fmt.Errorf(
|
||||||
|
"%w for webhook %s",
|
||||||
|
errInvalidCachedDBType,
|
||||||
|
webhookID,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return existingDB, nil
|
return existingDB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDB explicitly creates a new per-webhook database file and runs migrations.
|
// CreateDB explicitly creates a new per-webhook database file
|
||||||
// This is called when a new webhook is created.
|
// and runs migrations.
|
||||||
func (m *WebhookDBManager) CreateDB(webhookID string) error {
|
func (m *WebhookDBManager) CreateDB(
|
||||||
|
webhookID string,
|
||||||
|
) error {
|
||||||
_, err := m.GetDB(webhookID)
|
_, err := m.GetDB(webhookID)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DBExists checks if a per-webhook database file exists on disk.
|
// DBExists checks if a per-webhook database file exists on
|
||||||
func (m *WebhookDBManager) DBExists(webhookID string) bool {
|
// disk.
|
||||||
|
func (m *WebhookDBManager) DBExists(
|
||||||
|
webhookID string,
|
||||||
|
) bool {
|
||||||
_, err := os.Stat(m.dbPath(webhookID))
|
_, err := os.Stat(m.dbPath(webhookID))
|
||||||
|
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDB closes the connection and deletes the database file for a webhook.
|
// DeleteDB closes the connection and deletes the database file
|
||||||
// This performs a hard delete — the file is permanently removed.
|
// for a webhook. The file is permanently removed.
|
||||||
func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
func (m *WebhookDBManager) DeleteDB(
|
||||||
|
webhookID string,
|
||||||
|
) error {
|
||||||
// Close and remove from cache
|
// Close and remove from cache
|
||||||
if val, ok := m.dbs.LoadAndDelete(webhookID); ok {
|
if val, ok := m.dbs.LoadAndDelete(webhookID); ok {
|
||||||
if gormDB, castOK := val.(*gorm.DB); castOK {
|
if gormDB, castOK := val.(*gorm.DB); castOK {
|
||||||
if sqlDB, err := gormDB.DB(); err == nil {
|
sqlDB, err := gormDB.DB()
|
||||||
sqlDB.Close()
|
if err == nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -151,12 +164,20 @@ func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
|||||||
// Delete the main DB file and WAL/SHM files
|
// Delete the main DB file and WAL/SHM files
|
||||||
path := m.dbPath(webhookID)
|
path := m.dbPath(webhookID)
|
||||||
for _, suffix := range []string{"", "-wal", "-shm"} {
|
for _, suffix := range []string{"", "-wal", "-shm"} {
|
||||||
if err := os.Remove(path + suffix); err != nil && !os.IsNotExist(err) {
|
err := os.Remove(path + suffix)
|
||||||
return fmt.Errorf("deleting webhook database file %s%s: %w", path, suffix, err)
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"deleting webhook database file %s%s: %w",
|
||||||
|
path, suffix, err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.log.Info("deleted per-webhook database", "webhook_id", webhookID)
|
m.log.Info(
|
||||||
|
"deleted per-webhook database",
|
||||||
|
"webhook_id", webhookID,
|
||||||
|
)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,20 +185,97 @@ func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
|||||||
// Called during application shutdown.
|
// Called during application shutdown.
|
||||||
func (m *WebhookDBManager) CloseAll() error {
|
func (m *WebhookDBManager) CloseAll() error {
|
||||||
var lastErr error
|
var lastErr error
|
||||||
m.dbs.Range(func(key, value interface{}) bool {
|
|
||||||
|
m.dbs.Range(func(key, value any) bool {
|
||||||
if gormDB, castOK := value.(*gorm.DB); castOK {
|
if gormDB, castOK := value.(*gorm.DB); castOK {
|
||||||
if sqlDB, err := gormDB.DB(); err == nil {
|
sqlDB, err := gormDB.DB()
|
||||||
if closeErr := sqlDB.Close(); closeErr != nil {
|
if err == nil {
|
||||||
|
closeErr := sqlDB.Close()
|
||||||
|
if closeErr != nil {
|
||||||
lastErr = closeErr
|
lastErr = closeErr
|
||||||
m.log.Error("failed to close webhook database",
|
m.log.Error(
|
||||||
|
"failed to close webhook database",
|
||||||
"webhook_id", key,
|
"webhook_id", key,
|
||||||
"error", closeErr,
|
"error", closeErr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.dbs.Delete(key)
|
m.dbs.Delete(key)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
return lastErr
|
return lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DBPath returns the filesystem path for a webhook's database
|
||||||
|
// file.
|
||||||
|
func (m *WebhookDBManager) DBPath(
|
||||||
|
webhookID string,
|
||||||
|
) string {
|
||||||
|
return m.dbPath(webhookID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WebhookDBManager) dbPath(
|
||||||
|
webhookID string,
|
||||||
|
) string {
|
||||||
|
return filepath.Join(
|
||||||
|
m.dataDir,
|
||||||
|
fmt.Sprintf("events-%s.db", webhookID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// openDB opens (or creates) a per-webhook SQLite database and
|
||||||
|
// runs migrations.
|
||||||
|
func (m *WebhookDBManager) openDB(
|
||||||
|
webhookID string,
|
||||||
|
) (*gorm.DB, error) {
|
||||||
|
path := m.dbPath(webhookID)
|
||||||
|
dbURL := fmt.Sprintf(
|
||||||
|
"file:%s?cache=shared&mode=rwc",
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
|
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"opening webhook database %s: %w",
|
||||||
|
webhookID, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Dialector{
|
||||||
|
Conn: sqlDB,
|
||||||
|
}, &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"connecting to webhook database %s: %w",
|
||||||
|
webhookID, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run migrations for event-tier models only
|
||||||
|
err = db.AutoMigrate(
|
||||||
|
&Event{}, &Delivery{}, &DeliveryResult{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"migrating webhook database %s: %w",
|
||||||
|
webhookID, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.log.Info(
|
||||||
|
"opened per-webhook database",
|
||||||
|
"webhook_id", webhookID,
|
||||||
|
"path", path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package database
|
package database_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -10,23 +10,29 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/fx/fxtest"
|
"go.uber.org/fx/fxtest"
|
||||||
|
"gorm.io/gorm"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/database"
|
||||||
"sneak.berlin/go/webhooker/internal/globals"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupTestWebhookDBManager(t *testing.T) (*WebhookDBManager, *fxtest.Lifecycle) {
|
func setupTestWebhookDBManager(
|
||||||
|
t *testing.T,
|
||||||
|
) (*database.WebhookDBManager, *fxtest.Lifecycle) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
lc := fxtest.NewLifecycle(t)
|
lc := fxtest.NewLifecycle(t)
|
||||||
|
|
||||||
globals.Appname = "webhooker-test"
|
g := &globals.Globals{
|
||||||
globals.Version = "test"
|
Appname: "webhooker-test",
|
||||||
|
Version: "test",
|
||||||
|
}
|
||||||
|
|
||||||
g, err := globals.New(lc)
|
l, err := logger.New(
|
||||||
require.NoError(t, err)
|
lc,
|
||||||
|
logger.LoggerParams{Globals: g},
|
||||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
dataDir := filepath.Join(t.TempDir(), "events")
|
dataDir := filepath.Join(t.TempDir(), "events")
|
||||||
@@ -35,19 +41,25 @@ func setupTestWebhookDBManager(t *testing.T) (*WebhookDBManager, *fxtest.Lifecyc
|
|||||||
DataDir: dataDir,
|
DataDir: dataDir,
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr, err := NewWebhookDBManager(lc, WebhookDBManagerParams{
|
mgr, err := database.NewWebhookDBManager(
|
||||||
Config: cfg,
|
lc,
|
||||||
Logger: l,
|
database.WebhookDBManagerParams{
|
||||||
})
|
Config: cfg,
|
||||||
|
Logger: l,
|
||||||
|
},
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return mgr, lc
|
return mgr, lc
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
mgr, lc := setupTestWebhookDBManager(t)
|
mgr, lc := setupTestWebhookDBManager(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
require.NoError(t, lc.Start(ctx))
|
require.NoError(t, lc.Start(ctx))
|
||||||
|
|
||||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||||
|
|
||||||
webhookID := uuid.New().String()
|
webhookID := uuid.New().String()
|
||||||
@@ -68,7 +80,7 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
|||||||
require.NotNil(t, db)
|
require.NotNil(t, db)
|
||||||
|
|
||||||
// Verify we can write an event
|
// Verify we can write an event
|
||||||
event := &Event{
|
event := &database.Event{
|
||||||
WebhookID: webhookID,
|
WebhookID: webhookID,
|
||||||
EntrypointID: uuid.New().String(),
|
EntrypointID: uuid.New().String(),
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
@@ -80,27 +92,35 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
|||||||
assert.NotEmpty(t, event.ID)
|
assert.NotEmpty(t, event.ID)
|
||||||
|
|
||||||
// Verify we can read it back
|
// Verify we can read it back
|
||||||
var readEvent Event
|
var readEvent database.Event
|
||||||
require.NoError(t, db.First(&readEvent, "id = ?", event.ID).Error)
|
|
||||||
|
require.NoError(
|
||||||
|
t,
|
||||||
|
db.First(&readEvent, "id = ?", event.ID).Error,
|
||||||
|
)
|
||||||
assert.Equal(t, webhookID, readEvent.WebhookID)
|
assert.Equal(t, webhookID, readEvent.WebhookID)
|
||||||
assert.Equal(t, "POST", readEvent.Method)
|
assert.Equal(t, "POST", readEvent.Method)
|
||||||
assert.Equal(t, `{"test": true}`, readEvent.Body)
|
assert.Equal(t, `{"test": true}`, readEvent.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
mgr, lc := setupTestWebhookDBManager(t)
|
mgr, lc := setupTestWebhookDBManager(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
require.NoError(t, lc.Start(ctx))
|
require.NoError(t, lc.Start(ctx))
|
||||||
|
|
||||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||||
|
|
||||||
webhookID := uuid.New().String()
|
webhookID := uuid.New().String()
|
||||||
|
|
||||||
// Create the DB and write some data
|
// Create the DB and write some data
|
||||||
require.NoError(t, mgr.CreateDB(webhookID))
|
require.NoError(t, mgr.CreateDB(webhookID))
|
||||||
|
|
||||||
db, err := mgr.GetDB(webhookID)
|
db, err := mgr.GetDB(webhookID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
event := &Event{
|
event := &database.Event{
|
||||||
WebhookID: webhookID,
|
WebhookID: webhookID,
|
||||||
EntrypointID: uuid.New().String(),
|
EntrypointID: uuid.New().String(),
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
@@ -116,15 +136,19 @@ func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
|||||||
assert.False(t, mgr.DBExists(webhookID))
|
assert.False(t, mgr.DBExists(webhookID))
|
||||||
|
|
||||||
// Verify the file is actually gone from disk
|
// Verify the file is actually gone from disk
|
||||||
dbPath := mgr.dbPath(webhookID)
|
dbPath := mgr.DBPath(webhookID)
|
||||||
|
|
||||||
_, err = os.Stat(dbPath)
|
_, err = os.Stat(dbPath)
|
||||||
assert.True(t, os.IsNotExist(err))
|
assert.True(t, os.IsNotExist(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
mgr, lc := setupTestWebhookDBManager(t)
|
mgr, lc := setupTestWebhookDBManager(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
require.NoError(t, lc.Start(ctx))
|
require.NoError(t, lc.Start(ctx))
|
||||||
|
|
||||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||||
|
|
||||||
webhookID := uuid.New().String()
|
webhookID := uuid.New().String()
|
||||||
@@ -139,9 +163,12 @@ func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
mgr, lc := setupTestWebhookDBManager(t)
|
mgr, lc := setupTestWebhookDBManager(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
require.NoError(t, lc.Start(ctx))
|
require.NoError(t, lc.Start(ctx))
|
||||||
|
|
||||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||||
|
|
||||||
webhookID := uuid.New().String()
|
webhookID := uuid.New().String()
|
||||||
@@ -150,8 +177,23 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
|||||||
db, err := mgr.GetDB(webhookID)
|
db, err := mgr.GetDB(webhookID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create an event
|
event, delivery := seedDeliveryWorkflow(
|
||||||
event := &Event{
|
t, db, webhookID, targetID,
|
||||||
|
)
|
||||||
|
|
||||||
|
verifyPendingDeliveries(t, db, event)
|
||||||
|
completeDelivery(t, db, delivery)
|
||||||
|
verifyNoPending(t, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedDeliveryWorkflow(
|
||||||
|
t *testing.T,
|
||||||
|
db *gorm.DB,
|
||||||
|
webhookID, targetID string,
|
||||||
|
) (*database.Event, *database.Delivery) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
event := &database.Event{
|
||||||
WebhookID: webhookID,
|
WebhookID: webhookID,
|
||||||
EntrypointID: uuid.New().String(),
|
EntrypointID: uuid.New().String(),
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
@@ -161,25 +203,45 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.NoError(t, db.Create(event).Error)
|
require.NoError(t, db.Create(event).Error)
|
||||||
|
|
||||||
// Create a delivery
|
delivery := &database.Delivery{
|
||||||
delivery := &Delivery{
|
|
||||||
EventID: event.ID,
|
EventID: event.ID,
|
||||||
TargetID: targetID,
|
TargetID: targetID,
|
||||||
Status: DeliveryStatusPending,
|
Status: database.DeliveryStatusPending,
|
||||||
}
|
}
|
||||||
require.NoError(t, db.Create(delivery).Error)
|
require.NoError(t, db.Create(delivery).Error)
|
||||||
|
|
||||||
// Query pending deliveries
|
return event, delivery
|
||||||
var pending []Delivery
|
}
|
||||||
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).
|
|
||||||
Preload("Event").
|
func verifyPendingDeliveries(
|
||||||
Find(&pending).Error)
|
t *testing.T,
|
||||||
|
db *gorm.DB,
|
||||||
|
event *database.Event,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var pending []database.Delivery
|
||||||
|
|
||||||
|
require.NoError(
|
||||||
|
t,
|
||||||
|
db.Where(
|
||||||
|
"status = ?",
|
||||||
|
database.DeliveryStatusPending,
|
||||||
|
).Preload("Event").Find(&pending).Error,
|
||||||
|
)
|
||||||
require.Len(t, pending, 1)
|
require.Len(t, pending, 1)
|
||||||
assert.Equal(t, event.ID, pending[0].EventID)
|
assert.Equal(t, event.ID, pending[0].EventID)
|
||||||
assert.Equal(t, "POST", pending[0].Event.Method)
|
assert.Equal(t, "POST", pending[0].Event.Method)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a delivery result
|
func completeDelivery(
|
||||||
result := &DeliveryResult{
|
t *testing.T,
|
||||||
|
db *gorm.DB,
|
||||||
|
delivery *database.Delivery,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
result := &database.DeliveryResult{
|
||||||
DeliveryID: delivery.ID,
|
DeliveryID: delivery.ID,
|
||||||
AttemptNum: 1,
|
AttemptNum: 1,
|
||||||
Success: true,
|
Success: true,
|
||||||
@@ -188,19 +250,40 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.NoError(t, db.Create(result).Error)
|
require.NoError(t, db.Create(result).Error)
|
||||||
|
|
||||||
// Update delivery status
|
require.NoError(
|
||||||
require.NoError(t, db.Model(delivery).Update("status", DeliveryStatusDelivered).Error)
|
t,
|
||||||
|
db.Model(delivery).Update(
|
||||||
|
"status",
|
||||||
|
database.DeliveryStatusDelivered,
|
||||||
|
).Error,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Verify no more pending deliveries
|
func verifyNoPending(
|
||||||
var stillPending []Delivery
|
t *testing.T,
|
||||||
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).Find(&stillPending).Error)
|
db *gorm.DB,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var stillPending []database.Delivery
|
||||||
|
|
||||||
|
require.NoError(
|
||||||
|
t,
|
||||||
|
db.Where(
|
||||||
|
"status = ?",
|
||||||
|
database.DeliveryStatusPending,
|
||||||
|
).Find(&stillPending).Error,
|
||||||
|
)
|
||||||
assert.Empty(t, stillPending)
|
assert.Empty(t, stillPending)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
mgr, lc := setupTestWebhookDBManager(t)
|
mgr, lc := setupTestWebhookDBManager(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
require.NoError(t, lc.Start(ctx))
|
require.NoError(t, lc.Start(ctx))
|
||||||
|
|
||||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||||
|
|
||||||
webhook1 := uuid.New().String()
|
webhook1 := uuid.New().String()
|
||||||
@@ -212,34 +295,38 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
|||||||
|
|
||||||
db1, err := mgr.GetDB(webhook1)
|
db1, err := mgr.GetDB(webhook1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
db2, err := mgr.GetDB(webhook2)
|
db2, err := mgr.GetDB(webhook2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Write events to each webhook's DB
|
// Write events to each webhook's DB
|
||||||
event1 := &Event{
|
event1 := &database.Event{
|
||||||
WebhookID: webhook1,
|
WebhookID: webhook1,
|
||||||
EntrypointID: uuid.New().String(),
|
EntrypointID: uuid.New().String(),
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
Body: `{"webhook": 1}`,
|
Body: `{"webhook": 1}`,
|
||||||
ContentType: "application/json",
|
ContentType: "application/json",
|
||||||
}
|
}
|
||||||
event2 := &Event{
|
event2 := &database.Event{
|
||||||
WebhookID: webhook2,
|
WebhookID: webhook2,
|
||||||
EntrypointID: uuid.New().String(),
|
EntrypointID: uuid.New().String(),
|
||||||
Method: "PUT",
|
Method: "PUT",
|
||||||
Body: `{"webhook": 2}`,
|
Body: `{"webhook": 2}`,
|
||||||
ContentType: "application/json",
|
ContentType: "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, db1.Create(event1).Error)
|
require.NoError(t, db1.Create(event1).Error)
|
||||||
require.NoError(t, db2.Create(event2).Error)
|
require.NoError(t, db2.Create(event2).Error)
|
||||||
|
|
||||||
// Verify isolation: each DB only has its own events
|
// Verify isolation: each DB only has its own events
|
||||||
var count1 int64
|
var count1 int64
|
||||||
db1.Model(&Event{}).Count(&count1)
|
|
||||||
|
db1.Model(&database.Event{}).Count(&count1)
|
||||||
assert.Equal(t, int64(1), count1)
|
assert.Equal(t, int64(1), count1)
|
||||||
|
|
||||||
var count2 int64
|
var count2 int64
|
||||||
db2.Model(&Event{}).Count(&count2)
|
|
||||||
|
db2.Model(&database.Event{}).Count(&count2)
|
||||||
assert.Equal(t, int64(1), count2)
|
assert.Equal(t, int64(1), count2)
|
||||||
|
|
||||||
// Delete webhook1's DB, webhook2 should be unaffected
|
// Delete webhook1's DB, webhook2 should be unaffected
|
||||||
@@ -248,25 +335,31 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
|||||||
assert.True(t, mgr.DBExists(webhook2))
|
assert.True(t, mgr.DBExists(webhook2))
|
||||||
|
|
||||||
// webhook2's data should still be accessible
|
// webhook2's data should still be accessible
|
||||||
var events []Event
|
var events []database.Event
|
||||||
|
|
||||||
require.NoError(t, db2.Find(&events).Error)
|
require.NoError(t, db2.Find(&events).Error)
|
||||||
assert.Len(t, events, 1)
|
assert.Len(t, events, 1)
|
||||||
assert.Equal(t, "PUT", events[0].Method)
|
assert.Equal(t, "PUT", events[0].Method)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebhookDBManager_CloseAll(t *testing.T) {
|
func TestWebhookDBManager_CloseAll(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
mgr, lc := setupTestWebhookDBManager(t)
|
mgr, lc := setupTestWebhookDBManager(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
require.NoError(t, lc.Start(ctx))
|
require.NoError(t, lc.Start(ctx))
|
||||||
|
|
||||||
// Create a few DBs
|
// Create a few DBs
|
||||||
for i := 0; i < 3; i++ {
|
for range 3 {
|
||||||
require.NoError(t, mgr.CreateDB(uuid.New().String()))
|
require.NoError(
|
||||||
|
t,
|
||||||
|
mgr.CreateDB(uuid.New().String()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseAll should close all connections without error
|
// CloseAll should close all connections without error
|
||||||
require.NoError(t, mgr.CloseAll())
|
require.NoError(t, mgr.CloseAll())
|
||||||
|
|
||||||
// Stop lifecycle (CloseAll already called, but shouldn't panic)
|
// Stop lifecycle (CloseAll already called)
|
||||||
require.NoError(t, lc.Stop(ctx))
|
require.NoError(t, lc.Stop(ctx))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,41 +5,32 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CircuitState represents the current state of a circuit breaker.
|
// CircuitState represents the current state of a circuit
|
||||||
|
// breaker.
|
||||||
type CircuitState int
|
type CircuitState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// CircuitClosed is the normal operating state. Deliveries flow through.
|
// CircuitClosed is the normal operating state.
|
||||||
CircuitClosed CircuitState = iota
|
CircuitClosed CircuitState = iota
|
||||||
// CircuitOpen means the circuit has tripped. Deliveries are skipped
|
// CircuitOpen means the circuit has tripped.
|
||||||
// until the cooldown expires.
|
|
||||||
CircuitOpen
|
CircuitOpen
|
||||||
// CircuitHalfOpen allows a single probe delivery to test whether
|
// CircuitHalfOpen allows a single probe delivery to
|
||||||
// the target has recovered.
|
// test whether the target has recovered.
|
||||||
CircuitHalfOpen
|
CircuitHalfOpen
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// defaultFailureThreshold is the number of consecutive failures
|
// defaultFailureThreshold is the number of consecutive
|
||||||
// before a circuit breaker trips open.
|
// failures before a circuit breaker trips open.
|
||||||
defaultFailureThreshold = 5
|
defaultFailureThreshold = 5
|
||||||
|
|
||||||
// defaultCooldown is how long a circuit stays open before
|
// defaultCooldown is how long a circuit stays open
|
||||||
// transitioning to half-open for a probe delivery.
|
// before transitioning to half-open.
|
||||||
defaultCooldown = 30 * time.Second
|
defaultCooldown = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// CircuitBreaker implements the circuit breaker pattern for a single
|
// CircuitBreaker implements the circuit breaker pattern
|
||||||
// delivery target. It tracks consecutive failures and prevents
|
// for a single delivery target.
|
||||||
// hammering a down target by temporarily stopping delivery attempts.
|
|
||||||
//
|
|
||||||
// States:
|
|
||||||
// - Closed (normal): deliveries flow through; consecutive failures
|
|
||||||
// are counted.
|
|
||||||
// - Open (tripped): deliveries are skipped; a cooldown timer is
|
|
||||||
// running. After the cooldown expires the state moves to HalfOpen.
|
|
||||||
// - HalfOpen (probing): one probe delivery is allowed. If it
|
|
||||||
// succeeds the circuit closes; if it fails the circuit reopens.
|
|
||||||
type CircuitBreaker struct {
|
type CircuitBreaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
state CircuitState
|
state CircuitState
|
||||||
@@ -49,7 +40,8 @@ type CircuitBreaker struct {
|
|||||||
lastFailure time.Time
|
lastFailure time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCircuitBreaker creates a circuit breaker with default settings.
|
// NewCircuitBreaker creates a circuit breaker with default
|
||||||
|
// settings.
|
||||||
func NewCircuitBreaker() *CircuitBreaker {
|
func NewCircuitBreaker() *CircuitBreaker {
|
||||||
return &CircuitBreaker{
|
return &CircuitBreaker{
|
||||||
state: CircuitClosed,
|
state: CircuitClosed,
|
||||||
@@ -58,12 +50,7 @@ func NewCircuitBreaker() *CircuitBreaker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow checks whether a delivery attempt should proceed. It returns
|
// Allow checks whether a delivery attempt should proceed.
|
||||||
// true if the delivery should be attempted, false if the circuit is
|
|
||||||
// open and the delivery should be skipped.
|
|
||||||
//
|
|
||||||
// When the circuit is open and the cooldown has elapsed, Allow
|
|
||||||
// transitions to half-open and permits exactly one probe delivery.
|
|
||||||
func (cb *CircuitBreaker) Allow() bool {
|
func (cb *CircuitBreaker) Allow() bool {
|
||||||
cb.mu.Lock()
|
cb.mu.Lock()
|
||||||
defer cb.mu.Unlock()
|
defer cb.mu.Unlock()
|
||||||
@@ -73,17 +60,15 @@ func (cb *CircuitBreaker) Allow() bool {
|
|||||||
return true
|
return true
|
||||||
|
|
||||||
case CircuitOpen:
|
case CircuitOpen:
|
||||||
// Check if cooldown has elapsed
|
|
||||||
if time.Since(cb.lastFailure) >= cb.cooldown {
|
if time.Since(cb.lastFailure) >= cb.cooldown {
|
||||||
cb.state = CircuitHalfOpen
|
cb.state = CircuitHalfOpen
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|
||||||
case CircuitHalfOpen:
|
case CircuitHalfOpen:
|
||||||
// Only one probe at a time — reject additional attempts while
|
|
||||||
// a probe is in flight. The probe goroutine will call
|
|
||||||
// RecordSuccess or RecordFailure to resolve the state.
|
|
||||||
return false
|
return false
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -91,9 +76,8 @@ func (cb *CircuitBreaker) Allow() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CooldownRemaining returns how much time is left before an open circuit
|
// CooldownRemaining returns how much time is left before
|
||||||
// transitions to half-open. Returns zero if the circuit is not open or
|
// an open circuit transitions to half-open.
|
||||||
// the cooldown has already elapsed.
|
|
||||||
func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
||||||
cb.mu.Lock()
|
cb.mu.Lock()
|
||||||
defer cb.mu.Unlock()
|
defer cb.mu.Unlock()
|
||||||
@@ -106,11 +90,12 @@ func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
|||||||
if remaining < 0 {
|
if remaining < 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return remaining
|
return remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordSuccess records a successful delivery and resets the circuit
|
// RecordSuccess records a successful delivery and resets
|
||||||
// breaker to closed state with zero failures.
|
// the circuit breaker to closed state.
|
||||||
func (cb *CircuitBreaker) RecordSuccess() {
|
func (cb *CircuitBreaker) RecordSuccess() {
|
||||||
cb.mu.Lock()
|
cb.mu.Lock()
|
||||||
defer cb.mu.Unlock()
|
defer cb.mu.Unlock()
|
||||||
@@ -119,8 +104,8 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
|||||||
cb.state = CircuitClosed
|
cb.state = CircuitClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordFailure records a failed delivery. If the failure count reaches
|
// RecordFailure records a failed delivery. If the failure
|
||||||
// the threshold, the circuit trips open.
|
// count reaches the threshold, the circuit trips open.
|
||||||
func (cb *CircuitBreaker) RecordFailure() {
|
func (cb *CircuitBreaker) RecordFailure() {
|
||||||
cb.mu.Lock()
|
cb.mu.Lock()
|
||||||
defer cb.mu.Unlock()
|
defer cb.mu.Unlock()
|
||||||
@@ -134,20 +119,25 @@ func (cb *CircuitBreaker) RecordFailure() {
|
|||||||
cb.state = CircuitOpen
|
cb.state = CircuitOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case CircuitOpen:
|
||||||
|
// Already open; no state change needed.
|
||||||
|
|
||||||
case CircuitHalfOpen:
|
case CircuitHalfOpen:
|
||||||
// Probe failed — reopen immediately
|
// Probe failed -- reopen immediately.
|
||||||
cb.state = CircuitOpen
|
cb.state = CircuitOpen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// State returns the current circuit state. Safe for concurrent use.
|
// State returns the current circuit state.
|
||||||
func (cb *CircuitBreaker) State() CircuitState {
|
func (cb *CircuitBreaker) State() CircuitState {
|
||||||
cb.mu.Lock()
|
cb.mu.Lock()
|
||||||
defer cb.mu.Unlock()
|
defer cb.mu.Unlock()
|
||||||
|
|
||||||
return cb.state
|
return cb.state
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the human-readable name of a circuit state.
|
// String returns the human-readable name of a circuit
|
||||||
|
// state.
|
||||||
func (s CircuitState) String() string {
|
func (s CircuitState) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
case CircuitClosed:
|
case CircuitClosed:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package delivery
|
package delivery_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
@@ -7,237 +7,304 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"sneak.berlin/go/webhooker/internal/delivery"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCircuitBreaker_ClosedState_AllowsDeliveries(t *testing.T) {
|
func TestCircuitBreaker_ClosedState_AllowsDeliveries(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := NewCircuitBreaker()
|
|
||||||
|
|
||||||
assert.Equal(t, CircuitClosed, cb.State())
|
cb := delivery.NewCircuitBreaker()
|
||||||
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
|
|
||||||
// Multiple calls should all succeed
|
assert.Equal(t, delivery.CircuitClosed, cb.State())
|
||||||
for i := 0; i < 10; i++ {
|
assert.True(t, cb.Allow(),
|
||||||
|
"closed circuit should allow deliveries",
|
||||||
|
)
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
assert.True(t, cb.Allow())
|
assert.True(t, cb.Allow())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_FailureCounting(t *testing.T) {
|
func TestCircuitBreaker_FailureCounting(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := NewCircuitBreaker()
|
|
||||||
|
|
||||||
// Record failures below threshold — circuit should stay closed
|
cb := delivery.NewCircuitBreaker()
|
||||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
|
||||||
|
for i := range delivery.ExportDefaultFailureThreshold - 1 {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
assert.Equal(t, CircuitClosed, cb.State(),
|
|
||||||
"circuit should remain closed after %d failures", i+1)
|
assert.Equal(t,
|
||||||
assert.True(t, cb.Allow(), "should still allow after %d failures", i+1)
|
delivery.CircuitClosed, cb.State(),
|
||||||
|
"circuit should remain closed after %d failures",
|
||||||
|
i+1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.True(t, cb.Allow(),
|
||||||
|
"should still allow after %d failures",
|
||||||
|
i+1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_OpenTransition(t *testing.T) {
|
func TestCircuitBreaker_OpenTransition(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := NewCircuitBreaker()
|
|
||||||
|
|
||||||
// Record exactly threshold failures
|
cb := delivery.NewCircuitBreaker()
|
||||||
for i := 0; i < defaultFailureThreshold; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, CircuitOpen, cb.State(), "circuit should be open after threshold failures")
|
assert.Equal(t, delivery.CircuitOpen, cb.State(),
|
||||||
assert.False(t, cb.Allow(), "open circuit should reject deliveries")
|
"circuit should be open after threshold failures",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.False(t, cb.Allow(),
|
||||||
|
"open circuit should reject deliveries",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_Cooldown_StaysOpen(t *testing.T) {
|
func TestCircuitBreaker_Cooldown_StaysOpen(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
// Use a circuit with a known short cooldown for testing
|
|
||||||
cb := &CircuitBreaker{
|
|
||||||
state: CircuitClosed,
|
|
||||||
threshold: defaultFailureThreshold,
|
|
||||||
cooldown: 200 * time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trip the circuit open
|
cb := delivery.NewCircuitBreaker()
|
||||||
for i := 0; i < defaultFailureThreshold; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
require.Equal(t, CircuitOpen, cb.State())
|
|
||||||
|
|
||||||
// During cooldown, Allow should return false
|
require.Equal(t, delivery.CircuitOpen, cb.State())
|
||||||
assert.False(t, cb.Allow(), "should be blocked during cooldown")
|
|
||||||
|
assert.False(t, cb.Allow(),
|
||||||
|
"should be blocked during cooldown",
|
||||||
|
)
|
||||||
|
|
||||||
// CooldownRemaining should be positive
|
|
||||||
remaining := cb.CooldownRemaining()
|
remaining := cb.CooldownRemaining()
|
||||||
assert.Greater(t, remaining, time.Duration(0), "cooldown should have remaining time")
|
|
||||||
|
assert.Greater(t, remaining, time.Duration(0),
|
||||||
|
"cooldown should have remaining time",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_HalfOpen_AfterCooldown(t *testing.T) {
|
func TestCircuitBreaker_HalfOpen_AfterCooldown(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := &CircuitBreaker{
|
|
||||||
state: CircuitClosed,
|
|
||||||
threshold: defaultFailureThreshold,
|
|
||||||
cooldown: 50 * time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trip the circuit open
|
cb := newShortCooldownCB(t)
|
||||||
for i := 0; i < defaultFailureThreshold; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
require.Equal(t, CircuitOpen, cb.State())
|
|
||||||
|
|
||||||
// Wait for cooldown to expire
|
require.Equal(t, delivery.CircuitOpen, cb.State())
|
||||||
|
|
||||||
time.Sleep(60 * time.Millisecond)
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
// CooldownRemaining should be zero after cooldown
|
assert.Equal(t, time.Duration(0),
|
||||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining())
|
cb.CooldownRemaining(),
|
||||||
|
)
|
||||||
|
|
||||||
// First Allow after cooldown should succeed (probe)
|
assert.True(t, cb.Allow(),
|
||||||
assert.True(t, cb.Allow(), "should allow one probe after cooldown")
|
"should allow one probe after cooldown",
|
||||||
assert.Equal(t, CircuitHalfOpen, cb.State(), "should be half-open after probe allowed")
|
)
|
||||||
|
|
||||||
// Second Allow should be rejected (only one probe at a time)
|
assert.Equal(t,
|
||||||
assert.False(t, cb.Allow(), "should reject additional probes while half-open")
|
delivery.CircuitHalfOpen, cb.State(),
|
||||||
|
"should be half-open after probe allowed",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.False(t, cb.Allow(),
|
||||||
|
"should reject additional probes while half-open",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(t *testing.T) {
|
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := &CircuitBreaker{
|
|
||||||
state: CircuitClosed,
|
|
||||||
threshold: defaultFailureThreshold,
|
|
||||||
cooldown: 50 * time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trip open → wait for cooldown → allow probe
|
cb := newShortCooldownCB(t)
|
||||||
for i := 0; i < defaultFailureThreshold; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
time.Sleep(60 * time.Millisecond)
|
|
||||||
require.True(t, cb.Allow()) // probe allowed, state → half-open
|
|
||||||
|
|
||||||
// Probe succeeds → circuit should close
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
require.True(t, cb.Allow())
|
||||||
|
|
||||||
cb.RecordSuccess()
|
cb.RecordSuccess()
|
||||||
assert.Equal(t, CircuitClosed, cb.State(), "successful probe should close circuit")
|
|
||||||
|
|
||||||
// Should allow deliveries again
|
assert.Equal(t, delivery.CircuitClosed, cb.State(),
|
||||||
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
|
"successful probe should close circuit",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.True(t, cb.Allow(),
|
||||||
|
"closed circuit should allow deliveries",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_ProbeFailure_ReopensCircuit(t *testing.T) {
|
func TestCircuitBreaker_ProbeFailure_ReopensCircuit(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := &CircuitBreaker{
|
|
||||||
state: CircuitClosed,
|
|
||||||
threshold: defaultFailureThreshold,
|
|
||||||
cooldown: 50 * time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trip open → wait for cooldown → allow probe
|
cb := newShortCooldownCB(t)
|
||||||
for i := 0; i < defaultFailureThreshold; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(60 * time.Millisecond)
|
time.Sleep(60 * time.Millisecond)
|
||||||
require.True(t, cb.Allow()) // probe allowed, state → half-open
|
|
||||||
|
|
||||||
// Probe fails → circuit should reopen
|
require.True(t, cb.Allow())
|
||||||
|
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
assert.Equal(t, CircuitOpen, cb.State(), "failed probe should reopen circuit")
|
|
||||||
assert.False(t, cb.Allow(), "reopened circuit should reject deliveries")
|
assert.Equal(t, delivery.CircuitOpen, cb.State(),
|
||||||
|
"failed probe should reopen circuit",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.False(t, cb.Allow(),
|
||||||
|
"reopened circuit should reject deliveries",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
func TestCircuitBreaker_SuccessResetsFailures(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := NewCircuitBreaker()
|
|
||||||
|
|
||||||
// Accumulate failures just below threshold
|
cb := delivery.NewCircuitBreaker()
|
||||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold - 1 {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
require.Equal(t, CircuitClosed, cb.State())
|
|
||||||
|
|
||||||
// Success should reset the failure counter
|
require.Equal(t, delivery.CircuitClosed, cb.State())
|
||||||
|
|
||||||
cb.RecordSuccess()
|
cb.RecordSuccess()
|
||||||
assert.Equal(t, CircuitClosed, cb.State())
|
|
||||||
|
|
||||||
// Now we should need another full threshold of failures to trip
|
assert.Equal(t, delivery.CircuitClosed, cb.State())
|
||||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold - 1 {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
assert.Equal(t, CircuitClosed, cb.State(),
|
|
||||||
"circuit should still be closed — success reset the counter")
|
|
||||||
|
|
||||||
// One more failure should trip it
|
assert.Equal(t, delivery.CircuitClosed, cb.State(),
|
||||||
|
"circuit should still be closed -- "+
|
||||||
|
"success reset the counter",
|
||||||
|
)
|
||||||
|
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
assert.Equal(t, CircuitOpen, cb.State())
|
|
||||||
|
assert.Equal(t, delivery.CircuitOpen, cb.State())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := NewCircuitBreaker()
|
|
||||||
|
cb := delivery.NewCircuitBreaker()
|
||||||
|
|
||||||
const goroutines = 100
|
const goroutines = 100
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
wg.Add(goroutines * 3)
|
wg.Add(goroutines * 3)
|
||||||
|
|
||||||
// Concurrent Allow calls
|
for range goroutines {
|
||||||
for i := 0; i < goroutines; i++ {
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
cb.Allow()
|
cb.Allow()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concurrent RecordFailure calls
|
for range goroutines {
|
||||||
for i := 0; i < goroutines; i++ {
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concurrent RecordSuccess calls
|
for range goroutines {
|
||||||
for i := 0; i < goroutines; i++ {
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
cb.RecordSuccess()
|
cb.RecordSuccess()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
// No panic or data race — the test passes if -race doesn't flag anything.
|
|
||||||
// State should be one of the valid states.
|
|
||||||
state := cb.State()
|
state := cb.State()
|
||||||
assert.Contains(t, []CircuitState{CircuitClosed, CircuitOpen, CircuitHalfOpen}, state,
|
|
||||||
"state should be valid after concurrent access")
|
assert.Contains(t,
|
||||||
|
[]delivery.CircuitState{
|
||||||
|
delivery.CircuitClosed,
|
||||||
|
delivery.CircuitOpen,
|
||||||
|
delivery.CircuitHalfOpen,
|
||||||
|
},
|
||||||
|
state,
|
||||||
|
"state should be valid after concurrent access",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(t *testing.T) {
|
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := NewCircuitBreaker()
|
|
||||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
|
cb := delivery.NewCircuitBreaker()
|
||||||
"closed circuit should have zero cooldown remaining")
|
|
||||||
|
assert.Equal(t, time.Duration(0),
|
||||||
|
cb.CooldownRemaining(),
|
||||||
|
"closed circuit should have zero cooldown remaining",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(t *testing.T) {
|
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cb := &CircuitBreaker{
|
|
||||||
state: CircuitClosed,
|
|
||||||
threshold: defaultFailureThreshold,
|
|
||||||
cooldown: 50 * time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trip open, wait, transition to half-open
|
cb := newShortCooldownCB(t)
|
||||||
for i := 0; i < defaultFailureThreshold; i++ {
|
|
||||||
|
for range delivery.ExportDefaultFailureThreshold {
|
||||||
cb.RecordFailure()
|
cb.RecordFailure()
|
||||||
}
|
}
|
||||||
time.Sleep(60 * time.Millisecond)
|
|
||||||
require.True(t, cb.Allow()) // → half-open
|
|
||||||
|
|
||||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
|
time.Sleep(60 * time.Millisecond)
|
||||||
"half-open circuit should have zero cooldown remaining")
|
|
||||||
|
require.True(t, cb.Allow())
|
||||||
|
|
||||||
|
assert.Equal(t, time.Duration(0),
|
||||||
|
cb.CooldownRemaining(),
|
||||||
|
"half-open circuit should have zero cooldown remaining",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCircuitState_String(t *testing.T) {
|
func TestCircuitState_String(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
assert.Equal(t, "closed", CircuitClosed.String())
|
|
||||||
assert.Equal(t, "open", CircuitOpen.String())
|
assert.Equal(t, "closed", delivery.CircuitClosed.String())
|
||||||
assert.Equal(t, "half-open", CircuitHalfOpen.String())
|
assert.Equal(t, "open", delivery.CircuitOpen.String())
|
||||||
assert.Equal(t, "unknown", CircuitState(99).String())
|
assert.Equal(t, "half-open", delivery.CircuitHalfOpen.String())
|
||||||
|
assert.Equal(t, "unknown", delivery.CircuitState(99).String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// newShortCooldownCB creates a CircuitBreaker with a short
|
||||||
|
// cooldown for testing. We use NewCircuitBreaker and
|
||||||
|
// manipulate through the public API.
|
||||||
|
func newShortCooldownCB(t *testing.T) *delivery.CircuitBreaker {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return delivery.NewTestCircuitBreaker(
|
||||||
|
delivery.ExportDefaultFailureThreshold,
|
||||||
|
50*time.Millisecond,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
240
internal/delivery/export_test.go
Normal file
240
internal/delivery/export_test.go
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
package delivery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"sneak.berlin/go/webhooker/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Exported constants for test access.
|
||||||
|
const (
|
||||||
|
ExportDeliveryChannelSize = deliveryChannelSize
|
||||||
|
ExportRetryChannelSize = retryChannelSize
|
||||||
|
ExportDefaultFailureThreshold = defaultFailureThreshold
|
||||||
|
ExportDefaultCooldown = defaultCooldown
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExportIsBlockedIP exposes isBlockedIP for testing.
|
||||||
|
func ExportIsBlockedIP(ip net.IP) bool {
|
||||||
|
return isBlockedIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportBlockedNetworks exposes blockedNetworks.
|
||||||
|
func ExportBlockedNetworks() []*net.IPNet {
|
||||||
|
return blockedNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportIsForwardableHeader exposes isForwardableHeader.
|
||||||
|
func ExportIsForwardableHeader(name string) bool {
|
||||||
|
return isForwardableHeader(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportTruncate exposes truncate for testing.
|
||||||
|
func ExportTruncate(s string, maxLen int) string {
|
||||||
|
return truncate(s, maxLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportDeliverHTTP exposes deliverHTTP for testing.
|
||||||
|
func (e *Engine) ExportDeliverHTTP(
|
||||||
|
ctx context.Context,
|
||||||
|
webhookDB *gorm.DB,
|
||||||
|
d *database.Delivery,
|
||||||
|
task *Task,
|
||||||
|
) {
|
||||||
|
e.deliverHTTP(ctx, webhookDB, d, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportDeliverDatabase exposes deliverDatabase.
|
||||||
|
func (e *Engine) ExportDeliverDatabase(
|
||||||
|
webhookDB *gorm.DB, d *database.Delivery,
|
||||||
|
) {
|
||||||
|
e.deliverDatabase(webhookDB, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportDeliverLog exposes deliverLog for testing.
|
||||||
|
func (e *Engine) ExportDeliverLog(
|
||||||
|
webhookDB *gorm.DB, d *database.Delivery,
|
||||||
|
) {
|
||||||
|
e.deliverLog(webhookDB, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportDeliverSlack exposes deliverSlack for testing.
|
||||||
|
func (e *Engine) ExportDeliverSlack(
|
||||||
|
ctx context.Context,
|
||||||
|
webhookDB *gorm.DB,
|
||||||
|
d *database.Delivery,
|
||||||
|
) {
|
||||||
|
e.deliverSlack(ctx, webhookDB, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportProcessNewTask exposes processNewTask.
|
||||||
|
func (e *Engine) ExportProcessNewTask(
|
||||||
|
ctx context.Context, task *Task,
|
||||||
|
) {
|
||||||
|
e.processNewTask(ctx, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportProcessRetryTask exposes processRetryTask.
|
||||||
|
func (e *Engine) ExportProcessRetryTask(
|
||||||
|
ctx context.Context, task *Task,
|
||||||
|
) {
|
||||||
|
e.processRetryTask(ctx, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportProcessDelivery exposes processDelivery.
|
||||||
|
func (e *Engine) ExportProcessDelivery(
|
||||||
|
ctx context.Context,
|
||||||
|
webhookDB *gorm.DB,
|
||||||
|
d *database.Delivery,
|
||||||
|
task *Task,
|
||||||
|
) {
|
||||||
|
e.processDelivery(ctx, webhookDB, d, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportGetCircuitBreaker exposes getCircuitBreaker.
|
||||||
|
func (e *Engine) ExportGetCircuitBreaker(
|
||||||
|
targetID string,
|
||||||
|
) *CircuitBreaker {
|
||||||
|
return e.getCircuitBreaker(targetID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportParseHTTPConfig exposes parseHTTPConfig.
|
||||||
|
func (e *Engine) ExportParseHTTPConfig(
|
||||||
|
configJSON string,
|
||||||
|
) (*HTTPTargetConfig, error) {
|
||||||
|
return e.parseHTTPConfig(configJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportParseSlackConfig exposes parseSlackConfig.
|
||||||
|
func (e *Engine) ExportParseSlackConfig(
|
||||||
|
configJSON string,
|
||||||
|
) (*SlackTargetConfig, error) {
|
||||||
|
return e.parseSlackConfig(configJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportDoHTTPRequest exposes doHTTPRequest.
|
||||||
|
func (e *Engine) ExportDoHTTPRequest(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *HTTPTargetConfig,
|
||||||
|
event *database.Event,
|
||||||
|
) (int, string, int64, error) {
|
||||||
|
return e.doHTTPRequest(ctx, cfg, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportScheduleRetry exposes scheduleRetry.
|
||||||
|
func (e *Engine) ExportScheduleRetry(
|
||||||
|
task Task, delay time.Duration,
|
||||||
|
) {
|
||||||
|
e.scheduleRetry(task, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportRecoverPendingDeliveries exposes
|
||||||
|
// recoverPendingDeliveries.
|
||||||
|
func (e *Engine) ExportRecoverPendingDeliveries(
|
||||||
|
ctx context.Context,
|
||||||
|
webhookDB *gorm.DB,
|
||||||
|
webhookID string,
|
||||||
|
) {
|
||||||
|
e.recoverPendingDeliveries(
|
||||||
|
ctx, webhookDB, webhookID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportRecoverWebhookDeliveries exposes
|
||||||
|
// recoverWebhookDeliveries.
|
||||||
|
func (e *Engine) ExportRecoverWebhookDeliveries(
|
||||||
|
ctx context.Context, webhookID string,
|
||||||
|
) {
|
||||||
|
e.recoverWebhookDeliveries(ctx, webhookID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportRecoverInFlight exposes recoverInFlight.
|
||||||
|
func (e *Engine) ExportRecoverInFlight(
|
||||||
|
ctx context.Context,
|
||||||
|
) {
|
||||||
|
e.recoverInFlight(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportStart exposes start for testing.
|
||||||
|
func (e *Engine) ExportStart(ctx context.Context) {
|
||||||
|
e.start(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportStop exposes stop for testing.
|
||||||
|
func (e *Engine) ExportStop() {
|
||||||
|
e.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportDeliveryCh returns the delivery channel.
|
||||||
|
func (e *Engine) ExportDeliveryCh() chan Task {
|
||||||
|
return e.deliveryCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportRetryCh returns the retry channel.
|
||||||
|
func (e *Engine) ExportRetryCh() chan Task {
|
||||||
|
return e.retryCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestEngine creates an Engine for unit tests without
|
||||||
|
// database dependencies.
|
||||||
|
func NewTestEngine(
|
||||||
|
log *slog.Logger,
|
||||||
|
client *http.Client,
|
||||||
|
workers int,
|
||||||
|
) *Engine {
|
||||||
|
return &Engine{
|
||||||
|
log: log,
|
||||||
|
client: client,
|
||||||
|
deliveryCh: make(chan Task, deliveryChannelSize),
|
||||||
|
retryCh: make(chan Task, retryChannelSize),
|
||||||
|
workers: workers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestEngineSmallRetry creates an Engine with a tiny
|
||||||
|
// retry channel buffer for overflow testing.
|
||||||
|
func NewTestEngineSmallRetry(
|
||||||
|
log *slog.Logger,
|
||||||
|
) *Engine {
|
||||||
|
return &Engine{
|
||||||
|
log: log,
|
||||||
|
retryCh: make(chan Task, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestEngineWithDB creates an Engine with a real
|
||||||
|
// database and dbManager for integration tests.
|
||||||
|
func NewTestEngineWithDB(
|
||||||
|
db *database.Database,
|
||||||
|
dbMgr *database.WebhookDBManager,
|
||||||
|
log *slog.Logger,
|
||||||
|
client *http.Client,
|
||||||
|
workers int,
|
||||||
|
) *Engine {
|
||||||
|
return &Engine{
|
||||||
|
database: db,
|
||||||
|
dbManager: dbMgr,
|
||||||
|
log: log,
|
||||||
|
client: client,
|
||||||
|
deliveryCh: make(chan Task, deliveryChannelSize),
|
||||||
|
retryCh: make(chan Task, retryChannelSize),
|
||||||
|
workers: workers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestCircuitBreaker creates a CircuitBreaker with
|
||||||
|
// custom settings for testing.
|
||||||
|
func NewTestCircuitBreaker(
|
||||||
|
threshold int, cooldown time.Duration,
|
||||||
|
) *CircuitBreaker {
|
||||||
|
return &CircuitBreaker{
|
||||||
|
state: CircuitClosed,
|
||||||
|
threshold: threshold,
|
||||||
|
cooldown: cooldown,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package delivery
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -10,14 +11,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// dnsResolutionTimeout is the maximum time to wait for DNS resolution
|
// dnsResolutionTimeout is the maximum time to wait for
|
||||||
// during SSRF validation.
|
// DNS resolution during SSRF validation.
|
||||||
dnsResolutionTimeout = 5 * time.Second
|
dnsResolutionTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// blockedNetworks contains all private/reserved IP ranges that should be
|
// Sentinel errors for SSRF validation.
|
||||||
// blocked to prevent SSRF attacks. This includes RFC 1918 private
|
var (
|
||||||
// addresses, loopback, link-local, and IPv6 equivalents.
|
errNoHostname = errors.New("URL has no hostname")
|
||||||
|
errNoIPs = errors.New(
|
||||||
|
"hostname resolved to no IP addresses",
|
||||||
|
)
|
||||||
|
errBlockedIP = errors.New(
|
||||||
|
"blocked private/reserved IP range",
|
||||||
|
)
|
||||||
|
errInvalidScheme = errors.New(
|
||||||
|
"only http and https are allowed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
// blockedNetworks contains all private/reserved IP ranges
|
||||||
|
// that should be blocked to prevent SSRF attacks.
|
||||||
//
|
//
|
||||||
//nolint:gochecknoglobals // package-level network list is appropriate here
|
//nolint:gochecknoglobals // package-level network list is appropriate here
|
||||||
var blockedNetworks []*net.IPNet
|
var blockedNetworks []*net.IPNet
|
||||||
@@ -25,129 +39,184 @@ var blockedNetworks []*net.IPNet
|
|||||||
//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup
|
//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup
|
||||||
func init() {
|
func init() {
|
||||||
cidrs := []string{
|
cidrs := []string{
|
||||||
// IPv4 private/reserved ranges
|
"127.0.0.0/8",
|
||||||
"127.0.0.0/8", // Loopback
|
"10.0.0.0/8",
|
||||||
"10.0.0.0/8", // RFC 1918 Class A private
|
"172.16.0.0/12",
|
||||||
"172.16.0.0/12", // RFC 1918 Class B private
|
"192.168.0.0/16",
|
||||||
"192.168.0.0/16", // RFC 1918 Class C private
|
"169.254.0.0/16",
|
||||||
"169.254.0.0/16", // Link-local (cloud metadata)
|
"0.0.0.0/8",
|
||||||
"0.0.0.0/8", // "This" network
|
"100.64.0.0/10",
|
||||||
"100.64.0.0/10", // Shared address space (CGN)
|
"192.0.0.0/24",
|
||||||
"192.0.0.0/24", // IETF protocol assignments
|
"192.0.2.0/24",
|
||||||
"192.0.2.0/24", // TEST-NET-1
|
"198.18.0.0/15",
|
||||||
"198.18.0.0/15", // Benchmarking
|
"198.51.100.0/24",
|
||||||
"198.51.100.0/24", // TEST-NET-2
|
"203.0.113.0/24",
|
||||||
"203.0.113.0/24", // TEST-NET-3
|
"224.0.0.0/4",
|
||||||
"224.0.0.0/4", // Multicast
|
"240.0.0.0/4",
|
||||||
"240.0.0.0/4", // Reserved for future use
|
"::1/128",
|
||||||
|
"fc00::/7",
|
||||||
// IPv6 private/reserved ranges
|
"fe80::/10",
|
||||||
"::1/128", // Loopback
|
|
||||||
"fc00::/7", // Unique local addresses
|
|
||||||
"fe80::/10", // Link-local
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, cidr := range cidrs {
|
for _, cidr := range cidrs {
|
||||||
_, network, err := net.ParseCIDR(cidr)
|
_, network, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("ssrf: failed to parse CIDR %q: %v", cidr, err))
|
panic(fmt.Sprintf(
|
||||||
|
"ssrf: failed to parse CIDR %q: %v",
|
||||||
|
cidr, err,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
blockedNetworks = append(blockedNetworks, network)
|
|
||||||
|
blockedNetworks = append(
|
||||||
|
blockedNetworks, network,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isBlockedIP checks whether an IP address falls within any blocked
|
// isBlockedIP checks whether an IP address falls within
|
||||||
// private/reserved network range.
|
// any blocked private/reserved network range.
|
||||||
func isBlockedIP(ip net.IP) bool {
|
func isBlockedIP(ip net.IP) bool {
|
||||||
for _, network := range blockedNetworks {
|
for _, network := range blockedNetworks {
|
||||||
if network.Contains(ip) {
|
if network.Contains(ip) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateTargetURL checks that an HTTP delivery target URL is safe
|
// ValidateTargetURL checks that an HTTP delivery target
|
||||||
// from SSRF attacks. It validates the URL format, resolves the hostname
|
// URL is safe from SSRF attacks.
|
||||||
// to IP addresses, and verifies that none of the resolved IPs are in
|
func ValidateTargetURL(
|
||||||
// blocked private/reserved ranges.
|
ctx context.Context, targetURL string,
|
||||||
//
|
) error {
|
||||||
// Returns nil if the URL is safe, or an error describing the issue.
|
|
||||||
func ValidateTargetURL(targetURL string) error {
|
|
||||||
parsed, err := url.Parse(targetURL)
|
parsed, err := url.Parse(targetURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid URL: %w", err)
|
return fmt.Errorf("invalid URL: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only allow http and https schemes
|
err = validateScheme(parsed.Scheme)
|
||||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
if err != nil {
|
||||||
return fmt.Errorf("unsupported URL scheme %q: only http and https are allowed", parsed.Scheme)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
host := parsed.Hostname()
|
host := parsed.Hostname()
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return fmt.Errorf("URL has no hostname")
|
return errNoHostname
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the host is a raw IP address first
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
if isBlockedIP(ip) {
|
return checkBlockedIP(ip)
|
||||||
return fmt.Errorf("target IP %s is in a blocked private/reserved range", ip)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve hostname to IPs and check each one
|
return validateHostname(ctx, host)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsResolutionTimeout)
|
}
|
||||||
|
|
||||||
|
func validateScheme(scheme string) error {
|
||||||
|
if scheme != "http" && scheme != "https" {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"unsupported URL scheme %q: %w",
|
||||||
|
scheme, errInvalidScheme,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkBlockedIP(ip net.IP) error {
|
||||||
|
if isBlockedIP(ip) {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"target IP %s is in a blocked "+
|
||||||
|
"private/reserved range: %w",
|
||||||
|
ip, errBlockedIP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateHostname(
|
||||||
|
ctx context.Context, host string,
|
||||||
|
) error {
|
||||||
|
dnsCtx, cancel := context.WithTimeout(
|
||||||
|
ctx, dnsResolutionTimeout,
|
||||||
|
)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
ips, err := net.DefaultResolver.LookupIPAddr(
|
||||||
|
dnsCtx, host,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve hostname %q: %w", host, err)
|
return fmt.Errorf(
|
||||||
|
"failed to resolve hostname %q: %w",
|
||||||
|
host, err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
return fmt.Errorf("hostname %q resolved to no IP addresses", host)
|
return fmt.Errorf(
|
||||||
|
"hostname %q: %w", host, errNoIPs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ipAddr := range ips {
|
for _, ipAddr := range ips {
|
||||||
if isBlockedIP(ipAddr.IP) {
|
if isBlockedIP(ipAddr.IP) {
|
||||||
return fmt.Errorf("hostname %q resolves to blocked IP %s (private/reserved range)", host, ipAddr.IP)
|
return fmt.Errorf(
|
||||||
|
"hostname %q resolves to blocked "+
|
||||||
|
"IP %s: %w",
|
||||||
|
host, ipAddr.IP, errBlockedIP,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSSRFSafeTransport creates an http.Transport with a custom DialContext
|
// NewSSRFSafeTransport creates an http.Transport with a
|
||||||
// that blocks connections to private/reserved IP addresses. This provides
|
// custom DialContext that blocks connections to
|
||||||
// defense-in-depth SSRF protection at the network layer, catching cases
|
// private/reserved IP addresses.
|
||||||
// where DNS records change between target creation and delivery time
|
|
||||||
// (DNS rebinding attacks).
|
|
||||||
func NewSSRFSafeTransport() *http.Transport {
|
func NewSSRFSafeTransport() *http.Transport {
|
||||||
return &http.Transport{
|
return &http.Transport{
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialContext: ssrfDialContext,
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve hostname to IPs
|
|
||||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("ssrf: DNS resolution failed for %q: %w", host, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check all resolved IPs
|
|
||||||
for _, ipAddr := range ips {
|
|
||||||
if isBlockedIP(ipAddr.IP) {
|
|
||||||
return nil, fmt.Errorf("ssrf: connection to %s (%s) blocked — private/reserved IP range", host, ipAddr.IP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to the first allowed IP
|
|
||||||
var dialer net.Dialer
|
|
||||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port))
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ssrfDialContext(
|
||||||
|
ctx context.Context,
|
||||||
|
network, addr string,
|
||||||
|
) (net.Conn, error) {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"ssrf: invalid address %q: %w",
|
||||||
|
addr, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err := net.DefaultResolver.LookupIPAddr(
|
||||||
|
ctx, host,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"ssrf: DNS resolution failed for %q: %w",
|
||||||
|
host, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ipAddr := range ips {
|
||||||
|
if isBlockedIP(ipAddr.IP) {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"ssrf: connection to %s (%s) "+
|
||||||
|
"blocked: %w",
|
||||||
|
host, ipAddr.IP, errBlockedIP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dialer net.Dialer
|
||||||
|
|
||||||
|
return dialer.DialContext(
|
||||||
|
ctx, network,
|
||||||
|
net.JoinHostPort(ips[0].IP.String(), port),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package delivery
|
package delivery_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"sneak.berlin/go/webhooker/internal/delivery"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
||||||
@@ -16,56 +18,52 @@ func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
|||||||
ip string
|
ip string
|
||||||
blocked bool
|
blocked bool
|
||||||
}{
|
}{
|
||||||
// Loopback
|
|
||||||
{"loopback 127.0.0.1", "127.0.0.1", true},
|
{"loopback 127.0.0.1", "127.0.0.1", true},
|
||||||
{"loopback 127.0.0.2", "127.0.0.2", true},
|
{"loopback 127.0.0.2", "127.0.0.2", true},
|
||||||
{"loopback 127.255.255.255", "127.255.255.255", true},
|
{"loopback 127.255.255.255", "127.255.255.255", true},
|
||||||
|
|
||||||
// RFC 1918 - Class A
|
|
||||||
{"10.0.0.0", "10.0.0.0", true},
|
{"10.0.0.0", "10.0.0.0", true},
|
||||||
{"10.0.0.1", "10.0.0.1", true},
|
{"10.0.0.1", "10.0.0.1", true},
|
||||||
{"10.255.255.255", "10.255.255.255", true},
|
{"10.255.255.255", "10.255.255.255", true},
|
||||||
|
|
||||||
// RFC 1918 - Class B
|
|
||||||
{"172.16.0.1", "172.16.0.1", true},
|
{"172.16.0.1", "172.16.0.1", true},
|
||||||
{"172.31.255.255", "172.31.255.255", true},
|
{"172.31.255.255", "172.31.255.255", true},
|
||||||
{"172.15.255.255", "172.15.255.255", false},
|
{"172.15.255.255", "172.15.255.255", false},
|
||||||
{"172.32.0.0", "172.32.0.0", false},
|
{"172.32.0.0", "172.32.0.0", false},
|
||||||
|
|
||||||
// RFC 1918 - Class C
|
|
||||||
{"192.168.0.1", "192.168.0.1", true},
|
{"192.168.0.1", "192.168.0.1", true},
|
||||||
{"192.168.255.255", "192.168.255.255", true},
|
{"192.168.255.255", "192.168.255.255", true},
|
||||||
|
|
||||||
// Link-local / cloud metadata
|
|
||||||
{"169.254.0.1", "169.254.0.1", true},
|
{"169.254.0.1", "169.254.0.1", true},
|
||||||
{"169.254.169.254", "169.254.169.254", true},
|
{"169.254.169.254", "169.254.169.254", true},
|
||||||
|
|
||||||
// Public IPs (should NOT be blocked)
|
|
||||||
{"8.8.8.8", "8.8.8.8", false},
|
{"8.8.8.8", "8.8.8.8", false},
|
||||||
{"1.1.1.1", "1.1.1.1", false},
|
{"1.1.1.1", "1.1.1.1", false},
|
||||||
{"93.184.216.34", "93.184.216.34", false},
|
{"93.184.216.34", "93.184.216.34", false},
|
||||||
|
|
||||||
// IPv6 loopback
|
|
||||||
{"::1", "::1", true},
|
{"::1", "::1", true},
|
||||||
|
|
||||||
// IPv6 unique local
|
|
||||||
{"fd00::1", "fd00::1", true},
|
{"fd00::1", "fd00::1", true},
|
||||||
{"fc00::1", "fc00::1", true},
|
{"fc00::1", "fc00::1", true},
|
||||||
|
|
||||||
// IPv6 link-local
|
|
||||||
{"fe80::1", "fe80::1", true},
|
{"fe80::1", "fe80::1", true},
|
||||||
|
{
|
||||||
// IPv6 public (should NOT be blocked)
|
"2607:f8b0:4004:800::200e",
|
||||||
{"2607:f8b0:4004:800::200e", "2607:f8b0:4004:800::200e", false},
|
"2607:f8b0:4004:800::200e",
|
||||||
|
false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
require.NotNil(t, ip, "failed to parse IP %s", tt.ip)
|
|
||||||
assert.Equal(t, tt.blocked, isBlockedIP(ip),
|
require.NotNil(t, ip,
|
||||||
"isBlockedIP(%s) = %v, want %v", tt.ip, isBlockedIP(ip), tt.blocked)
|
"failed to parse IP %s", tt.ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
tt.blocked,
|
||||||
|
delivery.ExportIsBlockedIP(ip),
|
||||||
|
"isBlockedIP(%s) = %v, want %v",
|
||||||
|
tt.ip,
|
||||||
|
delivery.ExportIsBlockedIP(ip),
|
||||||
|
tt.blocked,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,8 +87,14 @@ func TestValidateTargetURL_Blocked(t *testing.T) {
|
|||||||
for _, u := range blockedURLs {
|
for _, u := range blockedURLs {
|
||||||
t.Run(u, func(t *testing.T) {
|
t.Run(u, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
err := ValidateTargetURL(u)
|
|
||||||
assert.Error(t, err, "URL %s should be blocked", u)
|
err := delivery.ValidateTargetURL(
|
||||||
|
context.Background(), u,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Error(t, err,
|
||||||
|
"URL %s should be blocked", u,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,7 +102,6 @@ func TestValidateTargetURL_Blocked(t *testing.T) {
|
|||||||
func TestValidateTargetURL_Allowed(t *testing.T) {
|
func TestValidateTargetURL_Allowed(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// These are public IPs and should be allowed
|
|
||||||
allowedURLs := []string{
|
allowedURLs := []string{
|
||||||
"https://example.com/hook",
|
"https://example.com/hook",
|
||||||
"http://93.184.216.34/webhook",
|
"http://93.184.216.34/webhook",
|
||||||
@@ -108,35 +111,62 @@ func TestValidateTargetURL_Allowed(t *testing.T) {
|
|||||||
for _, u := range allowedURLs {
|
for _, u := range allowedURLs {
|
||||||
t.Run(u, func(t *testing.T) {
|
t.Run(u, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
err := ValidateTargetURL(u)
|
|
||||||
assert.NoError(t, err, "URL %s should be allowed", u)
|
err := delivery.ValidateTargetURL(
|
||||||
|
context.Background(), u,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.NoError(t, err,
|
||||||
|
"URL %s should be allowed", u,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateTargetURL_InvalidScheme(t *testing.T) {
|
func TestValidateTargetURL_InvalidScheme(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
err := ValidateTargetURL("ftp://example.com/hook")
|
|
||||||
assert.Error(t, err)
|
err := delivery.ValidateTargetURL(
|
||||||
assert.Contains(t, err.Error(), "unsupported URL scheme")
|
context.Background(), "ftp://example.com/hook",
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
assert.Contains(t, err.Error(),
|
||||||
|
"unsupported URL scheme",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateTargetURL_EmptyHost(t *testing.T) {
|
func TestValidateTargetURL_EmptyHost(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
err := ValidateTargetURL("http:///path")
|
|
||||||
|
err := delivery.ValidateTargetURL(
|
||||||
|
context.Background(), "http:///path",
|
||||||
|
)
|
||||||
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateTargetURL_InvalidURL(t *testing.T) {
|
func TestValidateTargetURL_InvalidURL(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
err := ValidateTargetURL("://invalid")
|
|
||||||
|
err := delivery.ValidateTargetURL(
|
||||||
|
context.Background(), "://invalid",
|
||||||
|
)
|
||||||
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBlockedNetworks_Initialized(t *testing.T) {
|
func TestBlockedNetworks_Initialized(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
assert.NotEmpty(t, blockedNetworks, "blockedNetworks should be initialized")
|
|
||||||
// Should have at least the main RFC 1918 + loopback + link-local ranges
|
nets := delivery.ExportBlockedNetworks()
|
||||||
assert.GreaterOrEqual(t, len(blockedNetworks), 8,
|
|
||||||
"should have at least 8 blocked network ranges")
|
assert.NotEmpty(t, nets,
|
||||||
|
"blockedNetworks should be initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.GreaterOrEqual(t, len(nets), 8,
|
||||||
|
"should have at least 8 blocked network ranges",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,34 @@
|
|||||||
|
// Package globals provides build-time variables injected via ldflags.
|
||||||
package globals
|
package globals
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// these get populated from main() and copied into the Globals object.
|
// Build-time variables populated from main() and copied into the
|
||||||
|
// Globals object.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Build-time variables set by main().
|
||||||
var (
|
var (
|
||||||
Appname string
|
Appname string
|
||||||
Version string
|
Version string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Globals holds build-time metadata about the application.
|
||||||
type Globals struct {
|
type Globals struct {
|
||||||
Appname string
|
Appname string
|
||||||
Version string
|
Version string
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:revive // lc parameter is required by fx even if unused
|
// New creates a Globals instance from the package-level
|
||||||
|
// build-time variables.
|
||||||
|
//
|
||||||
|
//nolint:revive // lc parameter is required by fx even if unused.
|
||||||
func New(lc fx.Lifecycle) (*Globals, error) {
|
func New(lc fx.Lifecycle) (*Globals, error) {
|
||||||
n := &Globals{
|
n := &Globals{
|
||||||
Appname: Appname,
|
Appname: Appname,
|
||||||
Version: Version,
|
Version: Version,
|
||||||
}
|
}
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,30 @@
|
|||||||
package globals
|
package globals_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"go.uber.org/fx/fxtest"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestGlobalsFields(t *testing.T) {
|
||||||
// Set test values
|
t.Parallel()
|
||||||
Appname = "test-app"
|
|
||||||
Version = "1.0.0"
|
|
||||||
|
|
||||||
lc := fxtest.NewLifecycle(t)
|
g := &globals.Globals{
|
||||||
globals, err := New(lc)
|
Appname: "test-app",
|
||||||
if err != nil {
|
Version: "1.0.0",
|
||||||
t.Fatalf("New() error = %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if globals.Appname != "test-app" {
|
if g.Appname != "test-app" {
|
||||||
t.Errorf("Appname = %v, want %v", globals.Appname, "test-app")
|
t.Errorf(
|
||||||
|
"Appname = %v, want %v",
|
||||||
|
g.Appname, "test-app",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if globals.Version != "1.0.0" {
|
|
||||||
t.Errorf("Version = %v, want %v", globals.Version, "1.0.0")
|
if g.Version != "1.0.0" {
|
||||||
|
t.Errorf(
|
||||||
|
"Version = %v, want %v",
|
||||||
|
g.Version, "1.0.0",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ func (h *Handlers) HandleLoginPage() http.HandlerFunc {
|
|||||||
sess, err := h.session.Get(r)
|
sess, err := h.session.Get(r)
|
||||||
if err == nil && h.session.IsAuthenticated(sess) {
|
if err == nil && h.session.IsAuthenticated(sess) {
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Render login page
|
// Render login page
|
||||||
data := map[string]interface{}{
|
data := map[string]any{
|
||||||
"Error": "",
|
"Error": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,10 +29,15 @@ func (h *Handlers) HandleLoginPage() http.HandlerFunc {
|
|||||||
// HandleLoginSubmit handles the login form submission (POST)
|
// HandleLoginSubmit handles the login form submission (POST)
|
||||||
func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
|
func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Limit request body to prevent memory exhaustion
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, 1<<maxBodyShift)
|
||||||
|
|
||||||
// Parse form data
|
// Parse form data
|
||||||
if err := r.ParseForm(); err != nil {
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
h.log.Error("failed to parse form", "error", err)
|
h.log.Error("failed to parse form", "error", err)
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,85 +46,159 @@ func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
|
|||||||
|
|
||||||
// Validate input
|
// Validate input
|
||||||
if username == "" || password == "" {
|
if username == "" || password == "" {
|
||||||
data := map[string]interface{}{
|
h.renderLoginError(
|
||||||
"Error": "Username and password are required",
|
w, r,
|
||||||
}
|
"Username and password are required",
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
http.StatusBadRequest,
|
||||||
h.renderTemplate(w, r, "login.html", data)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find user in database
|
user, err := h.authenticateUser(
|
||||||
var user database.User
|
w, r, username, password,
|
||||||
if err := h.db.DB().Where("username = ?", username).First(&user).Error; err != nil {
|
)
|
||||||
h.log.Debug("user not found", "username", username)
|
|
||||||
data := map[string]interface{}{
|
|
||||||
"Error": "Invalid username or password",
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
h.renderTemplate(w, r, "login.html", data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify password
|
|
||||||
valid, err := database.VerifyPassword(password, user.Password)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.log.Error("failed to verify password", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !valid {
|
err = h.createAuthenticatedSession(w, r, user)
|
||||||
h.log.Debug("invalid password", "username", username)
|
|
||||||
data := map[string]interface{}{
|
|
||||||
"Error": "Invalid username or password",
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
h.renderTemplate(w, r, "login.html", data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the current session (may be pre-existing / attacker-set)
|
|
||||||
oldSess, err := h.session.Get(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.log.Error("failed to get session", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Regenerate the session to prevent session fixation attacks.
|
h.log.Info(
|
||||||
// This destroys the old session ID and creates a new one.
|
"user logged in",
|
||||||
sess, err := h.session.Regenerate(r, w, oldSess)
|
"username", username,
|
||||||
if err != nil {
|
"user_id", user.ID,
|
||||||
h.log.Error("failed to regenerate session", "error", err)
|
)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set user in session
|
|
||||||
h.session.SetUser(sess, user.ID, user.Username)
|
|
||||||
|
|
||||||
// Save session
|
|
||||||
if err := h.session.Save(r, w, sess); err != nil {
|
|
||||||
h.log.Error("failed to save session", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.log.Info("user logged in", "username", username, "user_id", user.ID)
|
|
||||||
|
|
||||||
// Redirect to home page
|
// Redirect to home page
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// renderLoginError renders the login page with an error message.
|
||||||
|
func (h *Handlers) renderLoginError(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
status int,
|
||||||
|
) {
|
||||||
|
data := map[string]any{
|
||||||
|
"Error": msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(status)
|
||||||
|
h.renderTemplate(w, r, "login.html", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateUser looks up and verifies a user's credentials.
|
||||||
|
// On failure it writes an HTTP response and returns an error.
|
||||||
|
func (h *Handlers) authenticateUser(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
username, password string,
|
||||||
|
) (database.User, error) {
|
||||||
|
var user database.User
|
||||||
|
|
||||||
|
err := h.db.DB().Where(
|
||||||
|
"username = ?", username,
|
||||||
|
).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
h.log.Debug("user not found", "username", username)
|
||||||
|
h.renderLoginError(
|
||||||
|
w, r,
|
||||||
|
"Invalid username or password",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := database.VerifyPassword(password, user.Password)
|
||||||
|
if err != nil {
|
||||||
|
h.log.Error("failed to verify password", "error", err)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
h.log.Debug("invalid password", "username", username)
|
||||||
|
h.renderLoginError(
|
||||||
|
w, r,
|
||||||
|
"Invalid username or password",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user, errInvalidPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAuthenticatedSession regenerates the session and stores
|
||||||
|
// user info. On failure it writes an HTTP response and returns
|
||||||
|
// an error.
|
||||||
|
func (h *Handlers) createAuthenticatedSession(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
user database.User,
|
||||||
|
) error {
|
||||||
|
oldSess, err := h.session.Get(r)
|
||||||
|
if err != nil {
|
||||||
|
h.log.Error("failed to get session", "error", err)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sess, err := h.session.Regenerate(r, w, oldSess)
|
||||||
|
if err != nil {
|
||||||
|
h.log.Error(
|
||||||
|
"failed to regenerate session", "error", err,
|
||||||
|
)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.session.SetUser(sess, user.ID, user.Username)
|
||||||
|
|
||||||
|
err = h.session.Save(r, w, sess)
|
||||||
|
if err != nil {
|
||||||
|
h.log.Error("failed to save session", "error", err)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// HandleLogout handles user logout
|
// HandleLogout handles user logout
|
||||||
func (h *Handlers) HandleLogout() http.HandlerFunc {
|
func (h *Handlers) HandleLogout() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
sess, err := h.session.Get(r)
|
sess, err := h.session.Get(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.log.Error("failed to get session", "error", err)
|
h.log.Error("failed to get session", "error", err)
|
||||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
http.Redirect(
|
||||||
|
w, r, "/pages/login", http.StatusSeeOther,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,8 +206,12 @@ func (h *Handlers) HandleLogout() http.HandlerFunc {
|
|||||||
h.session.Destroy(sess)
|
h.session.Destroy(sess)
|
||||||
|
|
||||||
// Save the destroyed session
|
// Save the destroyed session
|
||||||
if err := h.session.Save(r, w, sess); err != nil {
|
err = h.session.Save(r, w, sess)
|
||||||
h.log.Error("failed to save destroyed session", "error", err)
|
if err != nil {
|
||||||
|
h.log.Error(
|
||||||
|
"failed to save destroyed session",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect to login page
|
// Redirect to login page
|
||||||
|
|||||||
14
internal/handlers/export_test.go
Normal file
14
internal/handlers/export_test.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// RenderTemplateForTest exposes renderTemplate for use in the
|
||||||
|
// handlers_test package.
|
||||||
|
func (s *Handlers) RenderTemplateForTest(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
pageTemplate string,
|
||||||
|
data any,
|
||||||
|
) {
|
||||||
|
s.renderTemplate(w, r, pageTemplate, data)
|
||||||
|
}
|
||||||
@@ -1,8 +1,11 @@
|
|||||||
|
// Package handlers provides HTTP request handlers for the
|
||||||
|
// webhooker web UI and API.
|
||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"html/template"
|
"html/template"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -18,9 +21,24 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/templates"
|
"sneak.berlin/go/webhooker/templates"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // HandlersParams is a standard fx naming convention
|
const (
|
||||||
|
// maxBodyShift is the bit shift for 1 MB body limit.
|
||||||
|
maxBodyShift = 20
|
||||||
|
// recentEventLimit is the number of recent events to show.
|
||||||
|
recentEventLimit = 20
|
||||||
|
// defaultRetentionDays is the default event retention period.
|
||||||
|
defaultRetentionDays = 30
|
||||||
|
// paginationPerPage is the number of items per page.
|
||||||
|
paginationPerPage = 25
|
||||||
|
)
|
||||||
|
|
||||||
|
// errInvalidPassword is returned when a password does not match.
|
||||||
|
var errInvalidPassword = errors.New("invalid password")
|
||||||
|
|
||||||
|
//nolint:revive // HandlersParams is a standard fx naming convention.
|
||||||
type HandlersParams struct {
|
type HandlersParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
Globals *globals.Globals
|
Globals *globals.Globals
|
||||||
Database *database.Database
|
Database *database.Database
|
||||||
@@ -30,6 +48,8 @@ type HandlersParams struct {
|
|||||||
Notifier delivery.Notifier
|
Notifier delivery.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handlers provides HTTP handler methods for all application
|
||||||
|
// routes.
|
||||||
type Handlers struct {
|
type Handlers struct {
|
||||||
params *HandlersParams
|
params *HandlersParams
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
@@ -41,19 +61,29 @@ type Handlers struct {
|
|||||||
templates map[string]*template.Template
|
templates map[string]*template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
// parsePageTemplate parses a page-specific template set from the embedded FS.
|
// parsePageTemplate parses a page-specific template set from the
|
||||||
// Each page template is combined with the shared base, htmlheader, and navbar templates.
|
// embedded FS. Each page template is combined with the shared
|
||||||
// The page file must be listed first so that its root action ({{template "base" .}})
|
// base, htmlheader, and navbar templates. The page file must be
|
||||||
// becomes the template set's entry point. If a shared partial (e.g. htmlheader.html)
|
// listed first so that its root action ({{template "base" .}})
|
||||||
// is listed first, its {{define}} block becomes the root — which is empty — and
|
// becomes the template set's entry point.
|
||||||
// Execute() produces no output.
|
|
||||||
func parsePageTemplate(pageFile string) *template.Template {
|
func parsePageTemplate(pageFile string) *template.Template {
|
||||||
return template.Must(
|
return template.Must(
|
||||||
template.ParseFS(templates.Templates, pageFile, "base.html", "htmlheader.html", "navbar.html"),
|
template.ParseFS(
|
||||||
|
templates.Templates,
|
||||||
|
pageFile,
|
||||||
|
"base.html",
|
||||||
|
"htmlheader.html",
|
||||||
|
"navbar.html",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(lc fx.Lifecycle, params HandlersParams) (*Handlers, error) {
|
// New creates a Handlers instance, parsing all page templates at
|
||||||
|
// startup.
|
||||||
|
func New(
|
||||||
|
lc fx.Lifecycle,
|
||||||
|
params HandlersParams,
|
||||||
|
) (*Handlers, error) {
|
||||||
s := new(Handlers)
|
s := new(Handlers)
|
||||||
s.params = ¶ms
|
s.params = ¶ms
|
||||||
s.log = params.Logger.Get()
|
s.log = params.Logger.Get()
|
||||||
@@ -75,17 +105,23 @@ func New(lc fx.Lifecycle, params HandlersParams) (*Handlers, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
lc.Append(fx.Hook{
|
lc.Append(fx.Hook{
|
||||||
OnStart: func(ctx context.Context) error {
|
OnStart: func(_ context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:unparam // r parameter will be used in the future for request context
|
func (s *Handlers) respondJSON(
|
||||||
func (s *Handlers) respondJSON(w http.ResponseWriter, r *http.Request, data interface{}, status int) {
|
w http.ResponseWriter,
|
||||||
|
_ *http.Request,
|
||||||
|
data any,
|
||||||
|
status int,
|
||||||
|
) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
|
|
||||||
if data != nil {
|
if data != nil {
|
||||||
err := json.NewEncoder(w).Encode(data)
|
err := json.NewEncoder(w).Encode(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,9 +130,15 @@ func (s *Handlers) respondJSON(w http.ResponseWriter, r *http.Request, data inte
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:unparam,unused // will be used for handling JSON requests
|
// serverError logs an error and sends a 500 response.
|
||||||
func (s *Handlers) decodeJSON(w http.ResponseWriter, r *http.Request, v interface{}) error {
|
func (s *Handlers) serverError(
|
||||||
return json.NewDecoder(r.Body).Decode(v)
|
w http.ResponseWriter, msg string, err error,
|
||||||
|
) {
|
||||||
|
s.log.Error(msg, "error", err)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserInfo represents user information for templates
|
// UserInfo represents user information for templates
|
||||||
@@ -105,48 +147,66 @@ type UserInfo struct {
|
|||||||
Username string
|
Username string
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderTemplate renders a pre-parsed template with common data
|
// templateDataWrapper wraps non-map data with common fields.
|
||||||
func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTemplate string, data interface{}) {
|
type templateDataWrapper struct {
|
||||||
|
User *UserInfo
|
||||||
|
CSRFToken string
|
||||||
|
Data any
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserInfo extracts user info from the session.
|
||||||
|
func (s *Handlers) getUserInfo(
|
||||||
|
r *http.Request,
|
||||||
|
) *UserInfo {
|
||||||
|
sess, err := s.session.Get(r)
|
||||||
|
if err != nil || !s.session.IsAuthenticated(sess) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
username, ok := s.session.GetUsername(sess)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, ok := s.session.GetUserID(sess)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserInfo{ID: userID, Username: username}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderTemplate renders a pre-parsed template with common
|
||||||
|
// data
|
||||||
|
func (s *Handlers) renderTemplate(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
pageTemplate string,
|
||||||
|
data any,
|
||||||
|
) {
|
||||||
tmpl, ok := s.templates[pageTemplate]
|
tmpl, ok := s.templates[pageTemplate]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.log.Error("template not found", "template", pageTemplate)
|
s.log.Error(
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
"template not found",
|
||||||
|
"template", pageTemplate,
|
||||||
|
)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user from session if available
|
userInfo := s.getUserInfo(r)
|
||||||
var userInfo *UserInfo
|
|
||||||
sess, err := s.session.Get(r)
|
|
||||||
if err == nil && s.session.IsAuthenticated(sess) {
|
|
||||||
if username, ok := s.session.GetUsername(sess); ok {
|
|
||||||
if userID, ok := s.session.GetUserID(sess); ok {
|
|
||||||
userInfo = &UserInfo{
|
|
||||||
ID: userID,
|
|
||||||
Username: username,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get CSRF token from request context (set by CSRF middleware)
|
|
||||||
csrfToken := middleware.CSRFToken(r)
|
csrfToken := middleware.CSRFToken(r)
|
||||||
|
|
||||||
// If data is a map, merge user info and CSRF token into it
|
if m, ok := data.(map[string]any); ok {
|
||||||
if m, ok := data.(map[string]interface{}); ok {
|
|
||||||
m["User"] = userInfo
|
m["User"] = userInfo
|
||||||
m["CSRFToken"] = csrfToken
|
m["CSRFToken"] = csrfToken
|
||||||
if err := tmpl.Execute(w, m); err != nil {
|
s.executeTemplate(w, tmpl, m)
|
||||||
s.log.Error("failed to execute template", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap data with base template data
|
return
|
||||||
type templateDataWrapper struct {
|
|
||||||
User *UserInfo
|
|
||||||
CSRFToken string
|
|
||||||
Data interface{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
wrapper := templateDataWrapper{
|
wrapper := templateDataWrapper{
|
||||||
@@ -155,8 +215,23 @@ func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTe
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tmpl.Execute(w, wrapper); err != nil {
|
s.executeTemplate(w, tmpl, wrapper)
|
||||||
s.log.Error("failed to execute template", "error", err)
|
}
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
// executeTemplate runs the template and handles errors.
|
||||||
|
func (s *Handlers) executeTemplate(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
tmpl *template.Template,
|
||||||
|
data any,
|
||||||
|
) {
|
||||||
|
err := tmpl.Execute(w, data)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error(
|
||||||
|
"failed to execute template", "error", err,
|
||||||
|
)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,36 @@
|
|||||||
package handlers
|
package handlers_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
"go.uber.org/fx/fxtest"
|
"go.uber.org/fx/fxtest"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
"sneak.berlin/go/webhooker/internal/database"
|
"sneak.berlin/go/webhooker/internal/database"
|
||||||
"sneak.berlin/go/webhooker/internal/delivery"
|
"sneak.berlin/go/webhooker/internal/delivery"
|
||||||
"sneak.berlin/go/webhooker/internal/globals"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
|
"sneak.berlin/go/webhooker/internal/handlers"
|
||||||
"sneak.berlin/go/webhooker/internal/healthcheck"
|
"sneak.berlin/go/webhooker/internal/healthcheck"
|
||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
"sneak.berlin/go/webhooker/internal/session"
|
"sneak.berlin/go/webhooker/internal/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// noopNotifier is a no-op delivery.Notifier for tests.
|
|
||||||
type noopNotifier struct{}
|
type noopNotifier struct{}
|
||||||
|
|
||||||
func (n *noopNotifier) Notify([]delivery.DeliveryTask) {}
|
func (n *noopNotifier) Notify([]delivery.Task) {}
|
||||||
|
|
||||||
func TestHandleIndex(t *testing.T) {
|
func newTestApp(
|
||||||
var h *Handlers
|
t *testing.T,
|
||||||
var sess *session.Session
|
targets ...any,
|
||||||
|
) *fxtest.App {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
app := fxtest.New(
|
return fxtest.New(
|
||||||
t,
|
t,
|
||||||
fx.Provide(
|
fx.Provide(
|
||||||
globals.New,
|
globals.New,
|
||||||
@@ -40,90 +44,99 @@ func TestHandleIndex(t *testing.T) {
|
|||||||
database.NewWebhookDBManager,
|
database.NewWebhookDBManager,
|
||||||
healthcheck.New,
|
healthcheck.New,
|
||||||
session.New,
|
session.New,
|
||||||
func() delivery.Notifier { return &noopNotifier{} },
|
func() delivery.Notifier {
|
||||||
New,
|
return &noopNotifier{}
|
||||||
|
},
|
||||||
|
handlers.New,
|
||||||
),
|
),
|
||||||
fx.Populate(&h, &sess),
|
fx.Populate(targets...),
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleIndex_Unauthenticated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var h *handlers.Handlers
|
||||||
|
|
||||||
|
app := newTestApp(t, &h)
|
||||||
app.RequireStart()
|
app.RequireStart()
|
||||||
defer app.RequireStop()
|
|
||||||
|
|
||||||
t.Run("unauthenticated redirects to login", func(t *testing.T) {
|
t.Cleanup(app.RequireStop)
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handler := h.HandleIndex()
|
req := httptest.NewRequestWithContext(
|
||||||
handler.ServeHTTP(w, req)
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusSeeOther, w.Code)
|
handler := h.HandleIndex()
|
||||||
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
handler.ServeHTTP(w, req)
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("authenticated redirects to sources", func(t *testing.T) {
|
assert.Equal(t, http.StatusSeeOther, w.Code)
|
||||||
// Create a request, set up an authenticated session, then test
|
assert.Equal(
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
t, "/pages/login", w.Header().Get("Location"),
|
||||||
w := httptest.NewRecorder()
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Get a session and mark it as authenticated
|
func TestHandleIndex_Authenticated(t *testing.T) {
|
||||||
s, err := sess.Get(req)
|
t.Parallel()
|
||||||
assert.NoError(t, err)
|
|
||||||
sess.SetUser(s, "test-user-id", "testuser")
|
|
||||||
err = sess.Save(req, w, s)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Build a new request with the session cookie from the response
|
var h *handlers.Handlers
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
for _, cookie := range w.Result().Cookies() {
|
|
||||||
req2.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
w2 := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handler := h.HandleIndex()
|
var sess *session.Session
|
||||||
handler.ServeHTTP(w2, req2)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusSeeOther, w2.Code)
|
app := newTestApp(t, &h, &sess)
|
||||||
assert.Equal(t, "/sources", w2.Header().Get("Location"))
|
app.RequireStart()
|
||||||
})
|
|
||||||
|
t.Cleanup(app.RequireStop)
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s, err := sess.Get(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sess.SetUser(s, "test-user-id", "testuser")
|
||||||
|
|
||||||
|
err = sess.Save(req, w, s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req2 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
for _, cookie := range w.Result().Cookies() {
|
||||||
|
req2.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
h.HandleIndex().ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusSeeOther, w2.Code)
|
||||||
|
assert.Equal(
|
||||||
|
t, "/sources", w2.Header().Get("Location"),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRenderTemplate(t *testing.T) {
|
func TestRenderTemplate(t *testing.T) {
|
||||||
var h *Handlers
|
t.Parallel()
|
||||||
|
|
||||||
app := fxtest.New(
|
var h *handlers.Handlers
|
||||||
t,
|
|
||||||
fx.Provide(
|
app := newTestApp(t, &h)
|
||||||
globals.New,
|
|
||||||
logger.New,
|
|
||||||
func() *config.Config {
|
|
||||||
return &config.Config{
|
|
||||||
DataDir: t.TempDir(),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
database.New,
|
|
||||||
database.NewWebhookDBManager,
|
|
||||||
healthcheck.New,
|
|
||||||
session.New,
|
|
||||||
func() delivery.Notifier { return &noopNotifier{} },
|
|
||||||
New,
|
|
||||||
),
|
|
||||||
fx.Populate(&h),
|
|
||||||
)
|
|
||||||
app.RequireStart()
|
app.RequireStart()
|
||||||
defer app.RequireStop()
|
|
||||||
|
|
||||||
t.Run("handles missing templates gracefully", func(t *testing.T) {
|
t.Cleanup(app.RequireStop)
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
data := map[string]interface{}{
|
req := httptest.NewRequestWithContext(
|
||||||
"Version": "1.0.0",
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
}
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// When a non-existent template name is requested, renderTemplate
|
data := map[string]any{"Version": "1.0.0"}
|
||||||
// should return an internal server error
|
|
||||||
h.renderTemplate(w, req, "nonexistent.html", data)
|
|
||||||
|
|
||||||
// Should return internal server error when template is not found
|
h.RenderTemplateForTest(
|
||||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
w, req, "nonexistent.html", data,
|
||||||
})
|
)
|
||||||
|
|
||||||
|
assert.Equal(
|
||||||
|
t, http.StatusInternalServerError, w.Code,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,13 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const httpStatusOK = 200
|
||||||
|
|
||||||
|
// HandleHealthCheck returns an HTTP handler that reports
|
||||||
|
// application health.
|
||||||
func (s *Handlers) HandleHealthCheck() http.HandlerFunc {
|
func (s *Handlers) HandleHealthCheck() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, req *http.Request) {
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
resp := s.hc.Healthcheck()
|
resp := s.hc.Healthcheck()
|
||||||
s.respondJSON(w, req, resp, 200)
|
s.respondJSON(w, req, resp, httpStatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleIndex returns a handler for the root path that redirects based
|
// HandleIndex returns a handler for the root path that redirects
|
||||||
// on authentication state: authenticated users go to /sources (the
|
// based on authentication state: authenticated users go to /sources
|
||||||
// dashboard), unauthenticated users go to the login page.
|
// (the dashboard), unauthenticated users go to the login page.
|
||||||
func (s *Handlers) HandleIndex() http.HandlerFunc {
|
func (s *Handlers) HandleIndex() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
sess, err := s.session.Get(r)
|
sess, err := s.session.Get(r)
|
||||||
if err == nil && s.session.IsAuthenticated(sess) {
|
if err == nil && s.session.IsAuthenticated(sess) {
|
||||||
http.Redirect(w, r, "/sources", http.StatusSeeOther)
|
http.Redirect(w, r, "/sources", http.StatusSeeOther)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
|||||||
requestedUsername := chi.URLParam(r, "username")
|
requestedUsername := chi.URLParam(r, "username")
|
||||||
if requestedUsername == "" {
|
if requestedUsername == "" {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
|||||||
if err != nil || !h.session.IsAuthenticated(sess) {
|
if err != nil || !h.session.IsAuthenticated(sess) {
|
||||||
// Redirect to login if not authenticated
|
// Redirect to login if not authenticated
|
||||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +31,7 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
|||||||
if !ok {
|
if !ok {
|
||||||
h.log.Error("authenticated session missing username")
|
h.log.Error("authenticated session missing username")
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,17 +39,19 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
|||||||
if !ok {
|
if !ok {
|
||||||
h.log.Error("authenticated session missing user ID")
|
h.log.Error("authenticated session missing user ID")
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For now, only allow users to view their own profile
|
// For now, only allow users to view their own profile
|
||||||
if requestedUsername != sessionUsername {
|
if requestedUsername != sessionUsername {
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare data for template
|
// Prepare data for template
|
||||||
data := map[string]interface{}{
|
data := map[string]any{
|
||||||
"User": &UserInfo{
|
"User": &UserInfo{
|
||||||
ID: sessionUserID,
|
ID: sessionUserID,
|
||||||
Username: sessionUsername,
|
Username: sessionUsername,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -6,31 +6,36 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
"gorm.io/gorm"
|
||||||
"sneak.berlin/go/webhooker/internal/database"
|
"sneak.berlin/go/webhooker/internal/database"
|
||||||
"sneak.berlin/go/webhooker/internal/delivery"
|
"sneak.berlin/go/webhooker/internal/delivery"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// maxWebhookBodySize is the maximum allowed webhook request body (1 MB).
|
// maxWebhookBodySize is the maximum allowed webhook
|
||||||
maxWebhookBodySize = 1 << 20
|
// request body (1 MB).
|
||||||
|
maxWebhookBodySize = 1 << maxBodyShift
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleWebhook handles incoming webhook requests at entrypoint URLs.
|
// HandleWebhook handles incoming webhook requests at entrypoint
|
||||||
// Only POST requests are accepted; all other methods return 405 Method Not Allowed.
|
// URLs.
|
||||||
// Events and deliveries are stored in the per-webhook database. The handler
|
|
||||||
// builds self-contained DeliveryTask structs with all target and event data
|
|
||||||
// so the delivery engine can process them without additional DB reads.
|
|
||||||
func (h *Handlers) HandleWebhook() http.HandlerFunc {
|
func (h *Handlers) HandleWebhook() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
w.Header().Set("Allow", "POST")
|
w.Header().Set("Allow", "POST")
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
http.Error(
|
||||||
|
w,
|
||||||
|
"Method Not Allowed",
|
||||||
|
http.StatusMethodNotAllowed,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entrypointUUID := chi.URLParam(r, "uuid")
|
entrypointUUID := chi.URLParam(r, "uuid")
|
||||||
if entrypointUUID == "" {
|
if entrypointUUID == "" {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,152 +45,302 @@ func (h *Handlers) HandleWebhook() http.HandlerFunc {
|
|||||||
"remote_addr", r.RemoteAddr,
|
"remote_addr", r.RemoteAddr,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Look up entrypoint by path (from main application DB)
|
entrypoint, ok := h.lookupEntrypoint(
|
||||||
var entrypoint database.Entrypoint
|
w, r, entrypointUUID,
|
||||||
result := h.db.DB().Where("path = ?", entrypointUUID).First(&entrypoint)
|
)
|
||||||
if result.Error != nil {
|
if !ok {
|
||||||
h.log.Debug("entrypoint not found", "path", entrypointUUID)
|
|
||||||
http.NotFound(w, r)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if active
|
|
||||||
if !entrypoint.Active {
|
if !entrypoint.Active {
|
||||||
http.Error(w, "Gone", http.StatusGone)
|
http.Error(w, "Gone", http.StatusGone)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read body with size limit
|
h.processWebhookRequest(w, r, entrypoint)
|
||||||
body, err := io.ReadAll(io.LimitReader(r.Body, maxWebhookBodySize+1))
|
|
||||||
if err != nil {
|
|
||||||
h.log.Error("failed to read request body", "error", err)
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(body) > maxWebhookBodySize {
|
|
||||||
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize headers as JSON
|
|
||||||
headersJSON, err := json.Marshal(r.Header)
|
|
||||||
if err != nil {
|
|
||||||
h.log.Error("failed to serialize headers", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find all active targets for this webhook (from main application DB)
|
|
||||||
var targets []database.Target
|
|
||||||
if targetErr := h.db.DB().Where("webhook_id = ? AND active = ?", entrypoint.WebhookID, true).Find(&targets).Error; targetErr != nil {
|
|
||||||
h.log.Error("failed to query targets", "error", targetErr)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the per-webhook database for event storage
|
|
||||||
webhookDB, err := h.dbMgr.GetDB(entrypoint.WebhookID)
|
|
||||||
if err != nil {
|
|
||||||
h.log.Error("failed to get webhook database",
|
|
||||||
"webhook_id", entrypoint.WebhookID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the event and deliveries in a transaction on the per-webhook DB
|
|
||||||
tx := webhookDB.Begin()
|
|
||||||
if tx.Error != nil {
|
|
||||||
h.log.Error("failed to begin transaction", "error", tx.Error)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
event := &database.Event{
|
|
||||||
WebhookID: entrypoint.WebhookID,
|
|
||||||
EntrypointID: entrypoint.ID,
|
|
||||||
Method: r.Method,
|
|
||||||
Headers: string(headersJSON),
|
|
||||||
Body: string(body),
|
|
||||||
ContentType: r.Header.Get("Content-Type"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Create(event).Error; err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
h.log.Error("failed to create event", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare body pointer for inline transport (≤16KB bodies are
|
|
||||||
// included in the DeliveryTask so the engine needs no DB read).
|
|
||||||
var bodyPtr *string
|
|
||||||
if len(body) < delivery.MaxInlineBodySize {
|
|
||||||
bodyStr := string(body)
|
|
||||||
bodyPtr = &bodyStr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create delivery records and build self-contained delivery tasks
|
|
||||||
tasks := make([]delivery.DeliveryTask, 0, len(targets))
|
|
||||||
for i := range targets {
|
|
||||||
dlv := &database.Delivery{
|
|
||||||
EventID: event.ID,
|
|
||||||
TargetID: targets[i].ID,
|
|
||||||
Status: database.DeliveryStatusPending,
|
|
||||||
}
|
|
||||||
if err := tx.Create(dlv).Error; err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
h.log.Error("failed to create delivery",
|
|
||||||
"target_id", targets[i].ID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks = append(tasks, delivery.DeliveryTask{
|
|
||||||
DeliveryID: dlv.ID,
|
|
||||||
EventID: event.ID,
|
|
||||||
WebhookID: entrypoint.WebhookID,
|
|
||||||
TargetID: targets[i].ID,
|
|
||||||
TargetName: targets[i].Name,
|
|
||||||
TargetType: targets[i].Type,
|
|
||||||
TargetConfig: targets[i].Config,
|
|
||||||
MaxRetries: targets[i].MaxRetries,
|
|
||||||
Method: event.Method,
|
|
||||||
Headers: event.Headers,
|
|
||||||
ContentType: event.ContentType,
|
|
||||||
Body: bodyPtr,
|
|
||||||
AttemptNum: 1,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit().Error; err != nil {
|
|
||||||
h.log.Error("failed to commit transaction", "error", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify the delivery engine with self-contained delivery tasks.
|
|
||||||
// Each task carries all target config and event data inline so
|
|
||||||
// the engine can deliver without touching any database (in the
|
|
||||||
// ≤16KB happy path). The engine only writes to the DB to record
|
|
||||||
// delivery results after each attempt.
|
|
||||||
if len(tasks) > 0 {
|
|
||||||
h.notifier.Notify(tasks)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.log.Info("webhook event created",
|
|
||||||
"event_id", event.ID,
|
|
||||||
"webhook_id", entrypoint.WebhookID,
|
|
||||||
"entrypoint_id", entrypoint.ID,
|
|
||||||
"target_count", len(targets),
|
|
||||||
)
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil {
|
|
||||||
h.log.Error("failed to write response", "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processWebhookRequest reads the body, serializes headers,
|
||||||
|
// loads targets, and delivers the event.
|
||||||
|
func (h *Handlers) processWebhookRequest(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
entrypoint database.Entrypoint,
|
||||||
|
) {
|
||||||
|
body, ok := h.readWebhookBody(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
headersJSON, err := json.Marshal(r.Header)
|
||||||
|
if err != nil {
|
||||||
|
h.serverError(w, "failed to serialize headers", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targets, err := h.loadActiveTargets(entrypoint.WebhookID)
|
||||||
|
if err != nil {
|
||||||
|
h.serverError(w, "failed to query targets", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.createAndDeliverEvent(
|
||||||
|
w, r, entrypoint, body, headersJSON, targets,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadActiveTargets returns all active targets for a webhook.
|
||||||
|
func (h *Handlers) loadActiveTargets(
|
||||||
|
webhookID string,
|
||||||
|
) ([]database.Target, error) {
|
||||||
|
var targets []database.Target
|
||||||
|
|
||||||
|
err := h.db.DB().Where(
|
||||||
|
"webhook_id = ? AND active = ?",
|
||||||
|
webhookID, true,
|
||||||
|
).Find(&targets).Error
|
||||||
|
|
||||||
|
return targets, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupEntrypoint finds an entrypoint by UUID path.
|
||||||
|
func (h *Handlers) lookupEntrypoint(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
entrypointUUID string,
|
||||||
|
) (database.Entrypoint, bool) {
|
||||||
|
var entrypoint database.Entrypoint
|
||||||
|
|
||||||
|
result := h.db.DB().Where(
|
||||||
|
"path = ?", entrypointUUID,
|
||||||
|
).First(&entrypoint)
|
||||||
|
if result.Error != nil {
|
||||||
|
h.log.Debug(
|
||||||
|
"entrypoint not found",
|
||||||
|
"path", entrypointUUID,
|
||||||
|
)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
|
||||||
|
return entrypoint, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entrypoint, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// readWebhookBody reads and validates the request body size.
|
||||||
|
func (h *Handlers) readWebhookBody(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) ([]byte, bool) {
|
||||||
|
body, err := io.ReadAll(
|
||||||
|
io.LimitReader(r.Body, maxWebhookBodySize+1),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
h.log.Error(
|
||||||
|
"failed to read request body", "error", err,
|
||||||
|
)
|
||||||
|
http.Error(
|
||||||
|
w, "Bad request", http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) > maxWebhookBodySize {
|
||||||
|
http.Error(
|
||||||
|
w,
|
||||||
|
"Request body too large",
|
||||||
|
http.StatusRequestEntityTooLarge,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAndDeliverEvent creates the event and delivery records
|
||||||
|
// then notifies the delivery engine.
|
||||||
|
func (h *Handlers) createAndDeliverEvent(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
entrypoint database.Entrypoint,
|
||||||
|
body, headersJSON []byte,
|
||||||
|
targets []database.Target,
|
||||||
|
) {
|
||||||
|
tx, err := h.beginWebhookTx(w, entrypoint.WebhookID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
event := h.buildEvent(r, entrypoint, headersJSON, body)
|
||||||
|
|
||||||
|
err = tx.Create(event).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
h.serverError(w, "failed to create event", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyPtr := inlineBody(body)
|
||||||
|
|
||||||
|
tasks := h.buildDeliveryTasks(
|
||||||
|
w, tx, event, entrypoint, targets, bodyPtr,
|
||||||
|
)
|
||||||
|
if tasks == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit().Error
|
||||||
|
if err != nil {
|
||||||
|
h.serverError(w, "failed to commit transaction", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.finishWebhookResponse(w, event, entrypoint, tasks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// beginWebhookTx opens a transaction on the per-webhook DB.
|
||||||
|
func (h *Handlers) beginWebhookTx(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
webhookID string,
|
||||||
|
) (*gorm.DB, error) {
|
||||||
|
webhookDB, err := h.dbMgr.GetDB(webhookID)
|
||||||
|
if err != nil {
|
||||||
|
h.serverError(
|
||||||
|
w, "failed to get webhook database", err,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := webhookDB.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
h.serverError(
|
||||||
|
w, "failed to begin transaction", tx.Error,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil, tx.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// inlineBody returns a pointer to body as a string if it fits
|
||||||
|
// within the inline size limit, or nil otherwise.
|
||||||
|
func inlineBody(body []byte) *string {
|
||||||
|
if len(body) < delivery.MaxInlineBodySize {
|
||||||
|
s := string(body)
|
||||||
|
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// finishWebhookResponse notifies the delivery engine, logs the
|
||||||
|
// event, and writes the HTTP response.
|
||||||
|
func (h *Handlers) finishWebhookResponse(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
event *database.Event,
|
||||||
|
entrypoint database.Entrypoint,
|
||||||
|
tasks []delivery.Task,
|
||||||
|
) {
|
||||||
|
if len(tasks) > 0 {
|
||||||
|
h.notifier.Notify(tasks)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.log.Info("webhook event created",
|
||||||
|
"event_id", event.ID,
|
||||||
|
"webhook_id", entrypoint.WebhookID,
|
||||||
|
"entrypoint_id", entrypoint.ID,
|
||||||
|
"target_count", len(tasks),
|
||||||
|
)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
_, err := w.Write([]byte(`{"status":"ok"}`))
|
||||||
|
if err != nil {
|
||||||
|
h.log.Error(
|
||||||
|
"failed to write response", "error", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildEvent creates a new Event struct from request data.
|
||||||
|
func (h *Handlers) buildEvent(
|
||||||
|
r *http.Request,
|
||||||
|
entrypoint database.Entrypoint,
|
||||||
|
headersJSON, body []byte,
|
||||||
|
) *database.Event {
|
||||||
|
return &database.Event{
|
||||||
|
WebhookID: entrypoint.WebhookID,
|
||||||
|
EntrypointID: entrypoint.ID,
|
||||||
|
Method: r.Method,
|
||||||
|
Headers: string(headersJSON),
|
||||||
|
Body: string(body),
|
||||||
|
ContentType: r.Header.Get("Content-Type"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDeliveryTasks creates delivery records in the
|
||||||
|
// transaction and returns tasks for the delivery engine.
|
||||||
|
// Returns nil if an error occurred.
|
||||||
|
func (h *Handlers) buildDeliveryTasks(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
tx *gorm.DB,
|
||||||
|
event *database.Event,
|
||||||
|
entrypoint database.Entrypoint,
|
||||||
|
targets []database.Target,
|
||||||
|
bodyPtr *string,
|
||||||
|
) []delivery.Task {
|
||||||
|
tasks := make([]delivery.Task, 0, len(targets))
|
||||||
|
|
||||||
|
for i := range targets {
|
||||||
|
dlv := &database.Delivery{
|
||||||
|
EventID: event.ID,
|
||||||
|
TargetID: targets[i].ID,
|
||||||
|
Status: database.DeliveryStatusPending,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tx.Create(dlv).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
h.log.Error(
|
||||||
|
"failed to create delivery",
|
||||||
|
"target_id", targets[i].ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
http.Error(
|
||||||
|
w, "Internal server error",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks = append(tasks, delivery.Task{
|
||||||
|
DeliveryID: dlv.ID,
|
||||||
|
EventID: event.ID,
|
||||||
|
WebhookID: entrypoint.WebhookID,
|
||||||
|
TargetID: targets[i].ID,
|
||||||
|
TargetName: targets[i].Name,
|
||||||
|
TargetType: targets[i].Type,
|
||||||
|
TargetConfig: targets[i].Config,
|
||||||
|
MaxRetries: targets[i].MaxRetries,
|
||||||
|
Method: event.Method,
|
||||||
|
Headers: event.Headers,
|
||||||
|
ContentType: event.ContentType,
|
||||||
|
Body: bodyPtr,
|
||||||
|
AttemptNum: 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package healthcheck provides application health status reporting.
|
||||||
package healthcheck
|
package healthcheck
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -12,55 +13,51 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/internal/logger"
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // HealthcheckParams is a standard fx naming convention
|
//nolint:revive // HealthcheckParams is a standard fx naming convention.
|
||||||
type HealthcheckParams struct {
|
type HealthcheckParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Globals *globals.Globals
|
Globals *globals.Globals
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
Database *database.Database
|
Database *database.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Healthcheck tracks application uptime and reports health status.
|
||||||
type Healthcheck struct {
|
type Healthcheck struct {
|
||||||
StartupTime time.Time
|
StartupTime time.Time
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
params *HealthcheckParams
|
params *HealthcheckParams
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(lc fx.Lifecycle, params HealthcheckParams) (*Healthcheck, error) {
|
// New creates a Healthcheck that records the startup time on fx
|
||||||
|
// start.
|
||||||
|
func New(
|
||||||
|
lc fx.Lifecycle,
|
||||||
|
params HealthcheckParams,
|
||||||
|
) (*Healthcheck, error) {
|
||||||
s := new(Healthcheck)
|
s := new(Healthcheck)
|
||||||
s.params = ¶ms
|
s.params = ¶ms
|
||||||
s.log = params.Logger.Get()
|
s.log = params.Logger.Get()
|
||||||
|
|
||||||
lc.Append(fx.Hook{
|
lc.Append(fx.Hook{
|
||||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
OnStart: func(_ context.Context) error {
|
||||||
s.StartupTime = time.Now()
|
s.StartupTime = time.Now()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
OnStop: func(ctx context.Context) error {
|
OnStop: func(_ context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:revive // HealthcheckResponse is a clear, descriptive name
|
// Healthcheck returns the current health status of the
|
||||||
type HealthcheckResponse struct {
|
// application.
|
||||||
Status string `json:"status"`
|
func (s *Healthcheck) Healthcheck() *Response {
|
||||||
Now string `json:"now"`
|
resp := &Response{
|
||||||
UptimeSeconds int64 `json:"uptime_seconds"`
|
|
||||||
UptimeHuman string `json:"uptime_human"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
Appname string `json:"appname"`
|
|
||||||
Maintenance bool `json:"maintenance_mode"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Healthcheck) uptime() time.Duration {
|
|
||||||
return time.Since(s.StartupTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Healthcheck) Healthcheck() *HealthcheckResponse {
|
|
||||||
resp := &HealthcheckResponse{
|
|
||||||
Status: "ok",
|
Status: "ok",
|
||||||
Now: time.Now().UTC().Format(time.RFC3339Nano),
|
Now: time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
UptimeSeconds: int64(s.uptime().Seconds()),
|
UptimeSeconds: int64(s.uptime().Seconds()),
|
||||||
@@ -69,5 +66,21 @@ func (s *Healthcheck) Healthcheck() *HealthcheckResponse {
|
|||||||
Version: s.params.Globals.Version,
|
Version: s.params.Globals.Version,
|
||||||
Maintenance: s.params.Config.MaintenanceMode,
|
Maintenance: s.params.Config.MaintenanceMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Response contains the JSON-serialised health status.
|
||||||
|
type Response struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Now string `json:"now"`
|
||||||
|
UptimeSeconds int64 `json:"uptimeSeconds"`
|
||||||
|
UptimeHuman string `json:"uptimeHuman"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Appname string `json:"appname"`
|
||||||
|
Maintenance bool `json:"maintenanceMode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Healthcheck) uptime() time.Duration {
|
||||||
|
return time.Since(s.StartupTime)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
// Package logger provides structured logging with dynamic level
|
||||||
|
// control.
|
||||||
package logger
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -10,19 +12,25 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/internal/globals"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // LoggerParams is a standard fx naming convention
|
//nolint:revive // LoggerParams is a standard fx naming convention.
|
||||||
type LoggerParams struct {
|
type LoggerParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Globals *globals.Globals
|
Globals *globals.Globals
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger wraps slog with dynamic level control and structured
|
||||||
|
// output.
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
levelVar *slog.LevelVar
|
levelVar *slog.LevelVar
|
||||||
params LoggerParams
|
params LoggerParams
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:revive // lc parameter is required by fx even if unused
|
// New creates a Logger that outputs text (TTY) or JSON (non-TTY)
|
||||||
|
// to stdout.
|
||||||
|
//
|
||||||
|
//nolint:revive // lc parameter is required by fx even if unused.
|
||||||
func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
||||||
l := new(Logger)
|
l := new(Logger)
|
||||||
l.params = params
|
l.params = params
|
||||||
@@ -37,17 +45,22 @@ func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
|||||||
tty = true
|
tty = true
|
||||||
}
|
}
|
||||||
|
|
||||||
replaceAttr := func(_ []string, a slog.Attr) slog.Attr { // nolint:revive // groups unused
|
//nolint:revive // groups param unused but required by slog ReplaceAttr signature.
|
||||||
|
replaceAttr := func(_ []string, a slog.Attr) slog.Attr {
|
||||||
// Always use UTC for timestamps
|
// Always use UTC for timestamps
|
||||||
if a.Key == slog.TimeKey {
|
if a.Key == slog.TimeKey {
|
||||||
if t, ok := a.Value.Any().(time.Time); ok {
|
if t, ok := a.Value.Any().(time.Time); ok {
|
||||||
return slog.Time(slog.TimeKey, t.UTC())
|
return slog.Time(slog.TimeKey, t.UTC())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
var handler slog.Handler
|
var handler slog.Handler
|
||||||
|
|
||||||
opts := &slog.HandlerOptions{
|
opts := &slog.HandlerOptions{
|
||||||
Level: l.levelVar,
|
Level: l.levelVar,
|
||||||
ReplaceAttr: replaceAttr,
|
ReplaceAttr: replaceAttr,
|
||||||
@@ -69,15 +82,18 @@ func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
|||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnableDebugLogging switches the log level to debug.
|
||||||
func (l *Logger) EnableDebugLogging() {
|
func (l *Logger) EnableDebugLogging() {
|
||||||
l.levelVar.Set(slog.LevelDebug)
|
l.levelVar.Set(slog.LevelDebug)
|
||||||
l.logger.Debug("debug logging enabled", "debug", true)
|
l.logger.Debug("debug logging enabled", "debug", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get returns the underlying slog.Logger.
|
||||||
func (l *Logger) Get() *slog.Logger {
|
func (l *Logger) Get() *slog.Logger {
|
||||||
return l.logger
|
return l.logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Identify logs the application name and version at startup.
|
||||||
func (l *Logger) Identify() {
|
func (l *Logger) Identify() {
|
||||||
l.logger.Info("starting",
|
l.logger.Info("starting",
|
||||||
"appname", l.params.Globals.Appname,
|
"appname", l.params.Globals.Appname,
|
||||||
@@ -85,7 +101,8 @@ func (l *Logger) Identify() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper methods to maintain compatibility with existing code
|
// Writer returns an io.Writer suitable for standard library
|
||||||
|
// loggers.
|
||||||
func (l *Logger) Writer() io.Writer {
|
func (l *Logger) Writer() io.Writer {
|
||||||
return os.Stdout
|
return os.Stdout
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,63 +1,59 @@
|
|||||||
package logger
|
package logger_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"go.uber.org/fx/fxtest"
|
"go.uber.org/fx/fxtest"
|
||||||
"sneak.berlin/go/webhooker/internal/globals"
|
"sneak.berlin/go/webhooker/internal/globals"
|
||||||
|
"sneak.berlin/go/webhooker/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func testGlobals() *globals.Globals {
|
||||||
|
return &globals.Globals{
|
||||||
|
Appname: "test-app",
|
||||||
|
Version: "1.0.0",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
// Set up globals
|
t.Parallel()
|
||||||
globals.Appname = "test-app"
|
|
||||||
globals.Version = "1.0.0"
|
|
||||||
|
|
||||||
lc := fxtest.NewLifecycle(t)
|
lc := fxtest.NewLifecycle(t)
|
||||||
g, err := globals.New(lc)
|
|
||||||
if err != nil {
|
params := logger.LoggerParams{
|
||||||
t.Fatalf("globals.New() error = %v", err)
|
Globals: testGlobals(),
|
||||||
}
|
}
|
||||||
|
|
||||||
params := LoggerParams{
|
l, err := logger.New(lc, params)
|
||||||
Globals: g,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := New(lc, params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("New() error = %v", err)
|
t.Fatalf("New() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if logger.Get() == nil {
|
if l.Get() == nil {
|
||||||
t.Error("Get() returned nil logger")
|
t.Error("Get() returned nil logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that we can log without panic
|
// Test that we can log without panic
|
||||||
logger.Get().Info("test message", "key", "value")
|
l.Get().Info("test message", "key", "value")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEnableDebugLogging(t *testing.T) {
|
func TestEnableDebugLogging(t *testing.T) {
|
||||||
// Set up globals
|
t.Parallel()
|
||||||
globals.Appname = "test-app"
|
|
||||||
globals.Version = "1.0.0"
|
|
||||||
|
|
||||||
lc := fxtest.NewLifecycle(t)
|
lc := fxtest.NewLifecycle(t)
|
||||||
g, err := globals.New(lc)
|
|
||||||
if err != nil {
|
params := logger.LoggerParams{
|
||||||
t.Fatalf("globals.New() error = %v", err)
|
Globals: testGlobals(),
|
||||||
}
|
}
|
||||||
|
|
||||||
params := LoggerParams{
|
l, err := logger.New(lc, params)
|
||||||
Globals: g,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := New(lc, params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("New() error = %v", err)
|
t.Fatalf("New() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable debug logging should not panic
|
// Enable debug logging should not panic
|
||||||
logger.EnableDebugLogging()
|
l.EnableDebugLogging()
|
||||||
|
|
||||||
// Test debug logging
|
// Test debug logging
|
||||||
logger.Get().Debug("debug message", "test", true)
|
l.Get().Debug("debug message", "test", true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -11,362 +12,483 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// csrfCookieName is the gorilla/csrf cookie name.
|
||||||
|
const csrfCookieName = "_gorilla_csrf"
|
||||||
|
|
||||||
|
// csrfGetToken performs a GET request through the CSRF middleware
|
||||||
|
// and returns the token and cookies.
|
||||||
|
func csrfGetToken(
|
||||||
|
t *testing.T,
|
||||||
|
csrfMW func(http.Handler) http.Handler,
|
||||||
|
getReq *http.Request,
|
||||||
|
) (string, []*http.Cookie) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var token string
|
||||||
|
|
||||||
|
getHandler := csrfMW(http.HandlerFunc(
|
||||||
|
func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
token = middleware.CSRFToken(r)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
getW := httptest.NewRecorder()
|
||||||
|
getHandler.ServeHTTP(getW, getReq)
|
||||||
|
|
||||||
|
cookies := getW.Result().Cookies()
|
||||||
|
require.NotEmpty(t, cookies, "CSRF cookie should be set")
|
||||||
|
require.NotEmpty(t, token, "CSRF token should be set")
|
||||||
|
|
||||||
|
return token, cookies
|
||||||
|
}
|
||||||
|
|
||||||
|
// csrfPostWithToken performs a POST request with the given CSRF
|
||||||
|
// token and cookies through the middleware. Returns whether the
|
||||||
|
// handler was called and the response code.
|
||||||
|
func csrfPostWithToken(
|
||||||
|
t *testing.T,
|
||||||
|
csrfMW func(http.Handler) http.Handler,
|
||||||
|
postReq *http.Request,
|
||||||
|
token string,
|
||||||
|
cookies []*http.Cookie,
|
||||||
|
) (bool, int) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var called bool
|
||||||
|
|
||||||
|
postHandler := csrfMW(http.HandlerFunc(
|
||||||
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
form := url.Values{"csrf_token": {token}}
|
||||||
|
postReq.Body = http.NoBody
|
||||||
|
postReq.Body = nil
|
||||||
|
|
||||||
|
// Rebuild the request with the form body
|
||||||
|
rebuilt := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
postReq.Method, postReq.URL.String(),
|
||||||
|
strings.NewReader(form.Encode()),
|
||||||
|
)
|
||||||
|
rebuilt.Header = postReq.Header.Clone()
|
||||||
|
rebuilt.TLS = postReq.TLS
|
||||||
|
rebuilt.Header.Set(
|
||||||
|
"Content-Type", "application/x-www-form-urlencoded",
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, c := range cookies {
|
||||||
|
rebuilt.AddCookie(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
postW := httptest.NewRecorder()
|
||||||
|
postHandler.ServeHTTP(postW, rebuilt)
|
||||||
|
|
||||||
|
return called, postW.Code
|
||||||
|
}
|
||||||
|
|
||||||
func TestCSRF_GETSetsToken(t *testing.T) {
|
func TestCSRF_GETSetsToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var gotToken string
|
var gotToken string
|
||||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
||||||
gotToken = CSRFToken(r)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
handler := m.CSRF()(http.HandlerFunc(
|
||||||
|
func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
gotToken = middleware.CSRFToken(r)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/form", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
assert.NotEmpty(
|
||||||
|
t, gotToken,
|
||||||
|
"CSRF token should be set in context on GET",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
csrfMW := m.CSRF()
|
||||||
|
|
||||||
// Capture the token from a GET request
|
getReq := httptest.NewRequestWithContext(
|
||||||
var token string
|
context.Background(),
|
||||||
csrfMiddleware := m.CSRF()
|
http.MethodGet, "/form", nil,
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
)
|
||||||
token = CSRFToken(r)
|
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||||
}))
|
|
||||||
|
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
postReq := httptest.NewRequestWithContext(
|
||||||
getW := httptest.NewRecorder()
|
context.Background(),
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
http.MethodPost, "/form", nil,
|
||||||
|
)
|
||||||
|
called, _ := csrfPostWithToken(
|
||||||
|
t, csrfMW, postReq, token, cookies,
|
||||||
|
)
|
||||||
|
|
||||||
cookies := getW.Result().Cookies()
|
assert.True(
|
||||||
require.NotEmpty(t, cookies)
|
t, called,
|
||||||
require.NotEmpty(t, token)
|
"handler should be called with valid CSRF token",
|
||||||
|
)
|
||||||
// POST with valid token and cookies from the GET response
|
|
||||||
var called bool
|
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
form := url.Values{"csrf_token": {token}}
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
for _, c := range cookies {
|
|
||||||
postReq.AddCookie(c)
|
|
||||||
}
|
|
||||||
postW := httptest.NewRecorder()
|
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
|
||||||
|
|
||||||
assert.True(t, called, "handler should be called with valid CSRF token")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
// csrfPOSTWithoutTokenTest is a shared helper for testing POST
|
||||||
t.Parallel()
|
// requests without a CSRF token in both dev and prod modes.
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
func csrfPOSTWithoutTokenTest(
|
||||||
|
t *testing.T,
|
||||||
|
env string,
|
||||||
|
msg string,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
csrfMiddleware := m.CSRF()
|
m, _ := testMiddleware(t, env)
|
||||||
|
csrfMW := m.CSRF()
|
||||||
|
|
||||||
// GET to establish the CSRF cookie
|
// GET to establish the CSRF cookie
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
getHandler := csrfMW(http.HandlerFunc(
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
func(_ http.ResponseWriter, _ *http.Request) {},
|
||||||
|
))
|
||||||
|
|
||||||
|
getReq := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/form", nil)
|
||||||
getW := httptest.NewRecorder()
|
getW := httptest.NewRecorder()
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
getHandler.ServeHTTP(getW, getReq)
|
||||||
|
|
||||||
cookies := getW.Result().Cookies()
|
cookies := getW.Result().Cookies()
|
||||||
|
|
||||||
// POST without CSRF token
|
// POST without CSRF token
|
||||||
var called bool
|
var called bool
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
postHandler := csrfMW(http.HandlerFunc(
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
postReq := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/form", nil,
|
||||||
|
)
|
||||||
|
postReq.Header.Set(
|
||||||
|
"Content-Type", "application/x-www-form-urlencoded",
|
||||||
|
)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
postReq.AddCookie(c)
|
postReq.AddCookie(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
postW := httptest.NewRecorder()
|
postW := httptest.NewRecorder()
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
postHandler.ServeHTTP(postW, postReq)
|
||||||
|
|
||||||
assert.False(t, called, "handler should NOT be called without CSRF token")
|
assert.False(t, called, msg)
|
||||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
csrfPOSTWithoutTokenTest(
|
||||||
|
t,
|
||||||
|
config.EnvironmentDev,
|
||||||
|
"handler should NOT be called without CSRF token",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
|
||||||
|
|
||||||
csrfMiddleware := m.CSRF()
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
csrfMW := m.CSRF()
|
||||||
|
|
||||||
// GET to establish the CSRF cookie
|
// GET to establish the CSRF cookie
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
getHandler := csrfMW(http.HandlerFunc(
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
func(_ http.ResponseWriter, _ *http.Request) {},
|
||||||
|
))
|
||||||
|
|
||||||
|
getReq := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/form", nil)
|
||||||
getW := httptest.NewRecorder()
|
getW := httptest.NewRecorder()
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
getHandler.ServeHTTP(getW, getReq)
|
||||||
|
|
||||||
cookies := getW.Result().Cookies()
|
cookies := getW.Result().Cookies()
|
||||||
|
|
||||||
// POST with wrong CSRF token
|
// POST with wrong CSRF token
|
||||||
var called bool
|
var called bool
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
postHandler := csrfMW(http.HandlerFunc(
|
||||||
}))
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
postReq := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/form",
|
||||||
|
strings.NewReader(form.Encode()),
|
||||||
|
)
|
||||||
|
postReq.Header.Set(
|
||||||
|
"Content-Type", "application/x-www-form-urlencoded",
|
||||||
|
)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
postReq.AddCookie(c)
|
postReq.AddCookie(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
postW := httptest.NewRecorder()
|
postW := httptest.NewRecorder()
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
postHandler.ServeHTTP(postW, postReq)
|
||||||
|
|
||||||
assert.False(t, called, "handler should NOT be called with invalid CSRF token")
|
assert.False(
|
||||||
|
t, called,
|
||||||
|
"handler should NOT be called with invalid CSRF token",
|
||||||
|
)
|
||||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
// GET requests should pass through without CSRF validation
|
handler := m.CSRF()(http.HandlerFunc(
|
||||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/form", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
assert.True(
|
||||||
|
t, called,
|
||||||
|
"GET requests should pass through CSRF middleware",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFToken_NoMiddleware(t *testing.T) {
|
func TestCSRFToken_NoMiddleware(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run")
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
assert.Empty(
|
||||||
|
t, middleware.CSRFToken(req),
|
||||||
|
"CSRFToken should return empty string when "+
|
||||||
|
"middleware has not run",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- TLS Detection Tests ---
|
// --- TLS Detection Tests ---
|
||||||
|
|
||||||
func TestIsClientTLS_DirectTLS(t *testing.T) {
|
func TestIsClientTLS_DirectTLS(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
r.TLS = &tls.ConnectionState{} // simulate direct TLS
|
r := httptest.NewRequestWithContext(
|
||||||
assert.True(t, isClientTLS(r), "should detect direct TLS connection")
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
r.TLS = &tls.ConnectionState{}
|
||||||
|
|
||||||
|
assert.True(
|
||||||
|
t, middleware.IsClientTLS(r),
|
||||||
|
"should detect direct TLS connection",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsClientTLS_XForwardedProto(t *testing.T) {
|
func TestIsClientTLS_XForwardedProto(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
r := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
r.Header.Set("X-Forwarded-Proto", "https")
|
r.Header.Set("X-Forwarded-Proto", "https")
|
||||||
assert.True(t, isClientTLS(r), "should detect TLS via X-Forwarded-Proto")
|
|
||||||
|
assert.True(
|
||||||
|
t, middleware.IsClientTLS(r),
|
||||||
|
"should detect TLS via X-Forwarded-Proto",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsClientTLS_PlaintextHTTP(t *testing.T) {
|
func TestIsClientTLS_PlaintextHTTP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
assert.False(t, isClientTLS(r), "should detect plaintext HTTP")
|
r := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
assert.False(
|
||||||
|
t, middleware.IsClientTLS(r),
|
||||||
|
"should detect plaintext HTTP",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) {
|
func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
r := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
r.Header.Set("X-Forwarded-Proto", "http")
|
r.Header.Set("X-Forwarded-Proto", "http")
|
||||||
assert.False(t, isClientTLS(r), "should detect plaintext when X-Forwarded-Proto is http")
|
|
||||||
|
assert.False(
|
||||||
|
t, middleware.IsClientTLS(r),
|
||||||
|
"should detect plaintext when X-Forwarded-Proto is http",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Production Mode: POST over plaintext HTTP ---
|
// --- Production Mode: POST over plaintext HTTP ---
|
||||||
|
|
||||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(t *testing.T) {
|
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||||
|
csrfMW := m.CSRF()
|
||||||
|
|
||||||
// This tests the critical fix: prod mode over plaintext HTTP should
|
getReq := httptest.NewRequestWithContext(
|
||||||
// work because the middleware detects the transport per-request.
|
context.Background(),
|
||||||
var token string
|
http.MethodGet, "/form", nil,
|
||||||
csrfMiddleware := m.CSRF()
|
)
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||||
token = CSRFToken(r)
|
|
||||||
}))
|
|
||||||
|
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
// Verify cookie is NOT Secure (plaintext HTTP in prod)
|
||||||
getW := httptest.NewRecorder()
|
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
|
||||||
|
|
||||||
cookies := getW.Result().Cookies()
|
|
||||||
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
|
||||||
require.NotEmpty(t, token, "CSRF token should be set in context on GET")
|
|
||||||
|
|
||||||
// Verify the cookie is NOT Secure (plaintext HTTP in prod mode)
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
if c.Name == "_gorilla_csrf" {
|
if c.Name == csrfCookieName {
|
||||||
assert.False(t, c.Secure, "CSRF cookie should not be Secure over plaintext HTTP")
|
assert.False(t, c.Secure,
|
||||||
|
"CSRF cookie should not be Secure "+
|
||||||
|
"over plaintext HTTP")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// POST with valid token — should succeed
|
postReq := httptest.NewRequestWithContext(
|
||||||
var called bool
|
context.Background(),
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
http.MethodPost, "/form", nil,
|
||||||
called = true
|
)
|
||||||
}))
|
called, code := csrfPostWithToken(
|
||||||
|
t, csrfMW, postReq, token, cookies,
|
||||||
|
)
|
||||||
|
|
||||||
form := url.Values{"csrf_token": {token}}
|
assert.True(t, called,
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
"handler should be called -- prod mode over "+
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
"plaintext HTTP must work")
|
||||||
for _, c := range cookies {
|
assert.NotEqual(t, http.StatusForbidden, code,
|
||||||
postReq.AddCookie(c)
|
"should not return 403")
|
||||||
}
|
|
||||||
postW := httptest.NewRecorder()
|
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
|
||||||
|
|
||||||
assert.True(t, called, "handler should be called — prod mode over plaintext HTTP must work")
|
|
||||||
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Production Mode: POST with X-Forwarded-Proto (reverse proxy) ---
|
// --- Production Mode: POST with X-Forwarded-Proto ---
|
||||||
|
|
||||||
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(t *testing.T) {
|
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||||
|
csrfMW := m.CSRF()
|
||||||
|
|
||||||
// Simulates a deployment behind a TLS-terminating reverse proxy.
|
getReq := httptest.NewRequestWithContext(
|
||||||
// The Go server sees HTTP but X-Forwarded-Proto is "https".
|
context.Background(),
|
||||||
var token string
|
http.MethodGet, "http://example.com/form", nil,
|
||||||
csrfMiddleware := m.CSRF()
|
)
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
||||||
token = CSRFToken(r)
|
|
||||||
}))
|
|
||||||
|
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "http://example.com/form", nil)
|
|
||||||
getReq.Header.Set("X-Forwarded-Proto", "https")
|
getReq.Header.Set("X-Forwarded-Proto", "https")
|
||||||
getW := httptest.NewRecorder()
|
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
|
||||||
|
|
||||||
cookies := getW.Result().Cookies()
|
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||||
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
|
||||||
require.NotEmpty(t, token, "CSRF token should be set in context")
|
|
||||||
|
|
||||||
// Verify the cookie IS Secure (X-Forwarded-Proto: https)
|
// Verify cookie IS Secure (X-Forwarded-Proto: https)
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
if c.Name == "_gorilla_csrf" {
|
if c.Name == csrfCookieName {
|
||||||
assert.True(t, c.Secure, "CSRF cookie should be Secure behind TLS proxy")
|
assert.True(t, c.Secure,
|
||||||
|
"CSRF cookie should be Secure behind "+
|
||||||
|
"TLS proxy")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// POST with valid token, HTTPS Origin (as a browser behind proxy would send)
|
postReq := httptest.NewRequestWithContext(
|
||||||
var called bool
|
context.Background(),
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
http.MethodPost, "http://example.com/form", nil,
|
||||||
called = true
|
)
|
||||||
}))
|
|
||||||
|
|
||||||
form := url.Values{"csrf_token": {token}}
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "http://example.com/form", strings.NewReader(form.Encode()))
|
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
postReq.Header.Set("X-Forwarded-Proto", "https")
|
postReq.Header.Set("X-Forwarded-Proto", "https")
|
||||||
postReq.Header.Set("Origin", "https://example.com")
|
postReq.Header.Set("Origin", "https://example.com")
|
||||||
for _, c := range cookies {
|
|
||||||
postReq.AddCookie(c)
|
|
||||||
}
|
|
||||||
postW := httptest.NewRecorder()
|
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
called, code := csrfPostWithToken(
|
||||||
|
t, csrfMW, postReq, token, cookies,
|
||||||
|
)
|
||||||
|
|
||||||
assert.True(t, called, "handler should be called — prod mode behind TLS proxy must work")
|
assert.True(t, called,
|
||||||
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
"handler should be called -- prod mode behind "+
|
||||||
|
"TLS proxy must work")
|
||||||
|
assert.NotEqual(t, http.StatusForbidden, code,
|
||||||
|
"should not return 403")
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Production Mode: direct TLS ---
|
// --- Production Mode: direct TLS ---
|
||||||
|
|
||||||
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(t *testing.T) {
|
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||||
|
csrfMW := m.CSRF()
|
||||||
|
|
||||||
var token string
|
getReq := httptest.NewRequestWithContext(
|
||||||
csrfMiddleware := m.CSRF()
|
context.Background(),
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
http.MethodGet, "https://example.com/form", nil,
|
||||||
token = CSRFToken(r)
|
)
|
||||||
}))
|
|
||||||
|
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "https://example.com/form", nil)
|
|
||||||
getReq.TLS = &tls.ConnectionState{}
|
getReq.TLS = &tls.ConnectionState{}
|
||||||
getW := httptest.NewRecorder()
|
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
|
||||||
|
|
||||||
cookies := getW.Result().Cookies()
|
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||||
require.NotEmpty(t, cookies, "CSRF cookie should be set on GET")
|
|
||||||
require.NotEmpty(t, token, "CSRF token should be set in context")
|
|
||||||
|
|
||||||
// Verify the cookie IS Secure (direct TLS)
|
// Verify cookie IS Secure (direct TLS)
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
if c.Name == "_gorilla_csrf" {
|
if c.Name == csrfCookieName {
|
||||||
assert.True(t, c.Secure, "CSRF cookie should be Secure over direct TLS")
|
assert.True(t, c.Secure,
|
||||||
|
"CSRF cookie should be Secure over "+
|
||||||
|
"direct TLS")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// POST with valid token over direct TLS
|
postReq := httptest.NewRequestWithContext(
|
||||||
var called bool
|
context.Background(),
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
http.MethodPost, "https://example.com/form", nil,
|
||||||
called = true
|
)
|
||||||
}))
|
|
||||||
|
|
||||||
form := url.Values{"csrf_token": {token}}
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "https://example.com/form", strings.NewReader(form.Encode()))
|
|
||||||
postReq.TLS = &tls.ConnectionState{}
|
postReq.TLS = &tls.ConnectionState{}
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
postReq.Header.Set("Origin", "https://example.com")
|
postReq.Header.Set("Origin", "https://example.com")
|
||||||
for _, c := range cookies {
|
|
||||||
postReq.AddCookie(c)
|
|
||||||
}
|
|
||||||
postW := httptest.NewRecorder()
|
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
called, code := csrfPostWithToken(
|
||||||
|
t, csrfMW, postReq, token, cookies,
|
||||||
|
)
|
||||||
|
|
||||||
assert.True(t, called, "handler should be called — direct TLS must work")
|
assert.True(t, called,
|
||||||
assert.NotEqual(t, http.StatusForbidden, postW.Code, "should not return 403")
|
"handler should be called -- direct TLS must work")
|
||||||
|
assert.NotEqual(t, http.StatusForbidden, code,
|
||||||
|
"should not return 403")
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Production Mode: POST without token still rejects ---
|
// --- Production Mode: POST without token still rejects ---
|
||||||
|
|
||||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(t *testing.T) {
|
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
|
||||||
|
|
||||||
csrfMiddleware := m.CSRF()
|
csrfPOSTWithoutTokenTest(
|
||||||
|
t,
|
||||||
// GET to establish the CSRF cookie
|
config.EnvironmentProd,
|
||||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
"handler should NOT be called without CSRF token "+
|
||||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
"even in prod+plaintext",
|
||||||
getW := httptest.NewRecorder()
|
)
|
||||||
getHandler.ServeHTTP(getW, getReq)
|
|
||||||
cookies := getW.Result().Cookies()
|
|
||||||
|
|
||||||
// POST without CSRF token — should be rejected
|
|
||||||
var called bool
|
|
||||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
|
||||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
for _, c := range cookies {
|
|
||||||
postReq.AddCookie(c)
|
|
||||||
}
|
|
||||||
postW := httptest.NewRecorder()
|
|
||||||
|
|
||||||
postHandler.ServeHTTP(postW, postReq)
|
|
||||||
|
|
||||||
assert.False(t, called, "handler should NOT be called without CSRF token even in prod+plaintext")
|
|
||||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
|
||||||
}
|
}
|
||||||
|
|||||||
34
internal/middleware/export_test.go
Normal file
34
internal/middleware/export_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewLoggingResponseWriterForTest wraps newLoggingResponseWriter
|
||||||
|
// for use in external test packages.
|
||||||
|
func NewLoggingResponseWriterForTest(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
) *loggingResponseWriter {
|
||||||
|
return newLoggingResponseWriter(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggingResponseWriterStatusCode returns the status code
|
||||||
|
// captured by the loggingResponseWriter.
|
||||||
|
func LoggingResponseWriterStatusCode(
|
||||||
|
lrw *loggingResponseWriter,
|
||||||
|
) int {
|
||||||
|
return lrw.statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPFromHostPort exposes ipFromHostPort for testing.
|
||||||
|
func IPFromHostPort(hp string) string {
|
||||||
|
return ipFromHostPort(hp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClientTLS exposes isClientTLS for testing.
|
||||||
|
func IsClientTLS(r *http.Request) bool {
|
||||||
|
return isClientTLS(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginRateLimitConst exposes the loginRateLimit constant.
|
||||||
|
const LoginRateLimitConst = loginRateLimit
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
// Package middleware provides HTTP middleware for logging, auth,
|
||||||
|
// CORS, and metrics.
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -19,26 +21,42 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/internal/session"
|
"sneak.berlin/go/webhooker/internal/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // MiddlewareParams is a standard fx naming convention
|
const (
|
||||||
|
// corsMaxAge is the maximum time (in seconds) that a
|
||||||
|
// preflight response can be cached.
|
||||||
|
corsMaxAge = 300
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:revive // MiddlewareParams is a standard fx naming convention.
|
||||||
type MiddlewareParams struct {
|
type MiddlewareParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
Globals *globals.Globals
|
Globals *globals.Globals
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
Session *session.Session
|
Session *session.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Middleware provides HTTP middleware for logging, CORS, auth, and
|
||||||
|
// metrics.
|
||||||
type Middleware struct {
|
type Middleware struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
params *MiddlewareParams
|
params *MiddlewareParams
|
||||||
session *session.Session
|
session *session.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(lc fx.Lifecycle, params MiddlewareParams) (*Middleware, error) {
|
// New creates a Middleware from the provided fx parameters.
|
||||||
|
//
|
||||||
|
//nolint:revive // lc parameter is required by fx even if unused.
|
||||||
|
func New(
|
||||||
|
lc fx.Lifecycle,
|
||||||
|
params MiddlewareParams,
|
||||||
|
) (*Middleware, error) {
|
||||||
s := new(Middleware)
|
s := new(Middleware)
|
||||||
s.params = ¶ms
|
s.params = ¶ms
|
||||||
s.log = params.Logger.Get()
|
s.log = params.Logger.Get()
|
||||||
s.session = params.Session
|
s.session = params.Session
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,19 +68,24 @@ func ipFromHostPort(hp string) string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(h) > 0 && h[0] == '[' {
|
if len(h) > 0 && h[0] == '[' {
|
||||||
return h[1 : len(h)-1]
|
return h[1 : len(h)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
type loggingResponseWriter struct {
|
type loggingResponseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
|
|
||||||
statusCode int
|
statusCode int
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:revive // unexported type is only used internally
|
// newLoggingResponseWriter wraps w and records status codes.
|
||||||
func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
|
func newLoggingResponseWriter(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
) *loggingResponseWriter {
|
||||||
return &loggingResponseWriter{w, http.StatusOK}
|
return &loggingResponseWriter{w, http.StatusOK}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,23 +94,30 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
|||||||
lrw.ResponseWriter.WriteHeader(code)
|
lrw.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// type Middleware func(http.Handler) http.Handler
|
// Logging returns middleware that logs each HTTP request with
|
||||||
// this returns a Middleware that is designed to do every request through the
|
// timing and metadata.
|
||||||
// mux, note the signature:
|
|
||||||
func (s *Middleware) Logging() func(http.Handler) http.Handler {
|
func (s *Middleware) Logging() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
lrw := NewLoggingResponseWriter(w)
|
lrw := newLoggingResponseWriter(w)
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
requestID := ""
|
requestID := ""
|
||||||
if reqID := ctx.Value(middleware.RequestIDKey); reqID != nil {
|
|
||||||
|
if reqID := ctx.Value(
|
||||||
|
middleware.RequestIDKey,
|
||||||
|
); reqID != nil {
|
||||||
if id, ok := reqID.(string); ok {
|
if id, ok := reqID.(string); ok {
|
||||||
requestID = id
|
requestID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.log.Info("http request",
|
s.log.Info("http request",
|
||||||
"request_start", start,
|
"request_start", start,
|
||||||
"method", r.Method,
|
"method", r.Method,
|
||||||
@@ -107,20 +137,29 @@ func (s *Middleware) Logging() func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CORS returns middleware that sets CORS headers (permissive in
|
||||||
|
// dev, no-op in prod).
|
||||||
func (s *Middleware) CORS() func(http.Handler) http.Handler {
|
func (s *Middleware) CORS() func(http.Handler) http.Handler {
|
||||||
if s.params.Config.IsDev() {
|
if s.params.Config.IsDev() {
|
||||||
// In development, allow any origin for local testing.
|
// In development, allow any origin for local testing.
|
||||||
return cors.Handler(cors.Options{
|
return cors.Handler(cors.Options{
|
||||||
AllowedOrigins: []string{"*"},
|
AllowedOrigins: []string{"*"},
|
||||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowedMethods: []string{
|
||||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
"GET", "POST", "PUT", "DELETE", "OPTIONS",
|
||||||
|
},
|
||||||
|
AllowedHeaders: []string{
|
||||||
|
"Accept", "Authorization",
|
||||||
|
"Content-Type", "X-CSRF-Token",
|
||||||
|
},
|
||||||
ExposedHeaders: []string{"Link"},
|
ExposedHeaders: []string{"Link"},
|
||||||
AllowCredentials: false,
|
AllowCredentials: false,
|
||||||
MaxAge: 300,
|
MaxAge: corsMaxAge,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// In production, the web UI is server-rendered so cross-origin
|
|
||||||
// requests are not expected. Return a no-op middleware.
|
// In production, the web UI is server-rendered so
|
||||||
|
// cross-origin requests are not expected. Return a no-op
|
||||||
|
// middleware.
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
@@ -130,20 +169,33 @@ func (s *Middleware) CORS() func(http.Handler) http.Handler {
|
|||||||
// Unauthenticated users are redirected to the login page.
|
// Unauthenticated users are redirected to the login page.
|
||||||
func (s *Middleware) RequireAuth() func(http.Handler) http.Handler {
|
func (s *Middleware) RequireAuth() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) {
|
||||||
sess, err := s.session.Get(r)
|
sess, err := s.session.Get(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Debug("auth middleware: failed to get session", "error", err)
|
s.log.Debug(
|
||||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
"auth middleware: failed to get session",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
http.Redirect(
|
||||||
|
w, r, "/pages/login", http.StatusSeeOther,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.session.IsAuthenticated(sess) {
|
if !s.session.IsAuthenticated(sess) {
|
||||||
s.log.Debug("auth middleware: unauthenticated request",
|
s.log.Debug(
|
||||||
|
"auth middleware: unauthenticated request",
|
||||||
"path", r.URL.Path,
|
"path", r.URL.Path,
|
||||||
"method", r.Method,
|
"method", r.Method,
|
||||||
)
|
)
|
||||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
http.Redirect(
|
||||||
|
w, r, "/pages/login", http.StatusSeeOther,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,15 +204,19 @@ func (s *Middleware) RequireAuth() func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Metrics returns middleware that records Prometheus HTTP metrics.
|
||||||
func (s *Middleware) Metrics() func(http.Handler) http.Handler {
|
func (s *Middleware) Metrics() func(http.Handler) http.Handler {
|
||||||
mdlw := ghmm.New(ghmm.Config{
|
mdlw := ghmm.New(ghmm.Config{
|
||||||
Recorder: metrics.NewRecorder(metrics.Config{}),
|
Recorder: metrics.NewRecorder(metrics.Config{}),
|
||||||
})
|
})
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return std.Handler("", mdlw, next)
|
return std.Handler("", mdlw, next)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MetricsAuth returns middleware that protects metrics endpoints
|
||||||
|
// with basic auth.
|
||||||
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
||||||
return basicauth.New(
|
return basicauth.New(
|
||||||
"metrics",
|
"metrics",
|
||||||
@@ -172,33 +228,63 @@ func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SecurityHeaders returns middleware that sets production security headers
|
// SecurityHeaders returns middleware that sets production security
|
||||||
// on every response: HSTS, X-Content-Type-Options, X-Frame-Options, CSP,
|
// headers on every response: HSTS, X-Content-Type-Options,
|
||||||
// Referrer-Policy, and Permissions-Policy.
|
// X-Frame-Options, CSP, Referrer-Policy, and Permissions-Policy.
|
||||||
func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler {
|
func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(
|
||||||
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
w http.ResponseWriter,
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
r *http.Request,
|
||||||
|
) {
|
||||||
|
w.Header().Set(
|
||||||
|
"Strict-Transport-Security",
|
||||||
|
"max-age=63072000; includeSubDomains; preload",
|
||||||
|
)
|
||||||
|
w.Header().Set(
|
||||||
|
"X-Content-Type-Options", "nosniff",
|
||||||
|
)
|
||||||
w.Header().Set("X-Frame-Options", "DENY")
|
w.Header().Set("X-Frame-Options", "DENY")
|
||||||
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
w.Header().Set(
|
||||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
"Content-Security-Policy",
|
||||||
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
"default-src 'self'; "+
|
||||||
|
"script-src 'self' 'unsafe-inline'; "+
|
||||||
|
"style-src 'self' 'unsafe-inline'",
|
||||||
|
)
|
||||||
|
w.Header().Set(
|
||||||
|
"Referrer-Policy",
|
||||||
|
"strict-origin-when-cross-origin",
|
||||||
|
)
|
||||||
|
w.Header().Set(
|
||||||
|
"Permissions-Policy",
|
||||||
|
"camera=(), microphone=(), geolocation=()",
|
||||||
|
)
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxBodySize returns middleware that limits the request body size for POST
|
// MaxBodySize returns middleware that limits the request body size
|
||||||
// requests. If the body exceeds the given limit in bytes, the server returns
|
// for POST requests. If the body exceeds the given limit in
|
||||||
// 413 Request Entity Too Large. This prevents clients from sending arbitrarily
|
// bytes, the server returns 413 Request Entity Too Large. This
|
||||||
// large form bodies.
|
// prevents clients from sending arbitrarily large form bodies.
|
||||||
func (s *Middleware) MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
func (s *Middleware) MaxBodySize(
|
||||||
|
maxBytes int64,
|
||||||
|
) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(
|
||||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
w http.ResponseWriter,
|
||||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
r *http.Request,
|
||||||
|
) {
|
||||||
|
if r.Method == http.MethodPost ||
|
||||||
|
r.Method == http.MethodPut ||
|
||||||
|
r.Method == http.MethodPatch {
|
||||||
|
r.Body = http.MaxBytesReader(
|
||||||
|
w, r.Body, maxBytes,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,25 +13,37 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/middleware"
|
||||||
"sneak.berlin/go/webhooker/internal/session"
|
"sneak.berlin/go/webhooker/internal/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testMiddleware creates a Middleware with minimal dependencies for testing.
|
const testKeySize = 32
|
||||||
// It uses a real session.Session backed by an in-memory cookie store.
|
|
||||||
func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
|
// testMiddleware creates a Middleware with minimal dependencies
|
||||||
|
// for testing. It uses a real session.Session backed by an
|
||||||
|
// in-memory cookie store.
|
||||||
|
func testMiddleware(
|
||||||
|
t *testing.T,
|
||||||
|
env string,
|
||||||
|
) (*middleware.Middleware, *session.Session) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
log := slog.New(slog.NewTextHandler(
|
||||||
|
os.Stderr,
|
||||||
|
&slog.HandlerOptions{Level: slog.LevelDebug},
|
||||||
|
))
|
||||||
|
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Environment: env,
|
Environment: env,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a real session manager with a known key
|
// Create a real session manager with a known key
|
||||||
key := make([]byte, 32)
|
key := make([]byte, testKeySize)
|
||||||
|
|
||||||
for i := range key {
|
for i := range key {
|
||||||
key[i] = byte(i)
|
key[i] = byte(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := sessions.NewCookieStore(key)
|
store := sessions.NewCookieStore(key)
|
||||||
store.Options = &sessions.Options{
|
store.Options = &sessions.Options{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
@@ -40,40 +53,33 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
|
|||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
sessManager := newTestSession(t, store, cfg, log, key)
|
sessManager := session.NewForTest(store, cfg, log, key)
|
||||||
|
|
||||||
m := &Middleware{
|
m := middleware.NewForTest(log, cfg, sessManager)
|
||||||
log: log,
|
|
||||||
params: &MiddlewareParams{
|
|
||||||
Config: cfg,
|
|
||||||
},
|
|
||||||
session: sessManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, sessManager
|
return m, sessManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTestSession creates a session.Session with a pre-configured cookie store
|
|
||||||
// for testing. This avoids needing the fx lifecycle and database.
|
|
||||||
func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *session.Session {
|
|
||||||
t.Helper()
|
|
||||||
return session.NewForTest(store, cfg, log, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Logging Middleware Tests ---
|
// --- Logging Middleware Tests ---
|
||||||
|
|
||||||
func TestLogging_SetsStatusCode(t *testing.T) {
|
func TestLogging_SetsStatusCode(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.Logging()(http.HandlerFunc(
|
||||||
w.WriteHeader(http.StatusCreated)
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
if _, err := w.Write([]byte("created")); err != nil {
|
w.WriteHeader(http.StatusCreated)
|
||||||
return
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
_, err := w.Write([]byte("created"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/test", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
@@ -84,15 +90,20 @@ func TestLogging_SetsStatusCode(t *testing.T) {
|
|||||||
|
|
||||||
func TestLogging_DefaultStatusOK(t *testing.T) {
|
func TestLogging_DefaultStatusOK(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.Logging()(http.HandlerFunc(
|
||||||
if _, err := w.Write([]byte("ok")); err != nil {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
return
|
_, err := w.Write([]byte("ok"))
|
||||||
}
|
if err != nil {
|
||||||
}))
|
return
|
||||||
|
}
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
@@ -103,20 +114,31 @@ func TestLogging_DefaultStatusOK(t *testing.T) {
|
|||||||
|
|
||||||
func TestLogging_PassesThroughToNext(t *testing.T) {
|
func TestLogging_PassesThroughToNext(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/webhook", nil)
|
handler := m.Logging()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/api/webhook", nil,
|
||||||
|
)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.True(t, called, "logging middleware should call the next handler")
|
assert.True(
|
||||||
|
t, called,
|
||||||
|
"logging middleware should call the next handler",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- LoggingResponseWriter Tests ---
|
// --- LoggingResponseWriter Tests ---
|
||||||
@@ -125,24 +147,33 @@ func TestLoggingResponseWriter_CapturesStatusCode(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
lrw := NewLoggingResponseWriter(w)
|
lrw := middleware.NewLoggingResponseWriterForTest(w)
|
||||||
|
|
||||||
// Default should be 200
|
// Default should be 200
|
||||||
assert.Equal(t, http.StatusOK, lrw.statusCode)
|
assert.Equal(
|
||||||
|
t, http.StatusOK,
|
||||||
|
middleware.LoggingResponseWriterStatusCode(lrw),
|
||||||
|
)
|
||||||
|
|
||||||
// WriteHeader should capture the status code
|
// WriteHeader should capture the status code
|
||||||
lrw.WriteHeader(http.StatusNotFound)
|
lrw.WriteHeader(http.StatusNotFound)
|
||||||
assert.Equal(t, http.StatusNotFound, lrw.statusCode)
|
|
||||||
|
assert.Equal(
|
||||||
|
t, http.StatusNotFound,
|
||||||
|
middleware.LoggingResponseWriterStatusCode(lrw),
|
||||||
|
)
|
||||||
|
|
||||||
// Underlying writer should also get the status code
|
// Underlying writer should also get the status code
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoggingResponseWriter_WriteDelegatesToUnderlying(t *testing.T) {
|
func TestLoggingResponseWriter_WriteDelegatesToUnderlying(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
lrw := NewLoggingResponseWriter(w)
|
lrw := middleware.NewLoggingResponseWriterForTest(w)
|
||||||
|
|
||||||
n, err := lrw.Write([]byte("hello world"))
|
n, err := lrw.Write([]byte("hello world"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -154,79 +185,124 @@ func TestLoggingResponseWriter_WriteDelegatesToUnderlying(t *testing.T) {
|
|||||||
|
|
||||||
func TestCORS_DevMode_AllowsAnyOrigin(t *testing.T) {
|
func TestCORS_DevMode_AllowsAnyOrigin(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.CORS()(http.HandlerFunc(
|
||||||
w.WriteHeader(http.StatusOK)
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
}))
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
// Preflight request
|
// Preflight request
|
||||||
req := httptest.NewRequest(http.MethodOptions, "/api/test", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodOptions, "/api/test", nil,
|
||||||
|
)
|
||||||
req.Header.Set("Origin", "http://localhost:3000")
|
req.Header.Set("Origin", "http://localhost:3000")
|
||||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
// In dev mode, CORS should allow any origin
|
// In dev mode, CORS should allow any origin
|
||||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
assert.Equal(
|
||||||
|
t, "*",
|
||||||
|
w.Header().Get("Access-Control-Allow-Origin"),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORS_ProdMode_NoOp(t *testing.T) {
|
func TestCORS_ProdMode_NoOp(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
handler := m.CORS()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/api/test", nil,
|
||||||
|
)
|
||||||
req.Header.Set("Origin", "http://evil.com")
|
req.Header.Set("Origin", "http://evil.com")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.True(t, called, "prod CORS middleware should pass through to handler")
|
assert.True(
|
||||||
|
t, called,
|
||||||
|
"prod CORS middleware should pass through to handler",
|
||||||
|
)
|
||||||
// In prod, no CORS headers should be set (no-op middleware)
|
// In prod, no CORS headers should be set (no-op middleware)
|
||||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
|
assert.Empty(
|
||||||
"prod mode should not set CORS headers")
|
t,
|
||||||
|
w.Header().Get("Access-Control-Allow-Origin"),
|
||||||
|
"prod mode should not set CORS headers",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- RequireAuth Middleware Tests ---
|
// --- RequireAuth Middleware Tests ---
|
||||||
|
|
||||||
func TestRequireAuth_NoSession_RedirectsToLogin(t *testing.T) {
|
func TestRequireAuth_NoSession_RedirectsToLogin(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
handler := m.RequireAuth()(http.HandlerFunc(
|
||||||
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/dashboard", nil,
|
||||||
|
)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.False(t, called, "handler should not be called for unauthenticated request")
|
assert.False(
|
||||||
|
t, called,
|
||||||
|
"handler should not be called for "+
|
||||||
|
"unauthenticated request",
|
||||||
|
)
|
||||||
assert.Equal(t, http.StatusSeeOther, w.Code)
|
assert.Equal(t, http.StatusSeeOther, w.Code)
|
||||||
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequireAuth_AuthenticatedSession_PassesThrough(t *testing.T) {
|
func TestRequireAuth_AuthenticatedSession_PassesThrough(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, sessManager := testMiddleware(t, config.EnvironmentDev)
|
m, sessManager := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Create an authenticated session by making a request, setting session data,
|
handler := m.RequireAuth()(http.HandlerFunc(
|
||||||
// and saving the session cookie
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
setupReq := httptest.NewRequest(http.MethodGet, "/setup", nil)
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Create an authenticated session by making a request,
|
||||||
|
// setting session data, and saving the session cookie
|
||||||
|
setupReq := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/setup", nil,
|
||||||
|
)
|
||||||
setupW := httptest.NewRecorder()
|
setupW := httptest.NewRecorder()
|
||||||
|
|
||||||
sess, err := sessManager.Get(setupReq)
|
sess, err := sessManager.Get(setupReq)
|
||||||
@@ -239,47 +315,74 @@ func TestRequireAuth_AuthenticatedSession_PassesThrough(t *testing.T) {
|
|||||||
require.NotEmpty(t, cookies, "session cookie should be set")
|
require.NotEmpty(t, cookies, "session cookie should be set")
|
||||||
|
|
||||||
// Make the actual request with the session cookie
|
// Make the actual request with the session cookie
|
||||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/dashboard", nil,
|
||||||
|
)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
req.AddCookie(c)
|
req.AddCookie(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.True(t, called, "handler should be called for authenticated request")
|
assert.True(
|
||||||
|
t, called,
|
||||||
|
"handler should be called for authenticated request",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequireAuth_UnauthenticatedSession_RedirectsToLogin(t *testing.T) {
|
func TestRequireAuth_UnauthenticatedSession_RedirectsToLogin(
|
||||||
|
t *testing.T,
|
||||||
|
) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, sessManager := testMiddleware(t, config.EnvironmentDev)
|
m, sessManager := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
handler := m.RequireAuth()(http.HandlerFunc(
|
||||||
}))
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
// Create a session but don't authenticate it
|
// Create a session but don't authenticate it
|
||||||
setupReq := httptest.NewRequest(http.MethodGet, "/setup", nil)
|
setupReq := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/setup", nil,
|
||||||
|
)
|
||||||
setupW := httptest.NewRecorder()
|
setupW := httptest.NewRecorder()
|
||||||
|
|
||||||
sess, err := sessManager.Get(setupReq)
|
sess, err := sessManager.Get(setupReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Don't call SetUser — session exists but is not authenticated
|
// Don't call SetUser -- session exists but is not
|
||||||
|
// authenticated
|
||||||
require.NoError(t, sessManager.Save(setupReq, setupW, sess))
|
require.NoError(t, sessManager.Save(setupReq, setupW, sess))
|
||||||
|
|
||||||
cookies := setupW.Result().Cookies()
|
cookies := setupW.Result().Cookies()
|
||||||
require.NotEmpty(t, cookies)
|
require.NotEmpty(t, cookies)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/dashboard", nil,
|
||||||
|
)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
req.AddCookie(c)
|
req.AddCookie(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.False(t, called, "handler should not be called for unauthenticated session")
|
assert.False(
|
||||||
|
t, called,
|
||||||
|
"handler should not be called for "+
|
||||||
|
"unauthenticated session",
|
||||||
|
)
|
||||||
assert.Equal(t, http.StatusSeeOther, w.Code)
|
assert.Equal(t, http.StatusSeeOther, w.Code)
|
||||||
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
||||||
}
|
}
|
||||||
@@ -304,7 +407,9 @@ func TestIpFromHostPort(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
result := ipFromHostPort(tt.input)
|
|
||||||
|
result := middleware.IPFromHostPort(tt.input)
|
||||||
|
|
||||||
assert.Equal(t, tt.expected, result)
|
assert.Equal(t, tt.expected, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -312,122 +417,124 @@ func TestIpFromHostPort(t *testing.T) {
|
|||||||
|
|
||||||
// --- MetricsAuth Tests ---
|
// --- MetricsAuth Tests ---
|
||||||
|
|
||||||
func TestMetricsAuth_ValidCredentials(t *testing.T) {
|
// metricsAuthMiddleware creates a Middleware configured for
|
||||||
t.Parallel()
|
// metrics auth testing. This helper de-duplicates the setup in
|
||||||
|
// metrics auth test functions.
|
||||||
|
func metricsAuthMiddleware(
|
||||||
|
t *testing.T,
|
||||||
|
) *middleware.Middleware {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
log := slog.New(slog.NewTextHandler(
|
||||||
|
os.Stderr,
|
||||||
|
&slog.HandlerOptions{Level: slog.LevelDebug},
|
||||||
|
))
|
||||||
|
|
||||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Environment: config.EnvironmentDev,
|
Environment: config.EnvironmentDev,
|
||||||
MetricsUsername: "admin",
|
MetricsUsername: "admin",
|
||||||
MetricsPassword: "secret",
|
MetricsPassword: "secret",
|
||||||
}
|
}
|
||||||
|
|
||||||
key := make([]byte, 32)
|
key := make([]byte, testKeySize)
|
||||||
store := sessions.NewCookieStore(key)
|
store := sessions.NewCookieStore(key)
|
||||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||||
|
|
||||||
sessManager := session.NewForTest(store, cfg, log, key)
|
sessManager := session.NewForTest(store, cfg, log, key)
|
||||||
|
|
||||||
m := &Middleware{
|
return middleware.NewForTest(log, cfg, sessManager)
|
||||||
log: log,
|
}
|
||||||
params: &MiddlewareParams{
|
|
||||||
Config: cfg,
|
func TestMetricsAuth_ValidCredentials(t *testing.T) {
|
||||||
},
|
t.Parallel()
|
||||||
session: sessManager,
|
|
||||||
}
|
m := metricsAuthMiddleware(t)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
handler := m.MetricsAuth()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/metrics", nil,
|
||||||
|
)
|
||||||
req.SetBasicAuth("admin", "secret")
|
req.SetBasicAuth("admin", "secret")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.True(t, called, "handler should be called with valid basic auth")
|
assert.True(
|
||||||
|
t, called,
|
||||||
|
"handler should be called with valid basic auth",
|
||||||
|
)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMetricsAuth_InvalidCredentials(t *testing.T) {
|
func TestMetricsAuth_InvalidCredentials(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
m := metricsAuthMiddleware(t)
|
||||||
cfg := &config.Config{
|
|
||||||
Environment: config.EnvironmentDev,
|
|
||||||
MetricsUsername: "admin",
|
|
||||||
MetricsPassword: "secret",
|
|
||||||
}
|
|
||||||
|
|
||||||
key := make([]byte, 32)
|
|
||||||
store := sessions.NewCookieStore(key)
|
|
||||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
|
||||||
|
|
||||||
sessManager := session.NewForTest(store, cfg, log, key)
|
|
||||||
|
|
||||||
m := &Middleware{
|
|
||||||
log: log,
|
|
||||||
params: &MiddlewareParams{
|
|
||||||
Config: cfg,
|
|
||||||
},
|
|
||||||
session: sessManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
handler := m.MetricsAuth()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/metrics", nil,
|
||||||
|
)
|
||||||
req.SetBasicAuth("admin", "wrong-password")
|
req.SetBasicAuth("admin", "wrong-password")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.False(t, called, "handler should not be called with invalid basic auth")
|
assert.False(
|
||||||
|
t, called,
|
||||||
|
"handler should not be called with invalid basic auth",
|
||||||
|
)
|
||||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMetricsAuth_NoCredentials(t *testing.T) {
|
func TestMetricsAuth_NoCredentials(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
m := metricsAuthMiddleware(t)
|
||||||
cfg := &config.Config{
|
|
||||||
Environment: config.EnvironmentDev,
|
|
||||||
MetricsUsername: "admin",
|
|
||||||
MetricsPassword: "secret",
|
|
||||||
}
|
|
||||||
|
|
||||||
key := make([]byte, 32)
|
|
||||||
store := sessions.NewCookieStore(key)
|
|
||||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
|
||||||
|
|
||||||
sessManager := session.NewForTest(store, cfg, log, key)
|
|
||||||
|
|
||||||
m := &Middleware{
|
|
||||||
log: log,
|
|
||||||
params: &MiddlewareParams{
|
|
||||||
Config: cfg,
|
|
||||||
},
|
|
||||||
session: sessManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
called = true
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
handler := m.MetricsAuth()(http.HandlerFunc(
|
||||||
|
func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/metrics", nil,
|
||||||
|
)
|
||||||
// No basic auth header
|
// No basic auth header
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.False(t, called, "handler should not be called without credentials")
|
assert.False(
|
||||||
|
t, called,
|
||||||
|
"handler should not be called without credentials",
|
||||||
|
)
|
||||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,16 +542,23 @@ func TestMetricsAuth_NoCredentials(t *testing.T) {
|
|||||||
|
|
||||||
func TestCORS_DevMode_AllowsMethods(t *testing.T) {
|
func TestCORS_DevMode_AllowsMethods(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.CORS()(http.HandlerFunc(
|
||||||
w.WriteHeader(http.StatusOK)
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
}))
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
// Preflight for POST
|
// Preflight for POST
|
||||||
req := httptest.NewRequest(http.MethodOptions, "/api/webhooks", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodOptions, "/api/webhooks", nil,
|
||||||
|
)
|
||||||
req.Header.Set("Origin", "http://localhost:5173")
|
req.Header.Set("Origin", "http://localhost:5173")
|
||||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
@@ -458,14 +572,17 @@ func TestCORS_DevMode_AllowsMethods(t *testing.T) {
|
|||||||
func TestSessionKeyFormat(t *testing.T) {
|
func TestSessionKeyFormat(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// Verify that the session initialization correctly validates key format.
|
// Verify that the session initialization correctly validates
|
||||||
// A proper 32-byte key encoded as base64 should work.
|
// key format. A proper 32-byte key encoded as base64 should
|
||||||
key := make([]byte, 32)
|
// work.
|
||||||
|
key := make([]byte, testKeySize)
|
||||||
|
|
||||||
for i := range key {
|
for i := range key {
|
||||||
key[i] = byte(i + 1)
|
key[i] = byte(i + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(key)
|
encoded := base64.StdEncoding.EncodeToString(key)
|
||||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, decoded, 32)
|
assert.Len(t, decoded, testKeySize)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,40 +8,56 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// loginRateLimit is the maximum number of login attempts per interval.
|
// loginRateLimit is the maximum number of login attempts
|
||||||
|
// per interval.
|
||||||
loginRateLimit = 5
|
loginRateLimit = 5
|
||||||
|
|
||||||
// loginRateInterval is the time window for the rate limit.
|
// loginRateInterval is the time window for the rate limit.
|
||||||
loginRateInterval = 1 * time.Minute
|
loginRateInterval = 1 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoginRateLimit returns middleware that enforces per-IP rate limiting
|
// LoginRateLimit returns middleware that enforces per-IP rate
|
||||||
// on login attempts using go-chi/httprate. Only POST requests are
|
// limiting on login attempts using go-chi/httprate. Only POST
|
||||||
// rate-limited; GET requests (rendering the login form) pass through
|
// requests are rate-limited; GET requests (rendering the login
|
||||||
// unaffected. When the rate limit is exceeded, a 429 Too Many Requests
|
// form) pass through unaffected. When the rate limit is exceeded,
|
||||||
// response is returned. IP extraction honours X-Forwarded-For,
|
// a 429 Too Many Requests response is returned. IP extraction
|
||||||
// X-Real-IP, and True-Client-IP headers for reverse-proxy setups.
|
// honours X-Forwarded-For, X-Real-IP, and True-Client-IP headers
|
||||||
|
// for reverse-proxy setups.
|
||||||
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||||
limiter := httprate.Limit(
|
limiter := httprate.Limit(
|
||||||
loginRateLimit,
|
loginRateLimit,
|
||||||
loginRateInterval,
|
loginRateInterval,
|
||||||
httprate.WithKeyFuncs(httprate.KeyByRealIP),
|
httprate.WithKeyFuncs(httprate.KeyByRealIP),
|
||||||
httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httprate.WithLimitHandler(http.HandlerFunc(
|
||||||
m.log.Warn("login rate limit exceeded",
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
"path", r.URL.Path,
|
m.log.Warn("login rate limit exceeded",
|
||||||
)
|
"path", r.URL.Path,
|
||||||
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
)
|
||||||
})),
|
http.Error(
|
||||||
|
w,
|
||||||
|
"Too many login attempts. "+
|
||||||
|
"Please try again later.",
|
||||||
|
http.StatusTooManyRequests,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
limited := limiter(next)
|
limited := limiter(next)
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Only rate-limit POST requests (actual login attempts)
|
return http.HandlerFunc(func(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) {
|
||||||
|
// Only rate-limit POST requests (actual login
|
||||||
|
// attempts)
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
limited.ServeHTTP(w, r)
|
limited.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,90 +1,147 @@
|
|||||||
package middleware
|
package middleware_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoginRateLimit_AllowsGET(t *testing.T) {
|
func TestLoginRateLimit_AllowsGET(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var callCount int
|
var callCount int
|
||||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
callCount++
|
handler := m.LoginRateLimit()(http.HandlerFunc(
|
||||||
w.WriteHeader(http.StatusOK)
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
}))
|
callCount++
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
// GET requests should never be rate-limited
|
// GET requests should never be rate-limited
|
||||||
for i := 0; i < 20; i++ {
|
for i := range 20 {
|
||||||
req := httptest.NewRequest(http.MethodGet, "/pages/login", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/pages/login", nil,
|
||||||
|
)
|
||||||
req.RemoteAddr = "192.168.1.1:12345"
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code, "GET request %d should pass", i)
|
|
||||||
|
assert.Equal(
|
||||||
|
t, http.StatusOK, w.Code,
|
||||||
|
"GET request %d should pass", i,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, 20, callCount)
|
assert.Equal(t, 20, callCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoginRateLimit_LimitsPOST(t *testing.T) {
|
func TestLoginRateLimit_LimitsPOST(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
var callCount int
|
var callCount int
|
||||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
callCount++
|
handler := m.LoginRateLimit()(http.HandlerFunc(
|
||||||
w.WriteHeader(http.StatusOK)
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
}))
|
callCount++
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
// First loginRateLimit POST requests should succeed
|
// First loginRateLimit POST requests should succeed
|
||||||
for i := 0; i < loginRateLimit; i++ {
|
for i := range middleware.LoginRateLimitConst {
|
||||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/pages/login", nil,
|
||||||
|
)
|
||||||
req.RemoteAddr = "10.0.0.1:12345"
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code, "POST request %d should pass", i)
|
|
||||||
|
assert.Equal(
|
||||||
|
t, http.StatusOK, w.Code,
|
||||||
|
"POST request %d should pass", i,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next POST should be rate-limited
|
// Next POST should be rate-limited
|
||||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/pages/login", nil,
|
||||||
|
)
|
||||||
req.RemoteAddr = "10.0.0.1:12345"
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusTooManyRequests, w.Code, "POST after limit should be 429")
|
|
||||||
assert.Equal(t, loginRateLimit, callCount)
|
assert.Equal(
|
||||||
|
t, http.StatusTooManyRequests, w.Code,
|
||||||
|
"POST after limit should be 429",
|
||||||
|
)
|
||||||
|
assert.Equal(t, middleware.LoginRateLimitConst, callCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
|
func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||||
|
|
||||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.LoginRateLimit()(http.HandlerFunc(
|
||||||
w.WriteHeader(http.StatusOK)
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
}))
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
// Exhaust limit for IP1
|
// Exhaust limit for IP1
|
||||||
for i := 0; i < loginRateLimit; i++ {
|
for range middleware.LoginRateLimitConst {
|
||||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/pages/login", nil,
|
||||||
|
)
|
||||||
req.RemoteAddr = "1.2.3.4:12345"
|
req.RemoteAddr = "1.2.3.4:12345"
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IP1 should be rate-limited
|
// IP1 should be rate-limited
|
||||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/pages/login", nil,
|
||||||
|
)
|
||||||
req.RemoteAddr = "1.2.3.4:12345"
|
req.RemoteAddr = "1.2.3.4:12345"
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
|
|
||||||
// IP2 should still be allowed
|
// IP2 should still be allowed
|
||||||
req2 := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
req2 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodPost, "/pages/login", nil,
|
||||||
|
)
|
||||||
req2.RemoteAddr = "5.6.7.8:12345"
|
req2.RemoteAddr = "5.6.7.8:12345"
|
||||||
|
|
||||||
w2 := httptest.NewRecorder()
|
w2 := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w2, req2)
|
handler.ServeHTTP(w2, req2)
|
||||||
assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected")
|
|
||||||
|
assert.Equal(
|
||||||
|
t, http.StatusOK, w2.Code,
|
||||||
|
"different IP should not be affected",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
24
internal/middleware/testing.go
Normal file
24
internal/middleware/testing.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewForTest creates a Middleware with the minimum dependencies
|
||||||
|
// needed for testing. This bypasses the fx lifecycle.
|
||||||
|
func NewForTest(
|
||||||
|
log *slog.Logger,
|
||||||
|
cfg *config.Config,
|
||||||
|
sess *session.Session,
|
||||||
|
) *Middleware {
|
||||||
|
return &Middleware{
|
||||||
|
log: log,
|
||||||
|
params: &MiddlewareParams{
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
session: sess,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,18 +1,33 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// httpReadTimeout is the maximum duration for reading the
|
||||||
|
// entire request, including the body.
|
||||||
|
httpReadTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// httpWriteTimeout is the maximum duration before timing out
|
||||||
|
// writes of the response.
|
||||||
|
httpWriteTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// httpMaxHeaderBytes is the maximum number of bytes the
|
||||||
|
// server will read parsing the request headers.
|
||||||
|
httpMaxHeaderBytes = 1 << 20
|
||||||
|
)
|
||||||
|
|
||||||
func (s *Server) serveUntilShutdown() {
|
func (s *Server) serveUntilShutdown() {
|
||||||
listenAddr := fmt.Sprintf(":%d", s.params.Config.Port)
|
listenAddr := fmt.Sprintf(":%d", s.params.Config.Port)
|
||||||
s.httpServer = &http.Server{
|
s.httpServer = &http.Server{
|
||||||
Addr: listenAddr,
|
Addr: listenAddr,
|
||||||
ReadTimeout: 10 * time.Second,
|
ReadTimeout: httpReadTimeout,
|
||||||
WriteTimeout: 10 * time.Second,
|
WriteTimeout: httpWriteTimeout,
|
||||||
MaxHeaderBytes: 1 << 20,
|
MaxHeaderBytes: httpMaxHeaderBytes,
|
||||||
Handler: s,
|
Handler: s,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,14 +36,21 @@ func (s *Server) serveUntilShutdown() {
|
|||||||
s.SetupRoutes()
|
s.SetupRoutes()
|
||||||
|
|
||||||
s.log.Info("http begin listen", "listenaddr", listenAddr)
|
s.log.Info("http begin listen", "listenaddr", listenAddr)
|
||||||
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
|
err := s.httpServer.ListenAndServe()
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
s.log.Error("listen error", "error", err)
|
s.log.Error("listen error", "error", err)
|
||||||
|
|
||||||
if s.cancelFunc != nil {
|
if s.cancelFunc != nil {
|
||||||
s.cancelFunc()
|
s.cancelFunc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
// ServeHTTP delegates to the router.
|
||||||
|
func (s *Server) ServeHTTP(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) {
|
||||||
s.router.ServeHTTP(w, r)
|
s.router.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,15 +11,24 @@ import (
|
|||||||
"sneak.berlin/go/webhooker/static"
|
"sneak.berlin/go/webhooker/static"
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxFormBodySize is the maximum allowed request body size (in bytes) for
|
// maxFormBodySize is the maximum allowed request body size (in
|
||||||
// form POST endpoints. 1 MB is generous for any form submission while
|
// bytes) for form POST endpoints. 1 MB is generous for any form
|
||||||
// preventing abuse from oversized payloads.
|
// submission while preventing abuse from oversized payloads.
|
||||||
const maxFormBodySize int64 = 1 * 1024 * 1024 // 1 MB
|
const maxFormBodySize int64 = 1 * 1024 * 1024 // 1 MB
|
||||||
|
|
||||||
|
// requestTimeout is the maximum time allowed for a single HTTP
|
||||||
|
// request.
|
||||||
|
const requestTimeout = 60 * time.Second
|
||||||
|
|
||||||
|
// SetupRoutes configures all HTTP routes and middleware on the
|
||||||
|
// server's router.
|
||||||
func (s *Server) SetupRoutes() {
|
func (s *Server) SetupRoutes() {
|
||||||
s.router = chi.NewRouter()
|
s.router = chi.NewRouter()
|
||||||
|
s.setupGlobalMiddleware()
|
||||||
|
s.setupRoutes()
|
||||||
|
}
|
||||||
|
|
||||||
// Global middleware stack — applied to every request.
|
func (s *Server) setupGlobalMiddleware() {
|
||||||
s.router.Use(middleware.Recoverer)
|
s.router.Use(middleware.Recoverer)
|
||||||
s.router.Use(middleware.RequestID)
|
s.router.Use(middleware.RequestID)
|
||||||
s.router.Use(s.mw.SecurityHeaders())
|
s.router.Use(s.mw.SecurityHeaders())
|
||||||
@@ -31,24 +40,28 @@ func (s *Server) SetupRoutes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.router.Use(s.mw.CORS())
|
s.router.Use(s.mw.CORS())
|
||||||
s.router.Use(middleware.Timeout(60 * time.Second))
|
s.router.Use(middleware.Timeout(requestTimeout))
|
||||||
|
|
||||||
// Sentry error reporting (if SENTRY_DSN is set). Repanic is true
|
// Sentry error reporting (if SENTRY_DSN is set). Repanic is
|
||||||
// so panics still bubble up to the Recoverer middleware above.
|
// true so panics still bubble up to the Recoverer middleware.
|
||||||
if s.sentryEnabled {
|
if s.sentryEnabled {
|
||||||
sentryHandler := sentryhttp.New(sentryhttp.Options{
|
sentryHandler := sentryhttp.New(sentryhttp.Options{
|
||||||
Repanic: true,
|
Repanic: true,
|
||||||
})
|
})
|
||||||
s.router.Use(sentryHandler.Handle)
|
s.router.Use(sentryHandler.Handle)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Routes
|
func (s *Server) setupRoutes() {
|
||||||
s.router.Get("/", s.h.HandleIndex())
|
s.router.Get("/", s.h.HandleIndex())
|
||||||
|
|
||||||
s.router.Mount("/s", http.StripPrefix("/s", http.FileServer(http.FS(static.Static))))
|
s.router.Mount(
|
||||||
|
"/s",
|
||||||
|
http.StripPrefix("/s", http.FileServer(http.FS(static.Static))),
|
||||||
|
)
|
||||||
|
|
||||||
s.router.Route("/api/v1", func(_ chi.Router) {
|
s.router.Route("/api/v1", func(_ chi.Router) {
|
||||||
// TODO: Add API routes here
|
// API routes will be added here.
|
||||||
})
|
})
|
||||||
|
|
||||||
s.router.Get(
|
s.router.Get(
|
||||||
@@ -60,62 +73,89 @@ func (s *Server) SetupRoutes() {
|
|||||||
if s.params.Config.MetricsUsername != "" {
|
if s.params.Config.MetricsUsername != "" {
|
||||||
s.router.Group(func(r chi.Router) {
|
s.router.Group(func(r chi.Router) {
|
||||||
r.Use(s.mw.MetricsAuth())
|
r.Use(s.mw.MetricsAuth())
|
||||||
r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP))
|
r.Get(
|
||||||
|
"/metrics",
|
||||||
|
http.HandlerFunc(
|
||||||
|
promhttp.Handler().ServeHTTP,
|
||||||
|
),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// pages that are rendered server-side — CSRF-protected, body-size
|
s.setupPageRoutes()
|
||||||
// limited, and with per-IP rate limiting on the login endpoint.
|
s.setupUserRoutes()
|
||||||
|
s.setupSourceRoutes()
|
||||||
|
s.setupWebhookRoutes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) setupPageRoutes() {
|
||||||
s.router.Route("/pages", func(r chi.Router) {
|
s.router.Route("/pages", func(r chi.Router) {
|
||||||
r.Use(s.mw.CSRF())
|
r.Use(s.mw.CSRF())
|
||||||
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
||||||
|
|
||||||
// Login page — rate-limited to prevent brute-force attacks
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(s.mw.LoginRateLimit())
|
r.Use(s.mw.LoginRateLimit())
|
||||||
r.Get("/login", s.h.HandleLoginPage())
|
r.Get("/login", s.h.HandleLoginPage())
|
||||||
r.Post("/login", s.h.HandleLoginSubmit())
|
r.Post("/login", s.h.HandleLoginSubmit())
|
||||||
})
|
})
|
||||||
|
|
||||||
// Logout (auth required)
|
|
||||||
r.Post("/logout", s.h.HandleLogout())
|
r.Post("/logout", s.h.HandleLogout())
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// User profile routes
|
func (s *Server) setupUserRoutes() {
|
||||||
s.router.Route("/user/{username}", func(r chi.Router) {
|
s.router.Route("/user/{username}", func(r chi.Router) {
|
||||||
r.Use(s.mw.CSRF())
|
r.Use(s.mw.CSRF())
|
||||||
r.Get("/", s.h.HandleProfile())
|
r.Get("/", s.h.HandleProfile())
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Webhook management routes (require authentication, CSRF-protected)
|
func (s *Server) setupSourceRoutes() {
|
||||||
s.router.Route("/sources", func(r chi.Router) {
|
s.router.Route("/sources", func(r chi.Router) {
|
||||||
r.Use(s.mw.CSRF())
|
r.Use(s.mw.CSRF())
|
||||||
r.Use(s.mw.RequireAuth())
|
r.Use(s.mw.RequireAuth())
|
||||||
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
||||||
r.Get("/", s.h.HandleSourceList()) // List all webhooks
|
r.Get("/", s.h.HandleSourceList())
|
||||||
r.Get("/new", s.h.HandleSourceCreate()) // Show create form
|
r.Get("/new", s.h.HandleSourceCreate())
|
||||||
r.Post("/new", s.h.HandleSourceCreateSubmit()) // Handle create submission
|
r.Post("/new", s.h.HandleSourceCreateSubmit())
|
||||||
})
|
})
|
||||||
|
|
||||||
s.router.Route("/source/{sourceID}", func(r chi.Router) {
|
s.router.Route("/source/{sourceID}", func(r chi.Router) {
|
||||||
r.Use(s.mw.CSRF())
|
r.Use(s.mw.CSRF())
|
||||||
r.Use(s.mw.RequireAuth())
|
r.Use(s.mw.RequireAuth())
|
||||||
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
||||||
r.Get("/", s.h.HandleSourceDetail()) // View webhook details
|
r.Get("/", s.h.HandleSourceDetail())
|
||||||
r.Get("/edit", s.h.HandleSourceEdit()) // Show edit form
|
r.Get("/edit", s.h.HandleSourceEdit())
|
||||||
r.Post("/edit", s.h.HandleSourceEditSubmit()) // Handle edit submission
|
r.Post("/edit", s.h.HandleSourceEditSubmit())
|
||||||
r.Post("/delete", s.h.HandleSourceDelete()) // Delete webhook
|
r.Post("/delete", s.h.HandleSourceDelete())
|
||||||
r.Get("/logs", s.h.HandleSourceLogs()) // View webhook logs
|
r.Get("/logs", s.h.HandleSourceLogs())
|
||||||
r.Post("/entrypoints", s.h.HandleEntrypointCreate()) // Add entrypoint
|
r.Post(
|
||||||
r.Post("/entrypoints/{entrypointID}/delete", s.h.HandleEntrypointDelete()) // Delete entrypoint
|
"/entrypoints",
|
||||||
r.Post("/entrypoints/{entrypointID}/toggle", s.h.HandleEntrypointToggle()) // Toggle entrypoint active
|
s.h.HandleEntrypointCreate(),
|
||||||
r.Post("/targets", s.h.HandleTargetCreate()) // Add target
|
)
|
||||||
r.Post("/targets/{targetID}/delete", s.h.HandleTargetDelete()) // Delete target
|
r.Post(
|
||||||
r.Post("/targets/{targetID}/toggle", s.h.HandleTargetToggle()) // Toggle target active
|
"/entrypoints/{entrypointID}/delete",
|
||||||
|
s.h.HandleEntrypointDelete(),
|
||||||
|
)
|
||||||
|
r.Post(
|
||||||
|
"/entrypoints/{entrypointID}/toggle",
|
||||||
|
s.h.HandleEntrypointToggle(),
|
||||||
|
)
|
||||||
|
r.Post("/targets", s.h.HandleTargetCreate())
|
||||||
|
r.Post(
|
||||||
|
"/targets/{targetID}/delete",
|
||||||
|
s.h.HandleTargetDelete(),
|
||||||
|
)
|
||||||
|
r.Post(
|
||||||
|
"/targets/{targetID}/toggle",
|
||||||
|
s.h.HandleTargetToggle(),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
}
|
||||||
// Entrypoint endpoint — accepts incoming webhook POST requests only.
|
|
||||||
// Using HandleFunc so the handler itself can return 405 for non-POST
|
func (s *Server) setupWebhookRoutes() {
|
||||||
// methods (chi's Method routing returns 405 without Allow header).
|
s.router.HandleFunc(
|
||||||
s.router.HandleFunc("/webhook/{uuid}", s.h.HandleWebhook())
|
"/webhook/{uuid}",
|
||||||
|
s.h.HandleWebhook(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
// Package server wires up HTTP routes and manages the
|
||||||
|
// application lifecycle.
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -21,9 +23,20 @@ import (
|
|||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // ServerParams is a standard fx naming convention
|
const (
|
||||||
|
// shutdownTimeout is the maximum time to wait for the HTTP
|
||||||
|
// server to finish in-flight requests during shutdown.
|
||||||
|
shutdownTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// sentryFlushTimeout is the maximum time to wait for Sentry
|
||||||
|
// to flush pending events during shutdown.
|
||||||
|
sentryFlushTimeout = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:revive // ServerParams is a standard fx naming convention.
|
||||||
type ServerParams struct {
|
type ServerParams struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
Globals *globals.Globals
|
Globals *globals.Globals
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
@@ -31,12 +44,13 @@ type ServerParams struct {
|
|||||||
Handlers *handlers.Handlers
|
Handlers *handlers.Handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server is the main HTTP server that wires up routes and manages
|
||||||
|
// graceful shutdown.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
startupTime time.Time
|
startupTime time.Time
|
||||||
exitCode int
|
exitCode int
|
||||||
sentryEnabled bool
|
sentryEnabled bool
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
ctx context.Context
|
|
||||||
cancelFunc context.CancelFunc
|
cancelFunc context.CancelFunc
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
router *chi.Mux
|
router *chi.Mux
|
||||||
@@ -45,6 +59,8 @@ type Server struct {
|
|||||||
h *handlers.Handlers
|
h *handlers.Handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New creates a Server that starts the HTTP listener on fx start
|
||||||
|
// and stops it gracefully.
|
||||||
func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
|
func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
|
||||||
s := new(Server)
|
s := new(Server)
|
||||||
s.params = params
|
s.params = params
|
||||||
@@ -53,19 +69,23 @@ func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
|
|||||||
s.log = params.Logger.Get()
|
s.log = params.Logger.Get()
|
||||||
|
|
||||||
lc.Append(fx.Hook{
|
lc.Append(fx.Hook{
|
||||||
OnStart: func(ctx context.Context) error {
|
OnStart: func(_ context.Context) error {
|
||||||
s.startupTime = time.Now()
|
s.startupTime = time.Now()
|
||||||
go s.Run()
|
go s.Run()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
OnStop: func(ctx context.Context) error {
|
OnStop: func(ctx context.Context) error {
|
||||||
s.cleanShutdown()
|
s.cleanShutdown(ctx)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run configures Sentry and starts serving HTTP requests.
|
||||||
func (s *Server) Run() {
|
func (s *Server) Run() {
|
||||||
s.configure()
|
s.configure()
|
||||||
|
|
||||||
@@ -75,6 +95,12 @@ func (s *Server) Run() {
|
|||||||
s.serve()
|
s.serve()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaintenanceMode returns whether the server is in maintenance
|
||||||
|
// mode.
|
||||||
|
func (s *Server) MaintenanceMode() bool {
|
||||||
|
return s.params.Config.MaintenanceMode
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) enableSentry() {
|
func (s *Server) enableSentry() {
|
||||||
s.sentryEnabled = false
|
s.sentryEnabled = false
|
||||||
|
|
||||||
@@ -83,29 +109,37 @@ func (s *Server) enableSentry() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := sentry.Init(sentry.ClientOptions{
|
err := sentry.Init(sentry.ClientOptions{
|
||||||
Dsn: s.params.Config.SentryDSN,
|
Dsn: s.params.Config.SentryDSN,
|
||||||
Release: fmt.Sprintf("%s-%s", s.params.Globals.Appname, s.params.Globals.Version),
|
Release: fmt.Sprintf(
|
||||||
|
"%s-%s",
|
||||||
|
s.params.Globals.Appname,
|
||||||
|
s.params.Globals.Version,
|
||||||
|
),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("sentry init failure", "error", err)
|
s.log.Error("sentry init failure", "error", err)
|
||||||
// Don't use fatal since we still want the service to run
|
// Don't use fatal since we still want the service to run
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.log.Info("sentry error reporting activated")
|
s.log.Info("sentry error reporting activated")
|
||||||
s.sentryEnabled = true
|
s.sentryEnabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) serve() int {
|
func (s *Server) serve() int {
|
||||||
s.ctx, s.cancelFunc = context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
s.cancelFunc = cancelFunc
|
||||||
|
|
||||||
// signal watcher
|
// signal watcher
|
||||||
go func() {
|
go func() {
|
||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
|
|
||||||
signal.Ignore(syscall.SIGPIPE)
|
signal.Ignore(syscall.SIGPIPE)
|
||||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||||
// block and wait for signal
|
// block and wait for signal
|
||||||
sig := <-c
|
sig := <-c
|
||||||
s.log.Info("signal received", "signal", sig.String())
|
s.log.Info("signal received", "signal", sig.String())
|
||||||
|
|
||||||
if s.cancelFunc != nil {
|
if s.cancelFunc != nil {
|
||||||
// cancelling the main context will trigger a clean
|
// cancelling the main context will trigger a clean
|
||||||
// shutdown via the fx OnStop hook.
|
// shutdown via the fx OnStop hook.
|
||||||
@@ -115,9 +149,9 @@ func (s *Server) serve() int {
|
|||||||
|
|
||||||
go s.serveUntilShutdown()
|
go s.serveUntilShutdown()
|
||||||
|
|
||||||
<-s.ctx.Done()
|
<-ctx.Done()
|
||||||
// Shutdown is handled by the fx OnStop hook (cleanShutdown).
|
// Shutdown is handled by the fx OnStop hook (cleanShutdown).
|
||||||
// Do not call cleanShutdown() here to avoid a double invocation.
|
// Do not call cleanShutdown() here to avoid double invocation.
|
||||||
return s.exitCode
|
return s.exitCode
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,27 +159,29 @@ func (s *Server) cleanupForExit() {
|
|||||||
s.log.Info("cleaning up")
|
s.log.Info("cleaning up")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cleanShutdown() {
|
func (s *Server) cleanShutdown(ctx context.Context) {
|
||||||
// initiate clean shutdown
|
// initiate clean shutdown
|
||||||
s.exitCode = 0
|
s.exitCode = 0
|
||||||
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
|
ctxShutdown, shutdownCancel := context.WithTimeout(
|
||||||
|
ctx, shutdownTimeout,
|
||||||
|
)
|
||||||
defer shutdownCancel()
|
defer shutdownCancel()
|
||||||
|
|
||||||
if err := s.httpServer.Shutdown(ctxShutdown); err != nil {
|
err := s.httpServer.Shutdown(ctxShutdown)
|
||||||
s.log.Error("server clean shutdown failed", "error", err)
|
if err != nil {
|
||||||
|
s.log.Error(
|
||||||
|
"server clean shutdown failed", "error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cleanupForExit()
|
s.cleanupForExit()
|
||||||
|
|
||||||
if s.sentryEnabled {
|
if s.sentryEnabled {
|
||||||
sentry.Flush(2 * time.Second)
|
sentry.Flush(sentryFlushTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) MaintenanceMode() bool {
|
|
||||||
return s.params.Config.MaintenanceMode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) configure() {
|
func (s *Server) configure() {
|
||||||
// identify ourselves in the logs
|
// identify ourselves in the logs
|
||||||
s.params.Logger.Identify()
|
s.params.Logger.Identify()
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
|
// Package session manages HTTP session storage and authentication
|
||||||
|
// state.
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
@@ -15,28 +19,44 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// SessionName is the name of the session cookie
|
// SessionName is the name of the session cookie.
|
||||||
SessionName = "webhooker_session"
|
SessionName = "webhooker_session"
|
||||||
|
|
||||||
// UserIDKey is the session key for user ID
|
// UserIDKey is the session key for user ID.
|
||||||
UserIDKey = "user_id"
|
UserIDKey = "user_id"
|
||||||
|
|
||||||
// UsernameKey is the session key for username
|
// UsernameKey is the session key for username.
|
||||||
UsernameKey = "username"
|
UsernameKey = "username"
|
||||||
|
|
||||||
// AuthenticatedKey is the session key for authentication status
|
// AuthenticatedKey is the session key for authentication
|
||||||
|
// status.
|
||||||
AuthenticatedKey = "authenticated"
|
AuthenticatedKey = "authenticated"
|
||||||
|
|
||||||
|
// sessionKeyLength is the required length in bytes for the
|
||||||
|
// session authentication key.
|
||||||
|
sessionKeyLength = 32
|
||||||
|
|
||||||
|
// sessionMaxAgeDays is the session cookie lifetime in days.
|
||||||
|
sessionMaxAgeDays = 7
|
||||||
|
|
||||||
|
// secondsPerDay is the number of seconds in a day.
|
||||||
|
secondsPerDay = 86400
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:revive // SessionParams is a standard fx naming convention
|
// ErrSessionKeyLength is returned when the decoded session key
|
||||||
type SessionParams struct {
|
// does not have the expected length.
|
||||||
|
var ErrSessionKeyLength = errors.New("session key length mismatch")
|
||||||
|
|
||||||
|
// Params holds dependencies injected by fx.
|
||||||
|
type Params struct {
|
||||||
fx.In
|
fx.In
|
||||||
|
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
Database *database.Database
|
Database *database.Database
|
||||||
Logger *logger.Logger
|
Logger *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session manages encrypted session storage
|
// Session manages encrypted session storage.
|
||||||
type Session struct {
|
type Session struct {
|
||||||
store *sessions.CookieStore
|
store *sessions.CookieStore
|
||||||
key []byte // raw 32-byte auth key, also used for CSRF cookie signing
|
key []byte // raw 32-byte auth key, also used for CSRF cookie signing
|
||||||
@@ -44,29 +64,44 @@ type Session struct {
|
|||||||
config *config.Config
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new session manager. The cookie store is initialized
|
// New creates a new session manager. The cookie store is
|
||||||
// during the fx OnStart phase after the database is connected, using
|
// initialized during the fx OnStart phase after the database is
|
||||||
// a session key that is auto-generated and stored in the database.
|
// connected, using a session key that is auto-generated and stored
|
||||||
func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
|
// in the database.
|
||||||
|
func New(
|
||||||
|
lc fx.Lifecycle,
|
||||||
|
params Params,
|
||||||
|
) (*Session, error) {
|
||||||
s := &Session{
|
s := &Session{
|
||||||
log: params.Logger.Get(),
|
log: params.Logger.Get(),
|
||||||
config: params.Config,
|
config: params.Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
lc.Append(fx.Hook{
|
lc.Append(fx.Hook{
|
||||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
OnStart: func(_ context.Context) error {
|
||||||
sessionKey, err := params.Database.GetOrCreateSessionKey()
|
sessionKey, err := params.Database.GetOrCreateSessionKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get session key: %w", err)
|
return fmt.Errorf(
|
||||||
|
"failed to get session key: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
keyBytes, err := base64.StdEncoding.DecodeString(
|
||||||
|
sessionKey,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid session key format: %w", err)
|
return fmt.Errorf(
|
||||||
|
"invalid session key format: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(keyBytes) != 32 {
|
if len(keyBytes) != sessionKeyLength {
|
||||||
return fmt.Errorf("session key must be 32 bytes (got %d)", len(keyBytes))
|
return fmt.Errorf(
|
||||||
|
"%w: want %d, got %d",
|
||||||
|
ErrSessionKeyLength,
|
||||||
|
sessionKeyLength,
|
||||||
|
len(keyBytes),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := sessions.NewCookieStore(keyBytes)
|
store := sessions.NewCookieStore(keyBytes)
|
||||||
@@ -74,15 +109,16 @@ func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
|
|||||||
// Configure cookie options for security
|
// Configure cookie options for security
|
||||||
store.Options = &sessions.Options{
|
store.Options = &sessions.Options{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: 86400 * 7, // 7 days
|
MaxAge: secondsPerDay * sessionMaxAgeDays,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: !params.Config.IsDev(), // HTTPS in production
|
Secure: !params.Config.IsDev(),
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.key = keyBytes
|
s.key = keyBytes
|
||||||
s.store = store
|
s.store = store
|
||||||
s.log.Info("session manager initialized")
|
s.log.Info("session manager initialized")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -90,99 +126,126 @@ func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a session for the request
|
// Get retrieves a session for the request.
|
||||||
func (s *Session) Get(r *http.Request) (*sessions.Session, error) {
|
func (s *Session) Get(
|
||||||
|
r *http.Request,
|
||||||
|
) (*sessions.Session, error) {
|
||||||
return s.store.Get(r, SessionName)
|
return s.store.Get(r, SessionName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKey returns the raw 32-byte authentication key used for session
|
// GetKey returns the raw 32-byte authentication key used for
|
||||||
// encryption. This key is also suitable for CSRF cookie signing.
|
// session encryption. This key is also suitable for CSRF cookie
|
||||||
|
// signing.
|
||||||
func (s *Session) GetKey() []byte {
|
func (s *Session) GetKey() []byte {
|
||||||
return s.key
|
return s.key
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save saves the session
|
// Save saves the session.
|
||||||
func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
|
func (s *Session) Save(
|
||||||
|
r *http.Request,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
sess *sessions.Session,
|
||||||
|
) error {
|
||||||
return sess.Save(r, w)
|
return sess.Save(r, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUser sets the user information in the session
|
// SetUser sets the user information in the session.
|
||||||
func (s *Session) SetUser(sess *sessions.Session, userID, username string) {
|
func (s *Session) SetUser(
|
||||||
|
sess *sessions.Session,
|
||||||
|
userID, username string,
|
||||||
|
) {
|
||||||
sess.Values[UserIDKey] = userID
|
sess.Values[UserIDKey] = userID
|
||||||
sess.Values[UsernameKey] = username
|
sess.Values[UsernameKey] = username
|
||||||
sess.Values[AuthenticatedKey] = true
|
sess.Values[AuthenticatedKey] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearUser removes user information from the session
|
// ClearUser removes user information from the session.
|
||||||
func (s *Session) ClearUser(sess *sessions.Session) {
|
func (s *Session) ClearUser(sess *sessions.Session) {
|
||||||
delete(sess.Values, UserIDKey)
|
delete(sess.Values, UserIDKey)
|
||||||
delete(sess.Values, UsernameKey)
|
delete(sess.Values, UsernameKey)
|
||||||
delete(sess.Values, AuthenticatedKey)
|
delete(sess.Values, AuthenticatedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAuthenticated checks if the session has an authenticated user
|
// IsAuthenticated checks if the session has an authenticated
|
||||||
|
// user.
|
||||||
func (s *Session) IsAuthenticated(sess *sessions.Session) bool {
|
func (s *Session) IsAuthenticated(sess *sessions.Session) bool {
|
||||||
auth, ok := sess.Values[AuthenticatedKey].(bool)
|
auth, ok := sess.Values[AuthenticatedKey].(bool)
|
||||||
|
|
||||||
return ok && auth
|
return ok && auth
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserID retrieves the user ID from the session
|
// GetUserID retrieves the user ID from the session.
|
||||||
func (s *Session) GetUserID(sess *sessions.Session) (string, bool) {
|
func (s *Session) GetUserID(
|
||||||
|
sess *sessions.Session,
|
||||||
|
) (string, bool) {
|
||||||
userID, ok := sess.Values[UserIDKey].(string)
|
userID, ok := sess.Values[UserIDKey].(string)
|
||||||
|
|
||||||
return userID, ok
|
return userID, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsername retrieves the username from the session
|
// GetUsername retrieves the username from the session.
|
||||||
func (s *Session) GetUsername(sess *sessions.Session) (string, bool) {
|
func (s *Session) GetUsername(
|
||||||
|
sess *sessions.Session,
|
||||||
|
) (string, bool) {
|
||||||
username, ok := sess.Values[UsernameKey].(string)
|
username, ok := sess.Values[UsernameKey].(string)
|
||||||
|
|
||||||
return username, ok
|
return username, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy invalidates the session
|
// Destroy invalidates the session.
|
||||||
func (s *Session) Destroy(sess *sessions.Session) {
|
func (s *Session) Destroy(sess *sessions.Session) {
|
||||||
sess.Options.MaxAge = -1
|
sess.Options.MaxAge = -1
|
||||||
s.ClearUser(sess)
|
s.ClearUser(sess)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Regenerate creates a new session with the same values but a fresh ID.
|
// Regenerate creates a new session with the same values but a
|
||||||
// The old session is destroyed (MaxAge = -1) and saved, then a new session
|
// fresh ID. The old session is destroyed (MaxAge = -1) and saved,
|
||||||
// is created. This prevents session fixation attacks by ensuring the
|
// then a new session is created. This prevents session fixation
|
||||||
// session ID changes after privilege escalation (e.g. login).
|
// attacks by ensuring the session ID changes after privilege
|
||||||
func (s *Session) Regenerate(r *http.Request, w http.ResponseWriter, oldSess *sessions.Session) (*sessions.Session, error) {
|
// escalation (e.g. login).
|
||||||
|
func (s *Session) Regenerate(
|
||||||
|
r *http.Request,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
oldSess *sessions.Session,
|
||||||
|
) (*sessions.Session, error) {
|
||||||
// Copy the values from the old session
|
// Copy the values from the old session
|
||||||
oldValues := make(map[interface{}]interface{})
|
oldValues := make(map[any]any)
|
||||||
for k, v := range oldSess.Values {
|
maps.Copy(oldValues, oldSess.Values)
|
||||||
oldValues[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destroy the old session
|
// Destroy the old session
|
||||||
oldSess.Options.MaxAge = -1
|
oldSess.Options.MaxAge = -1
|
||||||
s.ClearUser(oldSess)
|
s.ClearUser(oldSess)
|
||||||
if err := oldSess.Save(r, w); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to destroy old session: %w", err)
|
err := oldSess.Save(r, w)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"failed to destroy old session: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new session (gorilla/sessions generates a new ID)
|
// Create a new session (gorilla/sessions generates a new ID)
|
||||||
newSess, err := s.store.New(r, SessionName)
|
newSess, err := s.store.New(r, SessionName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// store.New may return an error alongside a new empty session
|
// store.New may return an error alongside a new empty
|
||||||
// if the old cookie is now invalid. That is expected after we
|
// session if the old cookie is now invalid. That is
|
||||||
// destroyed it above. Only fail on a nil session.
|
// expected after we destroyed it above. Only fail on a
|
||||||
|
// nil session.
|
||||||
if newSess == nil {
|
if newSess == nil {
|
||||||
return nil, fmt.Errorf("failed to create new session: %w", err)
|
return nil, fmt.Errorf(
|
||||||
|
"failed to create new session: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore the copied values into the new session
|
// Restore the copied values into the new session
|
||||||
for k, v := range oldValues {
|
maps.Copy(newSess.Values, oldValues)
|
||||||
newSess.Values[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply the standard session options (the destroyed old session had
|
// Apply the standard session options (the destroyed old
|
||||||
// MaxAge = -1, which store.New might inherit from the cookie).
|
// session had MaxAge = -1, which store.New might inherit
|
||||||
|
// from the cookie).
|
||||||
newSess.Options = &sessions.Options{
|
newSess.Options = &sessions.Options{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: 86400 * 7,
|
MaxAge: secondsPerDay * sessionMaxAgeDays,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: !s.config.IsDev(),
|
Secure: !s.config.IsDev(),
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package session
|
package session_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -11,15 +12,22 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"sneak.berlin/go/webhooker/internal/config"
|
"sneak.berlin/go/webhooker/internal/config"
|
||||||
|
"sneak.berlin/go/webhooker/internal/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testSession creates a Session with a real cookie store for testing.
|
const testKeySize = 32
|
||||||
func testSession(t *testing.T) *Session {
|
|
||||||
|
// testSession creates a Session with a real cookie store for
|
||||||
|
// testing.
|
||||||
|
func testSession(t *testing.T) *session.Session {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
key := make([]byte, 32)
|
|
||||||
|
key := make([]byte, testKeySize)
|
||||||
|
|
||||||
for i := range key {
|
for i := range key {
|
||||||
key[i] = byte(i + 42)
|
key[i] = byte(i + 42)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := sessions.NewCookieStore(key)
|
store := sessions.NewCookieStore(key)
|
||||||
store.Options = &sessions.Options{
|
store.Options = &sessions.Options{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
@@ -32,34 +40,47 @@ func testSession(t *testing.T) *Session {
|
|||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Environment: config.EnvironmentDev,
|
Environment: config.EnvironmentDev,
|
||||||
}
|
}
|
||||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
||||||
|
|
||||||
return NewForTest(store, cfg, log, key)
|
log := slog.New(slog.NewTextHandler(
|
||||||
|
os.Stderr,
|
||||||
|
&slog.HandlerOptions{Level: slog.LevelDebug},
|
||||||
|
))
|
||||||
|
|
||||||
|
return session.NewForTest(store, cfg, log, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Get and Save Tests ---
|
// --- Get and Save Tests ---
|
||||||
|
|
||||||
func TestGet_NewSession(t *testing.T) {
|
func TestGet_NewSession(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, sess)
|
require.NotNil(t, sess)
|
||||||
assert.True(t, sess.IsNew, "session should be new when no cookie is present")
|
assert.True(
|
||||||
|
t, sess.IsNew,
|
||||||
|
"session should be new when no cookie is present",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGet_ExistingSession(t *testing.T) {
|
func TestGet_ExistingSession(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
// Create and save a session
|
// Create and save a session
|
||||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
req1 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
w1 := httptest.NewRecorder()
|
w1 := httptest.NewRecorder()
|
||||||
|
|
||||||
sess1, err := s.Get(req1)
|
sess1, err := s.Get(req1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sess1.Values["test_key"] = "test_value"
|
sess1.Values["test_key"] = "test_value"
|
||||||
require.NoError(t, s.Save(req1, w1, sess1))
|
require.NoError(t, s.Save(req1, w1, sess1))
|
||||||
|
|
||||||
@@ -68,26 +89,34 @@ func TestGet_ExistingSession(t *testing.T) {
|
|||||||
require.NotEmpty(t, cookies)
|
require.NotEmpty(t, cookies)
|
||||||
|
|
||||||
// Make a new request with the session cookie
|
// Make a new request with the session cookie
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
req2 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
req2.AddCookie(c)
|
req2.AddCookie(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
sess2, err := s.Get(req2)
|
sess2, err := s.Get(req2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, sess2.IsNew, "session should not be new when cookie is present")
|
assert.False(
|
||||||
|
t, sess2.IsNew,
|
||||||
|
"session should not be new when cookie is present",
|
||||||
|
)
|
||||||
assert.Equal(t, "test_value", sess2.Values["test_key"])
|
assert.Equal(t, "test_value", sess2.Values["test_key"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSave_SetsCookie(t *testing.T) {
|
func TestSave_SetsCookie(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sess.Values["key"] = "value"
|
sess.Values["key"] = "value"
|
||||||
|
|
||||||
err = s.Save(req, w, sess)
|
err = s.Save(req, w, sess)
|
||||||
@@ -98,48 +127,73 @@ func TestSave_SetsCookie(t *testing.T) {
|
|||||||
|
|
||||||
// Verify the cookie has the expected name
|
// Verify the cookie has the expected name
|
||||||
var found bool
|
var found bool
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
if c.Name == SessionName {
|
if c.Name == session.SessionName {
|
||||||
found = true
|
found = true
|
||||||
assert.True(t, c.HttpOnly, "session cookie should be HTTP-only")
|
|
||||||
|
assert.True(
|
||||||
|
t, c.HttpOnly,
|
||||||
|
"session cookie should be HTTP-only",
|
||||||
|
)
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.True(t, found, "should find a cookie named %s", SessionName)
|
|
||||||
|
assert.True(
|
||||||
|
t, found,
|
||||||
|
"should find a cookie named %s", session.SessionName,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- SetUser and User Retrieval Tests ---
|
// --- SetUser and User Retrieval Tests ---
|
||||||
|
|
||||||
func TestSetUser_SetsAllFields(t *testing.T) {
|
func TestSetUser_SetsAllFields(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s.SetUser(sess, "user-abc-123", "alice")
|
s.SetUser(sess, "user-abc-123", "alice")
|
||||||
|
|
||||||
assert.Equal(t, "user-abc-123", sess.Values[UserIDKey])
|
assert.Equal(
|
||||||
assert.Equal(t, "alice", sess.Values[UsernameKey])
|
t, "user-abc-123", sess.Values[session.UserIDKey],
|
||||||
assert.Equal(t, true, sess.Values[AuthenticatedKey])
|
)
|
||||||
|
assert.Equal(
|
||||||
|
t, "alice", sess.Values[session.UsernameKey],
|
||||||
|
)
|
||||||
|
assert.Equal(
|
||||||
|
t, true, sess.Values[session.AuthenticatedKey],
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetUserID(t *testing.T) {
|
func TestGetUserID(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Before setting user
|
// Before setting user
|
||||||
userID, ok := s.GetUserID(sess)
|
userID, ok := s.GetUserID(sess)
|
||||||
assert.False(t, ok, "should return false when no user ID is set")
|
assert.False(
|
||||||
|
t, ok, "should return false when no user ID is set",
|
||||||
|
)
|
||||||
assert.Empty(t, userID)
|
assert.Empty(t, userID)
|
||||||
|
|
||||||
// After setting user
|
// After setting user
|
||||||
s.SetUser(sess, "user-xyz", "bob")
|
s.SetUser(sess, "user-xyz", "bob")
|
||||||
|
|
||||||
userID, ok = s.GetUserID(sess)
|
userID, ok = s.GetUserID(sess)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, "user-xyz", userID)
|
assert.Equal(t, "user-xyz", userID)
|
||||||
@@ -147,19 +201,25 @@ func TestGetUserID(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetUsername(t *testing.T) {
|
func TestGetUsername(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Before setting user
|
// Before setting user
|
||||||
username, ok := s.GetUsername(sess)
|
username, ok := s.GetUsername(sess)
|
||||||
assert.False(t, ok, "should return false when no username is set")
|
assert.False(
|
||||||
|
t, ok, "should return false when no username is set",
|
||||||
|
)
|
||||||
assert.Empty(t, username)
|
assert.Empty(t, username)
|
||||||
|
|
||||||
// After setting user
|
// After setting user
|
||||||
s.SetUser(sess, "user-xyz", "bob")
|
s.SetUser(sess, "user-xyz", "bob")
|
||||||
|
|
||||||
username, ok = s.GetUsername(sess)
|
username, ok = s.GetUsername(sess)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, "bob", username)
|
assert.Equal(t, "bob", username)
|
||||||
@@ -169,20 +229,29 @@ func TestGetUsername(t *testing.T) {
|
|||||||
|
|
||||||
func TestIsAuthenticated_NoSession(t *testing.T) {
|
func TestIsAuthenticated_NoSession(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, s.IsAuthenticated(sess), "new session should not be authenticated")
|
assert.False(
|
||||||
|
t, s.IsAuthenticated(sess),
|
||||||
|
"new session should not be authenticated",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsAuthenticated_AfterSetUser(t *testing.T) {
|
func TestIsAuthenticated_AfterSetUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -192,9 +261,12 @@ func TestIsAuthenticated_AfterSetUser(t *testing.T) {
|
|||||||
|
|
||||||
func TestIsAuthenticated_AfterClearUser(t *testing.T) {
|
func TestIsAuthenticated_AfterClearUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -202,52 +274,71 @@ func TestIsAuthenticated_AfterClearUser(t *testing.T) {
|
|||||||
require.True(t, s.IsAuthenticated(sess))
|
require.True(t, s.IsAuthenticated(sess))
|
||||||
|
|
||||||
s.ClearUser(sess)
|
s.ClearUser(sess)
|
||||||
assert.False(t, s.IsAuthenticated(sess), "should not be authenticated after ClearUser")
|
|
||||||
|
assert.False(
|
||||||
|
t, s.IsAuthenticated(sess),
|
||||||
|
"should not be authenticated after ClearUser",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsAuthenticated_WrongType(t *testing.T) {
|
func TestIsAuthenticated_WrongType(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Set authenticated to a non-bool value
|
// Set authenticated to a non-bool value
|
||||||
sess.Values[AuthenticatedKey] = "yes"
|
sess.Values[session.AuthenticatedKey] = "yes"
|
||||||
assert.False(t, s.IsAuthenticated(sess), "should return false for non-bool authenticated value")
|
|
||||||
|
assert.False(
|
||||||
|
t, s.IsAuthenticated(sess),
|
||||||
|
"should return false for non-bool authenticated value",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- ClearUser Tests ---
|
// --- ClearUser Tests ---
|
||||||
|
|
||||||
func TestClearUser_RemovesAllKeys(t *testing.T) {
|
func TestClearUser_RemovesAllKeys(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s.SetUser(sess, "user-123", "alice")
|
s.SetUser(sess, "user-123", "alice")
|
||||||
s.ClearUser(sess)
|
s.ClearUser(sess)
|
||||||
|
|
||||||
_, hasUserID := sess.Values[UserIDKey]
|
_, hasUserID := sess.Values[session.UserIDKey]
|
||||||
assert.False(t, hasUserID, "UserIDKey should be removed")
|
assert.False(t, hasUserID, "UserIDKey should be removed")
|
||||||
|
|
||||||
_, hasUsername := sess.Values[UsernameKey]
|
_, hasUsername := sess.Values[session.UsernameKey]
|
||||||
assert.False(t, hasUsername, "UsernameKey should be removed")
|
assert.False(t, hasUsername, "UsernameKey should be removed")
|
||||||
|
|
||||||
_, hasAuth := sess.Values[AuthenticatedKey]
|
_, hasAuth := sess.Values[session.AuthenticatedKey]
|
||||||
assert.False(t, hasAuth, "AuthenticatedKey should be removed")
|
assert.False(
|
||||||
|
t, hasAuth, "AuthenticatedKey should be removed",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Destroy Tests ---
|
// --- Destroy Tests ---
|
||||||
|
|
||||||
func TestDestroy_InvalidatesSession(t *testing.T) {
|
func TestDestroy_InvalidatesSession(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -255,11 +346,18 @@ func TestDestroy_InvalidatesSession(t *testing.T) {
|
|||||||
|
|
||||||
s.Destroy(sess)
|
s.Destroy(sess)
|
||||||
|
|
||||||
// After Destroy: MaxAge should be -1 (delete cookie) and user data cleared
|
// After Destroy: MaxAge should be -1 (delete cookie) and
|
||||||
assert.Equal(t, -1, sess.Options.MaxAge, "Destroy should set MaxAge to -1")
|
// user data cleared
|
||||||
assert.False(t, s.IsAuthenticated(sess), "should not be authenticated after Destroy")
|
assert.Equal(
|
||||||
|
t, -1, sess.Options.MaxAge,
|
||||||
|
"Destroy should set MaxAge to -1",
|
||||||
|
)
|
||||||
|
assert.False(
|
||||||
|
t, s.IsAuthenticated(sess),
|
||||||
|
"should not be authenticated after Destroy",
|
||||||
|
)
|
||||||
|
|
||||||
_, hasUserID := sess.Values[UserIDKey]
|
_, hasUserID := sess.Values[session.UserIDKey]
|
||||||
assert.False(t, hasUserID, "Destroy should clear user ID")
|
assert.False(t, hasUserID, "Destroy should clear user ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,10 +365,12 @@ func TestDestroy_InvalidatesSession(t *testing.T) {
|
|||||||
|
|
||||||
func TestSessionPersistence_RoundTrip(t *testing.T) {
|
func TestSessionPersistence_RoundTrip(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
// Step 1: Create session, set user, save
|
// Step 1: Create session, set user, save
|
||||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
req1 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
w1 := httptest.NewRecorder()
|
w1 := httptest.NewRecorder()
|
||||||
|
|
||||||
sess1, err := s.Get(req1)
|
sess1, err := s.Get(req1)
|
||||||
@@ -281,8 +381,13 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
|
|||||||
cookies := w1.Result().Cookies()
|
cookies := w1.Result().Cookies()
|
||||||
require.NotEmpty(t, cookies)
|
require.NotEmpty(t, cookies)
|
||||||
|
|
||||||
// Step 2: New request with cookies — session data should persist
|
// Step 2: New request with cookies -- session data should
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/profile", nil)
|
// persist
|
||||||
|
req2 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/profile", nil,
|
||||||
|
)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
req2.AddCookie(c)
|
req2.AddCookie(c)
|
||||||
}
|
}
|
||||||
@@ -290,7 +395,10 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
|
|||||||
sess2, err := s.Get(req2)
|
sess2, err := s.Get(req2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, s.IsAuthenticated(sess2), "session should be authenticated after round-trip")
|
assert.True(
|
||||||
|
t, s.IsAuthenticated(sess2),
|
||||||
|
"session should be authenticated after round-trip",
|
||||||
|
)
|
||||||
|
|
||||||
userID, ok := s.GetUserID(sess2)
|
userID, ok := s.GetUserID(sess2)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -305,19 +413,23 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
|
|||||||
|
|
||||||
func TestSessionConstants(t *testing.T) {
|
func TestSessionConstants(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
assert.Equal(t, "webhooker_session", SessionName)
|
|
||||||
assert.Equal(t, "user_id", UserIDKey)
|
assert.Equal(t, "webhooker_session", session.SessionName)
|
||||||
assert.Equal(t, "username", UsernameKey)
|
assert.Equal(t, "user_id", session.UserIDKey)
|
||||||
assert.Equal(t, "authenticated", AuthenticatedKey)
|
assert.Equal(t, "username", session.UsernameKey)
|
||||||
|
assert.Equal(t, "authenticated", session.AuthenticatedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Edge Cases ---
|
// --- Edge Cases ---
|
||||||
|
|
||||||
func TestSetUser_OverwritesPreviousUser(t *testing.T) {
|
func TestSetUser_OverwritesPreviousUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
|
|
||||||
sess, err := s.Get(req)
|
sess, err := s.Get(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -338,10 +450,12 @@ func TestSetUser_OverwritesPreviousUser(t *testing.T) {
|
|||||||
|
|
||||||
func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := testSession(t)
|
s := testSession(t)
|
||||||
|
|
||||||
// Create a session
|
// Create a session
|
||||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
req1 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(), http.MethodGet, "/", nil)
|
||||||
w1 := httptest.NewRecorder()
|
w1 := httptest.NewRecorder()
|
||||||
|
|
||||||
sess, err := s.Get(req1)
|
sess, err := s.Get(req1)
|
||||||
@@ -353,10 +467,15 @@ func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
|||||||
require.NotEmpty(t, cookies)
|
require.NotEmpty(t, cookies)
|
||||||
|
|
||||||
// Destroy and save
|
// Destroy and save
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/logout", nil)
|
req2 := httptest.NewRequestWithContext(
|
||||||
|
context.Background(),
|
||||||
|
http.MethodGet, "/logout", nil,
|
||||||
|
)
|
||||||
|
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
req2.AddCookie(c)
|
req2.AddCookie(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
w2 := httptest.NewRecorder()
|
w2 := httptest.NewRecorder()
|
||||||
|
|
||||||
sess2, err := s.Get(req2)
|
sess2, err := s.Get(req2)
|
||||||
@@ -364,15 +483,25 @@ func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
|||||||
s.Destroy(sess2)
|
s.Destroy(sess2)
|
||||||
require.NoError(t, s.Save(req2, w2, sess2))
|
require.NoError(t, s.Save(req2, w2, sess2))
|
||||||
|
|
||||||
// The cookie should have MaxAge = -1 (browser should delete it)
|
// The cookie should have MaxAge = -1 (browser should delete)
|
||||||
responseCookies := w2.Result().Cookies()
|
responseCookies := w2.Result().Cookies()
|
||||||
|
|
||||||
var sessionCookie *http.Cookie
|
var sessionCookie *http.Cookie
|
||||||
|
|
||||||
for _, c := range responseCookies {
|
for _, c := range responseCookies {
|
||||||
if c.Name == SessionName {
|
if c.Name == session.SessionName {
|
||||||
sessionCookie = c
|
sessionCookie = c
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
require.NotNil(t, sessionCookie, "should have a session cookie in response")
|
|
||||||
assert.True(t, sessionCookie.MaxAge < 0, "destroyed session cookie should have negative MaxAge")
|
require.NotNil(
|
||||||
|
t, sessionCookie,
|
||||||
|
"should have a session cookie in response",
|
||||||
|
)
|
||||||
|
assert.Negative(
|
||||||
|
t, sessionCookie.MaxAge,
|
||||||
|
"destroyed session cookie should have negative MaxAge",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
|
// Package static embeds static assets (CSS, JS) served by the web UI.
|
||||||
package static
|
package static
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Static holds the embedded CSS and JavaScript files for the web UI.
|
||||||
|
//
|
||||||
//go:embed css js
|
//go:embed css js
|
||||||
var Static embed.FS
|
var Static embed.FS
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
|
// Package templates embeds HTML templates used by the web UI.
|
||||||
package templates
|
package templates
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Templates holds the embedded HTML template files.
|
||||||
|
//
|
||||||
//go:embed *.html
|
//go:embed *.html
|
||||||
var Templates embed.FS
|
var Templates embed.FS
|
||||||
|
|||||||
Reference in New Issue
Block a user