Compare commits
1 Commits
main
...
4d5ebfd692
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d5ebfd692 |
@@ -2,31 +2,42 @@ version: "2"
|
||||
|
||||
run:
|
||||
timeout: 5m
|
||||
modules-download-mode: readonly
|
||||
tests: true
|
||||
|
||||
linters:
|
||||
default: all
|
||||
disable:
|
||||
# Genuinely incompatible with project patterns
|
||||
- exhaustruct # Requires all struct fields
|
||||
- depguard # Dependency allow/block lists
|
||||
- godot # Requires comments to end with periods
|
||||
- wsl # Deprecated, replaced by wsl_v5
|
||||
- wrapcheck # Too verbose for internal packages
|
||||
- varnamelen # Short names like db, id are idiomatic Go
|
||||
enable:
|
||||
- revive
|
||||
- govet
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- unused
|
||||
- ineffassign
|
||||
- gosec
|
||||
- misspell
|
||||
- unparam
|
||||
- prealloc
|
||||
- copyloopvar
|
||||
- gocritic
|
||||
- gochecknoinits
|
||||
- gochecknoglobals
|
||||
|
||||
linters-settings:
|
||||
lll:
|
||||
line-length: 88
|
||||
funlen:
|
||||
lines: 80
|
||||
statements: 50
|
||||
cyclop:
|
||||
max-complexity: 15
|
||||
dupl:
|
||||
threshold: 100
|
||||
revive:
|
||||
confidence: 0.8
|
||||
govet:
|
||||
enable:
|
||||
- shadow
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: true
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
exclude-rules:
|
||||
# Exclude globals check for version variables in main
|
||||
- path: cmd/webhooker/main.go
|
||||
linters:
|
||||
- gochecknoglobals
|
||||
# Exclude globals check for version variables in globals package
|
||||
- path: internal/globals/globals.go
|
||||
linters:
|
||||
- gochecknoglobals
|
||||
|
||||
21
README.md
21
README.md
@@ -62,21 +62,6 @@ or `prod` (default: `dev`). The setting controls several behaviors:
|
||||
| CORS | Allows any origin (`*`) | Disabled (no-op) |
|
||||
| Session cookie Secure | `false` (works over plain HTTP) | `true` (requires HTTPS) |
|
||||
|
||||
The CSRF cookie's `Secure` flag and Origin/Referer validation mode are
|
||||
determined per-request based on the actual transport protocol, not the
|
||||
environment setting. The middleware checks `r.TLS` (direct TLS) and the
|
||||
`X-Forwarded-Proto` header (TLS-terminating reverse proxy) to decide:
|
||||
|
||||
- **Direct TLS or `X-Forwarded-Proto: https`**: Secure cookies, strict
|
||||
Origin/Referer validation.
|
||||
- **Plaintext HTTP**: Non-Secure cookies, relaxed Origin/Referer
|
||||
checks (token validation still enforced).
|
||||
|
||||
This means CSRF protection works correctly in all deployment scenarios:
|
||||
behind a TLS-terminating reverse proxy, with direct TLS, or over plain
|
||||
HTTP during development. When running behind a reverse proxy, ensure it
|
||||
sets the `X-Forwarded-Proto: https` header.
|
||||
|
||||
All other differences (log format, security headers, etc.) are
|
||||
independent of the environment setting — log format is determined by
|
||||
TTY detection, and security headers are always applied.
|
||||
@@ -667,7 +652,7 @@ against a misbehaving sender).
|
||||
|
||||
| Method | Path | Description |
|
||||
| ------ | --------------------------- | ----------- |
|
||||
| `GET` | `/` | Root redirect (authenticated → `/sources`, unauthenticated → `/pages/login`) |
|
||||
| `GET` | `/` | Web UI index page (server-rendered) |
|
||||
| `GET` | `/.well-known/healthcheck` | Health check (JSON: status, uptime, version) |
|
||||
| `GET` | `/s/*` | Static file serving (embedded CSS, JS) |
|
||||
| `ANY` | `/webhook/{uuid}` | Webhook receiver endpoint (accepts all methods) |
|
||||
@@ -856,9 +841,7 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a
|
||||
on all state-changing forms (cookie-based double-submit tokens with
|
||||
HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and
|
||||
`/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and
|
||||
`/api` (stateless API). The middleware auto-detects TLS status
|
||||
per-request (via `r.TLS` and `X-Forwarded-Proto`) to set appropriate
|
||||
cookie security flags and Origin/Referer validation mode
|
||||
`/api` (stateless API)
|
||||
- **SSRF prevention** for HTTP delivery targets: private/reserved IP
|
||||
ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked
|
||||
both at target creation time (URL validation) and at delivery time
|
||||
|
||||
4
go.mod
4
go.mod
@@ -1,6 +1,8 @@
|
||||
module sneak.berlin/go/webhooker
|
||||
|
||||
go 1.26.1
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.1
|
||||
|
||||
require (
|
||||
github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -19,29 +18,20 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvironmentDev represents development environment.
|
||||
// EnvironmentDev represents development environment
|
||||
EnvironmentDev = "dev"
|
||||
// EnvironmentProd represents production environment.
|
||||
// EnvironmentProd represents production environment
|
||||
EnvironmentProd = "prod"
|
||||
|
||||
// defaultPort is the default HTTP listen port.
|
||||
defaultPort = 8080
|
||||
)
|
||||
|
||||
// ErrInvalidEnvironment is returned when WEBHOOKER_ENVIRONMENT
|
||||
// contains an unrecognised value.
|
||||
var ErrInvalidEnvironment = errors.New("invalid environment")
|
||||
|
||||
//nolint:revive // ConfigParams is a standard fx naming convention.
|
||||
type ConfigParams struct {
|
||||
fx.In
|
||||
|
||||
Globals *globals.Globals
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// Config holds all application configuration loaded from
|
||||
// environment variables.
|
||||
// Config holds all application configuration loaded from environment variables.
|
||||
type Config struct {
|
||||
DataDir string
|
||||
Debug bool
|
||||
@@ -55,43 +45,39 @@ type Config struct {
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// IsDev returns true if running in development environment.
|
||||
// IsDev returns true if running in development environment
|
||||
func (c *Config) IsDev() bool {
|
||||
return c.Environment == EnvironmentDev
|
||||
}
|
||||
|
||||
// IsProd returns true if running in production environment.
|
||||
// IsProd returns true if running in production environment
|
||||
func (c *Config) IsProd() bool {
|
||||
return c.Environment == EnvironmentProd
|
||||
}
|
||||
|
||||
// envString returns the value of the named environment variable,
|
||||
// or an empty string if not set.
|
||||
// envString returns the value of the named environment variable, or
|
||||
// an empty string if not set.
|
||||
func envString(key string) string {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
|
||||
// envBool returns the value of the named environment variable
|
||||
// parsed as a boolean. Returns defaultValue if not set.
|
||||
// envBool returns the value of the named environment variable parsed as a
|
||||
// boolean. Returns defaultValue if not set.
|
||||
func envBool(key string, defaultValue bool) bool {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return strings.EqualFold(v, "true") || v == "1"
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// envInt returns the value of the named environment variable
|
||||
// parsed as an integer. Returns defaultValue if not set or
|
||||
// unparseable.
|
||||
// envInt returns the value of the named environment variable parsed as an
|
||||
// integer. Returns defaultValue if not set or unparseable.
|
||||
func envInt(key string, defaultValue int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
i, err := strconv.Atoi(v)
|
||||
if err == nil {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
@@ -101,21 +87,16 @@ func envInt(key string, defaultValue int) int {
|
||||
func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
||||
log := params.Logger.Get()
|
||||
|
||||
// Determine environment from WEBHOOKER_ENVIRONMENT env var,
|
||||
// default to dev
|
||||
// Determine environment from WEBHOOKER_ENVIRONMENT env var, default to dev
|
||||
environment := os.Getenv("WEBHOOKER_ENVIRONMENT")
|
||||
if environment == "" {
|
||||
environment = EnvironmentDev
|
||||
}
|
||||
|
||||
// Validate environment
|
||||
if environment != EnvironmentDev &&
|
||||
environment != EnvironmentProd {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: WEBHOOKER_ENVIRONMENT must be '%s' or '%s', got '%s'",
|
||||
ErrInvalidEnvironment,
|
||||
EnvironmentDev, EnvironmentProd, environment,
|
||||
)
|
||||
if environment != EnvironmentDev && environment != EnvironmentProd {
|
||||
return nil, fmt.Errorf("WEBHOOKER_ENVIRONMENT must be either '%s' or '%s', got '%s'",
|
||||
EnvironmentDev, EnvironmentProd, environment)
|
||||
}
|
||||
|
||||
// Load configuration values from environment variables
|
||||
@@ -126,16 +107,15 @@ func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
||||
Environment: environment,
|
||||
MetricsUsername: envString("METRICS_USERNAME"),
|
||||
MetricsPassword: envString("METRICS_PASSWORD"),
|
||||
Port: envInt("PORT", defaultPort),
|
||||
Port: envInt("PORT", 8080),
|
||||
SentryDSN: envString("SENTRY_DSN"),
|
||||
log: log,
|
||||
params: ¶ms,
|
||||
}
|
||||
|
||||
// Set default DataDir. All SQLite databases (main application
|
||||
// DB and per-webhook event DBs) live here. The same default is
|
||||
// used regardless of environment; override with DATA_DIR if
|
||||
// needed.
|
||||
// Set default DataDir. All SQLite databases (main application DB
|
||||
// and per-webhook event DBs) live here. The same default is used
|
||||
// regardless of environment; override with DATA_DIR if needed.
|
||||
if s.DataDir == "" {
|
||||
s.DataDir = "/var/lib/webhooker"
|
||||
}
|
||||
@@ -152,8 +132,7 @@ func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
||||
"maintenanceMode", s.MaintenanceMode,
|
||||
"dataDir", s.DataDir,
|
||||
"hasSentryDSN", s.SentryDSN != "",
|
||||
"hasMetricsAuth",
|
||||
s.MetricsUsername != "" && s.MetricsPassword != "",
|
||||
"hasMetricsAuth", s.MetricsUsername != "" && s.MetricsPassword != "",
|
||||
)
|
||||
|
||||
return s, nil
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package config_test
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/fx"
|
||||
"go.uber.org/fx/fxtest"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
@@ -24,142 +23,117 @@ func TestEnvironmentConfig(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "default is dev",
|
||||
envValue: "",
|
||||
envVars: map[string]string{},
|
||||
expectError: false,
|
||||
isDev: true,
|
||||
isProd: false,
|
||||
},
|
||||
{
|
||||
name: "explicit dev",
|
||||
envValue: "dev",
|
||||
envVars: map[string]string{},
|
||||
expectError: false,
|
||||
isDev: true,
|
||||
isProd: false,
|
||||
},
|
||||
{
|
||||
name: "explicit prod",
|
||||
envValue: "prod",
|
||||
envVars: map[string]string{},
|
||||
expectError: false,
|
||||
isDev: false,
|
||||
isProd: true,
|
||||
},
|
||||
{
|
||||
name: "invalid environment",
|
||||
envValue: "staging",
|
||||
envVars: map[string]string{},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Cannot use t.Parallel() here because t.Setenv
|
||||
// is incompatible with parallel subtests.
|
||||
// Set environment variable if specified
|
||||
if tt.envValue != "" {
|
||||
t.Setenv(
|
||||
"WEBHOOKER_ENVIRONMENT", tt.envValue,
|
||||
)
|
||||
t.Setenv("WEBHOOKER_ENVIRONMENT", tt.envValue)
|
||||
} else {
|
||||
require.NoError(t, os.Unsetenv(
|
||||
"WEBHOOKER_ENVIRONMENT",
|
||||
))
|
||||
require.NoError(t, os.Unsetenv("WEBHOOKER_ENVIRONMENT"))
|
||||
}
|
||||
|
||||
// Set additional environment variables
|
||||
for k, v := range tt.envVars {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
testEnvironmentConfigError(t)
|
||||
} else {
|
||||
testEnvironmentConfigSuccess(
|
||||
t, tt.isDev, tt.isProd,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testEnvironmentConfigError(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
var cfg *config.Config
|
||||
|
||||
// Use regular fx.New for error cases since fxtest doesn't expose errors the same way
|
||||
var cfg *Config
|
||||
app := fx.New(
|
||||
fx.NopLogger,
|
||||
fx.NopLogger, // Suppress fx logs in tests
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
config.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
|
||||
assert.Error(t, app.Err())
|
||||
}
|
||||
|
||||
func testEnvironmentConfigSuccess(
|
||||
t *testing.T,
|
||||
isDev, isProd bool,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
var cfg *config.Config
|
||||
|
||||
} else {
|
||||
// Use fxtest for success cases
|
||||
var cfg *Config
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
config.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
require.NoError(t, app.Err())
|
||||
|
||||
app.RequireStart()
|
||||
|
||||
defer app.RequireStop()
|
||||
|
||||
assert.Equal(t, isDev, cfg.IsDev())
|
||||
assert.Equal(t, isProd, cfg.IsProd())
|
||||
assert.Equal(t, tt.isDev, cfg.IsDev())
|
||||
assert.Equal(t, tt.isProd, cfg.IsProd())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultDataDir(t *testing.T) {
|
||||
// Verify that when DATA_DIR is unset, the default is /var/lib/webhooker
|
||||
// regardless of the environment setting.
|
||||
for _, env := range []string{"", "dev", "prod"} {
|
||||
name := env
|
||||
if name == "" {
|
||||
name = "unset"
|
||||
}
|
||||
|
||||
t.Run("env="+name, func(t *testing.T) {
|
||||
// Cannot use t.Parallel() here because t.Setenv
|
||||
// is incompatible with parallel subtests.
|
||||
if env != "" {
|
||||
t.Setenv("WEBHOOKER_ENVIRONMENT", env)
|
||||
} else {
|
||||
require.NoError(t, os.Unsetenv(
|
||||
"WEBHOOKER_ENVIRONMENT",
|
||||
))
|
||||
require.NoError(t, os.Unsetenv("WEBHOOKER_ENVIRONMENT"))
|
||||
}
|
||||
|
||||
require.NoError(t, os.Unsetenv("DATA_DIR"))
|
||||
|
||||
var cfg *config.Config
|
||||
|
||||
var cfg *Config
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
config.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
require.NoError(t, app.Err())
|
||||
|
||||
app.RequireStart()
|
||||
|
||||
defer app.RequireStop()
|
||||
|
||||
assert.Equal(
|
||||
t, "/var/lib/webhooker", cfg.DataDir,
|
||||
)
|
||||
assert.Equal(t, "/var/lib/webhooker", cfg.DataDir)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
// This replaces gorm.Model but uses UUID instead of uint for ID
|
||||
type BaseModel struct {
|
||||
ID string `gorm:"type:uuid;primary_key" json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deletedAt,omitzero"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to set UUID before creating a record.
|
||||
@@ -21,6 +21,5 @@ func (b *BaseModel) BeforeCreate(_ *gorm.DB) error {
|
||||
if b.ID == "" {
|
||||
b.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,16 +20,9 @@ import (
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
dataDirPerm = 0750
|
||||
randomPasswordLen = 16
|
||||
sessionKeyLen = 32
|
||||
)
|
||||
|
||||
//nolint:revive // DatabaseParams is a standard fx naming convention.
|
||||
type DatabaseParams struct {
|
||||
fx.In
|
||||
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
@@ -42,20 +35,17 @@ type Database struct {
|
||||
}
|
||||
|
||||
// New creates a Database that connects on fx start and disconnects on stop.
|
||||
func New(
|
||||
lc fx.Lifecycle,
|
||||
params DatabaseParams,
|
||||
) (*Database, error) {
|
||||
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
||||
d := &Database{
|
||||
params: ¶ms,
|
||||
log: params.Logger.Get(),
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error {
|
||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
return d.connect()
|
||||
},
|
||||
OnStop: func(_ context.Context) error {
|
||||
OnStop: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
return d.close()
|
||||
},
|
||||
})
|
||||
@@ -63,92 +53,21 @@ func New(
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// DB returns the underlying GORM database handle.
|
||||
func (d *Database) DB() *gorm.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
// GetOrCreateSessionKey retrieves the session encryption key from the
|
||||
// settings table. If no key exists, a cryptographically secure random
|
||||
// 32-byte key is generated, base64-encoded, and stored for future use.
|
||||
func (d *Database) GetOrCreateSessionKey() (string, error) {
|
||||
var setting Setting
|
||||
|
||||
result := d.db.Where(
|
||||
&Setting{Key: "session_key"},
|
||||
).First(&setting)
|
||||
if result.Error == nil {
|
||||
return setting.Value, nil
|
||||
}
|
||||
|
||||
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", fmt.Errorf(
|
||||
"failed to query session key: %w",
|
||||
result.Error,
|
||||
)
|
||||
}
|
||||
|
||||
// Generate a new cryptographically secure 32-byte key
|
||||
keyBytes := make([]byte, sessionKeyLen)
|
||||
|
||||
_, err := rand.Read(keyBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"failed to generate session key: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(keyBytes)
|
||||
|
||||
setting = Setting{
|
||||
Key: "session_key",
|
||||
Value: encoded,
|
||||
}
|
||||
|
||||
err = d.db.Create(&setting).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"failed to store session key: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
d.log.Info(
|
||||
"generated new session key and stored in database",
|
||||
)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (d *Database) connect() error {
|
||||
// Ensure the data directory exists before opening the database.
|
||||
dataDir := d.params.Config.DataDir
|
||||
|
||||
err := os.MkdirAll(dataDir, dataDirPerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"creating data directory %s: %w",
|
||||
dataDir,
|
||||
err,
|
||||
)
|
||||
if err := os.MkdirAll(dataDir, 0750); err != nil {
|
||||
return fmt.Errorf("creating data directory %s: %w", dataDir, err)
|
||||
}
|
||||
|
||||
// Construct the main application database path inside DATA_DIR.
|
||||
dbPath := filepath.Join(dataDir, "webhooker.db")
|
||||
dbURL := fmt.Sprintf(
|
||||
"file:%s?cache=shared&mode=rwc",
|
||||
dbPath,
|
||||
)
|
||||
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbPath)
|
||||
|
||||
// Open the database with the pure Go SQLite driver
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to open database",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
d.log.Error("failed to open database", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -157,11 +76,7 @@ func (d *Database) connect() error {
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to connect to database",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
d.log.Error("failed to connect to database", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -174,62 +89,34 @@ func (d *Database) connect() error {
|
||||
|
||||
func (d *Database) migrate() error {
|
||||
// Run GORM auto-migrations
|
||||
err := d.Migrate()
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to run database migrations",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
if err := d.Migrate(); err != nil {
|
||||
d.log.Error("failed to run database migrations", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Info("database migrations completed")
|
||||
|
||||
// Check if admin user exists
|
||||
var userCount int64
|
||||
|
||||
err = d.db.Model(&User{}).Count(&userCount).Error
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to count users",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
if err := d.db.Model(&User{}).Count(&userCount).Error; err != nil {
|
||||
d.log.Error("failed to count users", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if userCount == 0 {
|
||||
return d.createAdminUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) createAdminUser() error {
|
||||
// Create admin user
|
||||
d.log.Info("no users found, creating admin user")
|
||||
|
||||
// Generate random password
|
||||
password, err := GenerateRandomPassword(
|
||||
randomPasswordLen,
|
||||
)
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to generate random password",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
d.log.Error("failed to generate random password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
hashedPassword, err := HashPassword(password)
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to hash password",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
d.log.Error("failed to hash password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -239,22 +126,17 @@ func (d *Database) createAdminUser() error {
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
err = d.db.Create(adminUser).Error
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to create admin user",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
if err := d.db.Create(adminUser).Error; err != nil {
|
||||
d.log.Error("failed to create admin user", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Info("admin user created",
|
||||
"username", "admin",
|
||||
"password", password,
|
||||
"message",
|
||||
"SAVE THIS PASSWORD - it will not be shown again!",
|
||||
"message", "SAVE THIS PASSWORD - it will not be shown again!",
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -265,9 +147,44 @@ func (d *Database) close() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DB returns the underlying GORM database handle.
|
||||
func (d *Database) DB() *gorm.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
// GetOrCreateSessionKey retrieves the session encryption key from the
|
||||
// settings table. If no key exists, a cryptographically secure random
|
||||
// 32-byte key is generated, base64-encoded, and stored for future use.
|
||||
func (d *Database) GetOrCreateSessionKey() (string, error) {
|
||||
var setting Setting
|
||||
result := d.db.Where(&Setting{Key: "session_key"}).First(&setting)
|
||||
if result.Error == nil {
|
||||
return setting.Value, nil
|
||||
}
|
||||
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", fmt.Errorf("failed to query session key: %w", result.Error)
|
||||
}
|
||||
|
||||
// Generate a new cryptographically secure 32-byte key
|
||||
keyBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(keyBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate session key: %w", err)
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(keyBytes)
|
||||
|
||||
setting = Setting{
|
||||
Key: "session_key",
|
||||
Value: encoded,
|
||||
}
|
||||
if err := d.db.Create(&setting).Error; err != nil {
|
||||
return "", fmt.Errorf("failed to store session key: %w", err)
|
||||
}
|
||||
|
||||
d.log.Info("generated new session key and stored in database")
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package database_test
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -6,37 +6,37 @@ import (
|
||||
|
||||
"go.uber.org/fx/fxtest"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
func setupTestDB(
|
||||
t *testing.T,
|
||||
) (*database.Database, *fxtest.Lifecycle) {
|
||||
t.Helper()
|
||||
|
||||
func TestDatabaseConnection(t *testing.T) {
|
||||
// Set up test dependencies
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
g := &globals.Globals{
|
||||
Appname: "webhooker-test",
|
||||
Version: "test",
|
||||
// Create globals
|
||||
globals.Appname = "webhooker-test"
|
||||
globals.Version = "test"
|
||||
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create globals: %v", err)
|
||||
}
|
||||
|
||||
l, err := logger.New(
|
||||
lc,
|
||||
logger.LoggerParams{Globals: g},
|
||||
)
|
||||
// Create logger
|
||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create logger: %v", err)
|
||||
}
|
||||
|
||||
// Create config with DataDir pointing to a temp directory
|
||||
c := &config.Config{
|
||||
DataDir: t.TempDir(),
|
||||
Environment: "dev",
|
||||
}
|
||||
|
||||
db, err := database.New(lc, database.DatabaseParams{
|
||||
// Create database
|
||||
db, err := New(lc, DatabaseParams{
|
||||
Config: c,
|
||||
Logger: l,
|
||||
})
|
||||
@@ -44,45 +44,31 @@ func setupTestDB(
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
return db, lc
|
||||
}
|
||||
|
||||
func TestDatabaseConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, lc := setupTestDB(t)
|
||||
// Start lifecycle (this will trigger the connection)
|
||||
ctx := context.Background()
|
||||
|
||||
err := lc.Start(ctx)
|
||||
err = lc.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
stopErr := lc.Stop(ctx)
|
||||
if stopErr != nil {
|
||||
t.Errorf(
|
||||
"Failed to stop lifecycle: %v",
|
||||
stopErr,
|
||||
)
|
||||
if stopErr := lc.Stop(ctx); stopErr != nil {
|
||||
t.Errorf("Failed to stop lifecycle: %v", stopErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Verify we can get the DB instance
|
||||
if db.DB() == nil {
|
||||
t.Error("Expected non-nil database connection")
|
||||
}
|
||||
|
||||
// Test that we can perform a simple query
|
||||
var result int
|
||||
|
||||
err = db.DB().Raw("SELECT 1").Scan(&result).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute test query: %v", err)
|
||||
}
|
||||
|
||||
if result != 1 {
|
||||
t.Errorf(
|
||||
"Expected query result to be 1, got %d",
|
||||
result,
|
||||
)
|
||||
t.Errorf("Expected query result to be 1, got %d", result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,11 +6,11 @@ import "time"
|
||||
type APIKey struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"userId"`
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Key string `gorm:"uniqueIndex;not null" json:"key"`
|
||||
Description string `json:"description"`
|
||||
LastUsedAt *time.Time `json:"lastUsedAt,omitempty"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitzero"`
|
||||
User User `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
@@ -15,12 +15,12 @@ const (
|
||||
type Delivery struct {
|
||||
BaseModel
|
||||
|
||||
EventID string `gorm:"type:uuid;not null" json:"eventId"`
|
||||
TargetID string `gorm:"type:uuid;not null" json:"targetId"`
|
||||
EventID string `gorm:"type:uuid;not null" json:"event_id"`
|
||||
TargetID string `gorm:"type:uuid;not null" json:"target_id"`
|
||||
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
|
||||
|
||||
// Relations
|
||||
Event Event `json:"event,omitzero"`
|
||||
Target Target `json:"target,omitzero"`
|
||||
DeliveryResults []DeliveryResult `json:"deliveryResults,omitempty"`
|
||||
Event Event `json:"event,omitempty"`
|
||||
Target Target `json:"target,omitempty"`
|
||||
DeliveryResults []DeliveryResult `json:"delivery_results,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,14 +4,14 @@ package database
|
||||
type DeliveryResult struct {
|
||||
BaseModel
|
||||
|
||||
DeliveryID string `gorm:"type:uuid;not null" json:"deliveryId"`
|
||||
AttemptNum int `gorm:"not null" json:"attemptNum"`
|
||||
DeliveryID string `gorm:"type:uuid;not null" json:"delivery_id"`
|
||||
AttemptNum int `gorm:"not null" json:"attempt_num"`
|
||||
Success bool `json:"success"`
|
||||
StatusCode int `json:"statusCode,omitempty"`
|
||||
ResponseBody string `gorm:"type:text" json:"responseBody,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ResponseBody string `gorm:"type:text" json:"response_body,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Duration int64 `json:"durationMs"` // Duration in milliseconds
|
||||
Duration int64 `json:"duration_ms"` // Duration in milliseconds
|
||||
|
||||
// Relations
|
||||
Delivery Delivery `json:"delivery,omitzero"`
|
||||
Delivery Delivery `json:"delivery,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ package database
|
||||
type Entrypoint struct {
|
||||
BaseModel
|
||||
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this entrypoint
|
||||
Description string `json:"description"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Relations
|
||||
Webhook Webhook `json:"webhook,omitzero"`
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,17 +4,17 @@ package database
|
||||
type Event struct {
|
||||
BaseModel
|
||||
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||
EntrypointID string `gorm:"type:uuid;not null" json:"entrypointId"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
EntrypointID string `gorm:"type:uuid;not null" json:"entrypoint_id"`
|
||||
|
||||
// Request data
|
||||
Method string `gorm:"not null" json:"method"`
|
||||
Headers string `gorm:"type:text" json:"headers"` // JSON
|
||||
Body string `gorm:"type:text" json:"body"`
|
||||
ContentType string `json:"contentType"`
|
||||
ContentType string `json:"content_type"`
|
||||
|
||||
// Relations
|
||||
Webhook Webhook `json:"webhook,omitzero"`
|
||||
Entrypoint Entrypoint `json:"entrypoint,omitzero"`
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Entrypoint Entrypoint `json:"entrypoint,omitempty"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ const (
|
||||
type Target struct {
|
||||
BaseModel
|
||||
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type TargetType `gorm:"not null" json:"type"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
@@ -24,10 +24,10 @@ type Target struct {
|
||||
Config string `gorm:"type:text" json:"config"` // JSON configuration
|
||||
|
||||
// For HTTP targets (max_retries=0 means fire-and-forget, >0 enables retries with backoff)
|
||||
MaxRetries int `json:"maxRetries,omitempty"`
|
||||
MaxQueueSize int `json:"maxQueueSize,omitempty"`
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
MaxQueueSize int `json:"max_queue_size,omitempty"`
|
||||
|
||||
// Relations
|
||||
Webhook Webhook `json:"webhook,omitzero"`
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
|
||||
@@ -9,5 +9,5 @@ type User struct {
|
||||
|
||||
// Relations
|
||||
Webhooks []Webhook `json:"webhooks,omitempty"`
|
||||
APIKeys []APIKey `json:"apiKeys,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ package database
|
||||
type Webhook struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"userId"`
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Description string `json:"description"`
|
||||
RetentionDays int `gorm:"default:30" json:"retentionDays"` // Days to retain events
|
||||
RetentionDays int `gorm:"default:30" json:"retention_days"` // Days to retain events
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitzero"`
|
||||
User User `json:"user,omitempty"`
|
||||
Entrypoints []Entrypoint `json:"entrypoints,omitempty"`
|
||||
Targets []Target `json:"targets,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
@@ -21,23 +20,6 @@ const (
|
||||
argon2SaltLen = 16
|
||||
)
|
||||
|
||||
// hashParts is the expected number of $-separated segments
|
||||
// in an encoded Argon2id hash string.
|
||||
const hashParts = 6
|
||||
|
||||
// minPasswordComplexityLen is the minimum password length that
|
||||
// triggers per-character-class complexity enforcement.
|
||||
const minPasswordComplexityLen = 4
|
||||
|
||||
// Sentinel errors returned by decodeHash.
|
||||
var (
|
||||
errInvalidHashFormat = errors.New("invalid hash format")
|
||||
errInvalidAlgorithm = errors.New("invalid algorithm")
|
||||
errIncompatibleVersion = errors.New("incompatible argon2 version")
|
||||
errSaltLengthOutOfRange = errors.New("salt length out of range")
|
||||
errHashLengthOutOfRange = errors.New("hash length out of range")
|
||||
)
|
||||
|
||||
// PasswordConfig holds Argon2 configuration
|
||||
type PasswordConfig struct {
|
||||
Time uint32
|
||||
@@ -64,44 +46,26 @@ func HashPassword(password string) (string, error) {
|
||||
|
||||
// Generate a salt
|
||||
salt := make([]byte, config.SaltLen)
|
||||
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate the hash
|
||||
hash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
config.Time,
|
||||
config.Memory,
|
||||
config.Threads,
|
||||
config.KeyLen,
|
||||
)
|
||||
hash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
|
||||
// Encode the hash and parameters
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
|
||||
encoded := fmt.Sprintf(
|
||||
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version,
|
||||
config.Memory,
|
||||
config.Time,
|
||||
config.Threads,
|
||||
b64Salt,
|
||||
b64Hash,
|
||||
)
|
||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if the provided password matches the hash
|
||||
func VerifyPassword(
|
||||
password, encodedHash string,
|
||||
) (bool, error) {
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
// Extract parameters and hash from encoded string
|
||||
config, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
@@ -109,119 +73,60 @@ func VerifyPassword(
|
||||
}
|
||||
|
||||
// Generate hash of the provided password
|
||||
otherHash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
config.Time,
|
||||
config.Memory,
|
||||
config.Threads,
|
||||
config.KeyLen,
|
||||
)
|
||||
otherHash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
|
||||
// Compare hashes using constant time comparison
|
||||
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
|
||||
}
|
||||
|
||||
// decodeHash extracts parameters, salt, and hash from an
|
||||
// encoded hash string.
|
||||
func decodeHash(
|
||||
encodedHash string,
|
||||
) (*PasswordConfig, []byte, []byte, error) {
|
||||
// decodeHash extracts parameters, salt, and hash from an encoded hash string
|
||||
func decodeHash(encodedHash string) (*PasswordConfig, []byte, []byte, error) {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != hashParts {
|
||||
return nil, nil, nil, errInvalidHashFormat
|
||||
if len(parts) != 6 {
|
||||
return nil, nil, nil, fmt.Errorf("invalid hash format")
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return nil, nil, nil, errInvalidAlgorithm
|
||||
return nil, nil, nil, fmt.Errorf("invalid algorithm")
|
||||
}
|
||||
|
||||
version, err := parseVersion(parts[2])
|
||||
if err != nil {
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, errIncompatibleVersion
|
||||
return nil, nil, nil, fmt.Errorf("incompatible argon2 version")
|
||||
}
|
||||
|
||||
config, err := parseParams(parts[3])
|
||||
if err != nil {
|
||||
config := &PasswordConfig{}
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err := decodeSalt(parts[4])
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
saltLen := len(salt)
|
||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("salt length out of range")
|
||||
}
|
||||
config.SaltLen = uint32(saltLen) // nolint:gosec // checked above
|
||||
|
||||
config.SaltLen = uint32(len(salt)) //nolint:gosec // validated in decodeSalt
|
||||
|
||||
hash, err := decodeHashBytes(parts[5])
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
config.KeyLen = uint32(len(hash)) //nolint:gosec // validated in decodeHashBytes
|
||||
hashLen := len(hash)
|
||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("hash length out of range")
|
||||
}
|
||||
config.KeyLen = uint32(hashLen) // nolint:gosec // checked above
|
||||
|
||||
return config, salt, hash, nil
|
||||
}
|
||||
|
||||
func parseVersion(s string) (int, error) {
|
||||
var version int
|
||||
|
||||
_, err := fmt.Sscanf(s, "v=%d", &version)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing version: %w", err)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func parseParams(s string) (*PasswordConfig, error) {
|
||||
config := &PasswordConfig{}
|
||||
|
||||
_, err := fmt.Sscanf(
|
||||
s, "m=%d,t=%d,p=%d",
|
||||
&config.Memory, &config.Time, &config.Threads,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing params: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func decodeSalt(s string) ([]byte, error) {
|
||||
salt, err := base64.RawStdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding salt: %w", err)
|
||||
}
|
||||
|
||||
saltLen := len(salt)
|
||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||
return nil, errSaltLengthOutOfRange
|
||||
}
|
||||
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func decodeHashBytes(s string) ([]byte, error) {
|
||||
hash, err := base64.RawStdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding hash: %w", err)
|
||||
}
|
||||
|
||||
hashLen := len(hash)
|
||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||
return nil, errHashLengthOutOfRange
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword generates a cryptographically secure
|
||||
// random password.
|
||||
// GenerateRandomPassword generates a cryptographically secure random password
|
||||
func GenerateRandomPassword(length int) (string, error) {
|
||||
const (
|
||||
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
@@ -236,27 +141,27 @@ func GenerateRandomPassword(length int) (string, error) {
|
||||
// Create password slice
|
||||
password := make([]byte, length)
|
||||
|
||||
// Ensure at least one character from each set
|
||||
if length >= minPasswordComplexityLen {
|
||||
// Ensure at least one character from each set for password complexity
|
||||
if length >= 4 {
|
||||
// Get one character from each set
|
||||
password[0] = uppercase[cryptoRandInt(len(uppercase))]
|
||||
password[1] = lowercase[cryptoRandInt(len(lowercase))]
|
||||
password[2] = digits[cryptoRandInt(len(digits))]
|
||||
password[3] = special[cryptoRandInt(len(special))]
|
||||
|
||||
// Fill the rest randomly from all characters
|
||||
for i := minPasswordComplexityLen; i < length; i++ {
|
||||
for i := 4; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
|
||||
// Shuffle the password to avoid predictable pattern
|
||||
for i := range len(password) - 1 {
|
||||
j := cryptoRandInt(len(password) - i)
|
||||
idx := len(password) - 1 - i
|
||||
password[idx], password[j] = password[j], password[idx]
|
||||
for i := len(password) - 1; i > 0; i-- {
|
||||
j := cryptoRandInt(i + 1)
|
||||
password[i], password[j] = password[j], password[i]
|
||||
}
|
||||
} else {
|
||||
// For very short passwords, just use all characters
|
||||
for i := range length {
|
||||
for i := 0; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
}
|
||||
@@ -264,17 +169,16 @@ func GenerateRandomPassword(length int) (string, error) {
|
||||
return string(password), nil
|
||||
}
|
||||
|
||||
// cryptoRandInt generates a cryptographically secure random
|
||||
// integer in [0, upperBound).
|
||||
// cryptoRandInt generates a cryptographically secure random integer in [0, upperBound).
|
||||
func cryptoRandInt(upperBound int) int {
|
||||
if upperBound <= 0 {
|
||||
panic("upperBound must be positive")
|
||||
}
|
||||
|
||||
nBig, err := rand.Int(
|
||||
rand.Reader,
|
||||
big.NewInt(int64(upperBound)),
|
||||
)
|
||||
// Calculate the maximum valid value to avoid modulo bias
|
||||
// For example, if upperBound=200 and we have 256 possible values,
|
||||
// we only accept values 0-199 (reject 200-255)
|
||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(upperBound)))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand error: %v", err))
|
||||
}
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
package database_test
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
)
|
||||
|
||||
func TestGenerateRandomPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
@@ -22,172 +18,109 @@ func TestGenerateRandomPassword(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
password, err := database.GenerateRandomPassword(
|
||||
tt.length,
|
||||
)
|
||||
password, err := GenerateRandomPassword(tt.length)
|
||||
if err != nil {
|
||||
t.Fatalf(
|
||||
"GenerateRandomPassword() error = %v",
|
||||
err,
|
||||
)
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if len(password) != tt.length {
|
||||
t.Errorf(
|
||||
"Password length = %v, want %v",
|
||||
len(password), tt.length,
|
||||
)
|
||||
t.Errorf("Password length = %v, want %v", len(password), tt.length)
|
||||
}
|
||||
|
||||
checkPasswordComplexity(
|
||||
t, password, tt.length,
|
||||
)
|
||||
// For passwords >= 4 chars, check complexity
|
||||
if tt.length >= 4 {
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
hasSpecial := false
|
||||
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
||||
t.Errorf("Password lacks required complexity: upper=%v, lower=%v, digit=%v, special=%v",
|
||||
hasUpper, hasLower, hasDigit, hasSpecial)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkPasswordComplexity(
|
||||
t *testing.T,
|
||||
password string,
|
||||
length int,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
// For passwords >= 4 chars, check complexity
|
||||
if length < 4 {
|
||||
return
|
||||
}
|
||||
|
||||
flags := classifyChars(password)
|
||||
|
||||
if !flags[0] || !flags[1] || !flags[2] || !flags[3] {
|
||||
t.Errorf(
|
||||
"Password lacks required complexity: "+
|
||||
"upper=%v, lower=%v, digit=%v, special=%v",
|
||||
flags[0], flags[1], flags[2], flags[3],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func classifyChars(s string) [4]bool {
|
||||
var flags [4]bool // upper, lower, digit, special
|
||||
|
||||
for _, char := range s {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
flags[0] = true
|
||||
case char >= 'a' && char <= 'z':
|
||||
flags[1] = true
|
||||
case char >= '0' && char <= '9':
|
||||
flags[2] = true
|
||||
case strings.ContainsRune(
|
||||
"!@#$%^&*()_+-=[]{}|;:,.<>?",
|
||||
char,
|
||||
):
|
||||
flags[3] = true
|
||||
}
|
||||
}
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
func TestGenerateRandomPasswordUniqueness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate multiple passwords and ensure they're different
|
||||
passwords := make(map[string]bool)
|
||||
|
||||
const numPasswords = 100
|
||||
|
||||
for range numPasswords {
|
||||
password, err := database.GenerateRandomPassword(16)
|
||||
for i := 0; i < numPasswords; i++ {
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
t.Fatalf(
|
||||
"GenerateRandomPassword() error = %v",
|
||||
err,
|
||||
)
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if passwords[password] {
|
||||
t.Errorf(
|
||||
"Duplicate password generated: %s",
|
||||
password,
|
||||
)
|
||||
t.Errorf("Duplicate password generated: %s", password)
|
||||
}
|
||||
|
||||
passwords[password] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
password := "testPassword123!"
|
||||
|
||||
hash, err := database.HashPassword(password)
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
// Check that hash has correct format
|
||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||
t.Errorf(
|
||||
"Hash doesn't have correct prefix: %s",
|
||||
hash,
|
||||
)
|
||||
t.Errorf("Hash doesn't have correct prefix: %s", hash)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
valid, err := database.VerifyPassword(password, hash)
|
||||
valid, err := VerifyPassword(password, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error(
|
||||
"VerifyPassword() returned false " +
|
||||
"for correct password",
|
||||
)
|
||||
t.Error("VerifyPassword() returned false for correct password")
|
||||
}
|
||||
|
||||
// Verify wrong password fails
|
||||
valid, err = database.VerifyPassword(
|
||||
"wrongPassword", hash,
|
||||
)
|
||||
valid, err = VerifyPassword("wrongPassword", hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error(
|
||||
"VerifyPassword() returned true " +
|
||||
"for wrong password",
|
||||
)
|
||||
t.Error("VerifyPassword() returned true for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPasswordUniqueness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
password := "testPassword123!"
|
||||
|
||||
// Same password should produce different hashes
|
||||
hash1, err := database.HashPassword(password)
|
||||
// Same password should produce different hashes due to salt
|
||||
hash1, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
hash2, err := database.HashPassword(password)
|
||||
hash2, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error(
|
||||
"Same password produced identical hashes " +
|
||||
"(salt not working)",
|
||||
)
|
||||
t.Error("Same password produced identical hashes (salt not working)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -17,82 +16,87 @@ import (
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
// WebhookDBManagerParams holds the fx dependencies for
|
||||
// WebhookDBManager.
|
||||
// nolint:revive // WebhookDBManagerParams is a standard fx naming convention
|
||||
type WebhookDBManagerParams struct {
|
||||
fx.In
|
||||
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// errInvalidCachedDBType indicates a type assertion failure
|
||||
// when retrieving a cached database connection.
|
||||
var errInvalidCachedDBType = errors.New(
|
||||
"invalid cached database type",
|
||||
)
|
||||
|
||||
// WebhookDBManager manages per-webhook SQLite database files
|
||||
// for event storage. Each webhook gets its own dedicated
|
||||
// database containing Events, Deliveries, and DeliveryResults.
|
||||
// Database connections are opened lazily and cached.
|
||||
// WebhookDBManager manages per-webhook SQLite database files for event storage.
|
||||
// Each webhook gets its own dedicated database containing Events, Deliveries,
|
||||
// and DeliveryResults. Database connections are opened lazily and cached.
|
||||
type WebhookDBManager struct {
|
||||
dataDir string
|
||||
dbs sync.Map // map[webhookID]*gorm.DB
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// NewWebhookDBManager creates a new WebhookDBManager and
|
||||
// registers lifecycle hooks.
|
||||
func NewWebhookDBManager(
|
||||
lc fx.Lifecycle,
|
||||
params WebhookDBManagerParams,
|
||||
) (*WebhookDBManager, error) {
|
||||
// NewWebhookDBManager creates a new WebhookDBManager and registers lifecycle hooks.
|
||||
func NewWebhookDBManager(lc fx.Lifecycle, params WebhookDBManagerParams) (*WebhookDBManager, error) {
|
||||
m := &WebhookDBManager{
|
||||
dataDir: params.Config.DataDir,
|
||||
log: params.Logger.Get(),
|
||||
}
|
||||
|
||||
// Create data directory if it doesn't exist
|
||||
err := os.MkdirAll(m.dataDir, dataDirPerm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"creating data directory %s: %w",
|
||||
m.dataDir,
|
||||
err,
|
||||
)
|
||||
if err := os.MkdirAll(m.dataDir, 0750); err != nil {
|
||||
return nil, fmt.Errorf("creating data directory %s: %w", m.dataDir, err)
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStop: func(_ context.Context) error {
|
||||
OnStop: func(_ context.Context) error { //nolint:revive // ctx unused but required by fx
|
||||
return m.CloseAll()
|
||||
},
|
||||
})
|
||||
|
||||
m.log.Info(
|
||||
"webhook database manager initialized",
|
||||
"data_dir", m.dataDir,
|
||||
)
|
||||
|
||||
m.log.Info("webhook database manager initialized", "data_dir", m.dataDir)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetDB returns the database connection for a webhook,
|
||||
// creating the database file lazily if it doesn't exist.
|
||||
func (m *WebhookDBManager) GetDB(
|
||||
webhookID string,
|
||||
) (*gorm.DB, error) {
|
||||
// dbPath returns the filesystem path for a webhook's database file.
|
||||
func (m *WebhookDBManager) dbPath(webhookID string) string {
|
||||
return filepath.Join(m.dataDir, fmt.Sprintf("events-%s.db", webhookID))
|
||||
}
|
||||
|
||||
// openDB opens (or creates) a per-webhook SQLite database and runs migrations.
|
||||
func (m *WebhookDBManager) openDB(webhookID string) (*gorm.DB, error) {
|
||||
path := m.dbPath(webhookID)
|
||||
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", path)
|
||||
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening webhook database %s: %w", webhookID, err)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("connecting to webhook database %s: %w", webhookID, err)
|
||||
}
|
||||
|
||||
// Run migrations for event-tier models only
|
||||
if err := db.AutoMigrate(&Event{}, &Delivery{}, &DeliveryResult{}); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("migrating webhook database %s: %w", webhookID, err)
|
||||
}
|
||||
|
||||
m.log.Info("opened per-webhook database", "webhook_id", webhookID, "path", path)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// GetDB returns the database connection for a webhook, creating the database
|
||||
// file lazily if it doesn't exist. This handles both new webhooks and existing
|
||||
// webhooks that were created before per-webhook databases were introduced.
|
||||
func (m *WebhookDBManager) GetDB(webhookID string) (*gorm.DB, error) {
|
||||
// Fast path: already open
|
||||
if val, ok := m.dbs.Load(webhookID); ok {
|
||||
cachedDB, castOK := val.(*gorm.DB)
|
||||
if !castOK {
|
||||
return nil, fmt.Errorf(
|
||||
"%w for webhook %s",
|
||||
errInvalidCachedDBType,
|
||||
webhookID,
|
||||
)
|
||||
return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
|
||||
}
|
||||
|
||||
return cachedDB, nil
|
||||
}
|
||||
|
||||
@@ -102,60 +106,43 @@ func (m *WebhookDBManager) GetDB(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store it; if another goroutine beat us, close ours
|
||||
// Store it; if another goroutine beat us, close ours and use theirs
|
||||
actual, loaded := m.dbs.LoadOrStore(webhookID, db)
|
||||
if loaded {
|
||||
// Another goroutine created it first; close our duplicate
|
||||
sqlDB, closeErr := db.DB()
|
||||
if closeErr == nil {
|
||||
if sqlDB, closeErr := db.DB(); closeErr == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
|
||||
existingDB, castOK := actual.(*gorm.DB)
|
||||
if !castOK {
|
||||
return nil, fmt.Errorf(
|
||||
"%w for webhook %s",
|
||||
errInvalidCachedDBType,
|
||||
webhookID,
|
||||
)
|
||||
return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
|
||||
}
|
||||
|
||||
return existingDB, nil
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// CreateDB explicitly creates a new per-webhook database file
|
||||
// and runs migrations.
|
||||
func (m *WebhookDBManager) CreateDB(
|
||||
webhookID string,
|
||||
) error {
|
||||
// CreateDB explicitly creates a new per-webhook database file and runs migrations.
|
||||
// This is called when a new webhook is created.
|
||||
func (m *WebhookDBManager) CreateDB(webhookID string) error {
|
||||
_, err := m.GetDB(webhookID)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DBExists checks if a per-webhook database file exists on
|
||||
// disk.
|
||||
func (m *WebhookDBManager) DBExists(
|
||||
webhookID string,
|
||||
) bool {
|
||||
// DBExists checks if a per-webhook database file exists on disk.
|
||||
func (m *WebhookDBManager) DBExists(webhookID string) bool {
|
||||
_, err := os.Stat(m.dbPath(webhookID))
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// DeleteDB closes the connection and deletes the database file
|
||||
// for a webhook. The file is permanently removed.
|
||||
func (m *WebhookDBManager) DeleteDB(
|
||||
webhookID string,
|
||||
) error {
|
||||
// DeleteDB closes the connection and deletes the database file for a webhook.
|
||||
// This performs a hard delete — the file is permanently removed.
|
||||
func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
||||
// Close and remove from cache
|
||||
if val, ok := m.dbs.LoadAndDelete(webhookID); ok {
|
||||
if gormDB, castOK := val.(*gorm.DB); castOK {
|
||||
sqlDB, err := gormDB.DB()
|
||||
if err == nil {
|
||||
if sqlDB, err := gormDB.DB(); err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
}
|
||||
@@ -164,20 +151,12 @@ func (m *WebhookDBManager) DeleteDB(
|
||||
// Delete the main DB file and WAL/SHM files
|
||||
path := m.dbPath(webhookID)
|
||||
for _, suffix := range []string{"", "-wal", "-shm"} {
|
||||
err := os.Remove(path + suffix)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf(
|
||||
"deleting webhook database file %s%s: %w",
|
||||
path, suffix, err,
|
||||
)
|
||||
if err := os.Remove(path + suffix); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("deleting webhook database file %s%s: %w", path, suffix, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.log.Info(
|
||||
"deleted per-webhook database",
|
||||
"webhook_id", webhookID,
|
||||
)
|
||||
|
||||
m.log.Info("deleted per-webhook database", "webhook_id", webhookID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -185,97 +164,20 @@ func (m *WebhookDBManager) DeleteDB(
|
||||
// Called during application shutdown.
|
||||
func (m *WebhookDBManager) CloseAll() error {
|
||||
var lastErr error
|
||||
|
||||
m.dbs.Range(func(key, value any) bool {
|
||||
m.dbs.Range(func(key, value interface{}) bool {
|
||||
if gormDB, castOK := value.(*gorm.DB); castOK {
|
||||
sqlDB, err := gormDB.DB()
|
||||
if err == nil {
|
||||
closeErr := sqlDB.Close()
|
||||
if closeErr != nil {
|
||||
if sqlDB, err := gormDB.DB(); err == nil {
|
||||
if closeErr := sqlDB.Close(); closeErr != nil {
|
||||
lastErr = closeErr
|
||||
m.log.Error(
|
||||
"failed to close webhook database",
|
||||
m.log.Error("failed to close webhook database",
|
||||
"webhook_id", key,
|
||||
"error", closeErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.dbs.Delete(key)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// DBPath returns the filesystem path for a webhook's database
|
||||
// file.
|
||||
func (m *WebhookDBManager) DBPath(
|
||||
webhookID string,
|
||||
) string {
|
||||
return m.dbPath(webhookID)
|
||||
}
|
||||
|
||||
func (m *WebhookDBManager) dbPath(
|
||||
webhookID string,
|
||||
) string {
|
||||
return filepath.Join(
|
||||
m.dataDir,
|
||||
fmt.Sprintf("events-%s.db", webhookID),
|
||||
)
|
||||
}
|
||||
|
||||
// openDB opens (or creates) a per-webhook SQLite database and
|
||||
// runs migrations.
|
||||
func (m *WebhookDBManager) openDB(
|
||||
webhookID string,
|
||||
) (*gorm.DB, error) {
|
||||
path := m.dbPath(webhookID)
|
||||
dbURL := fmt.Sprintf(
|
||||
"file:%s?cache=shared&mode=rwc",
|
||||
path,
|
||||
)
|
||||
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"opening webhook database %s: %w",
|
||||
webhookID, err,
|
||||
)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
_ = sqlDB.Close()
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"connecting to webhook database %s: %w",
|
||||
webhookID, err,
|
||||
)
|
||||
}
|
||||
|
||||
// Run migrations for event-tier models only
|
||||
err = db.AutoMigrate(
|
||||
&Event{}, &Delivery{}, &DeliveryResult{},
|
||||
)
|
||||
if err != nil {
|
||||
_ = sqlDB.Close()
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"migrating webhook database %s: %w",
|
||||
webhookID, err,
|
||||
)
|
||||
}
|
||||
|
||||
m.log.Info(
|
||||
"opened per-webhook database",
|
||||
"webhook_id", webhookID,
|
||||
"path", path,
|
||||
)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package database_test
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -10,29 +10,23 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/fx/fxtest"
|
||||
"gorm.io/gorm"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
func setupTestWebhookDBManager(
|
||||
t *testing.T,
|
||||
) (*database.WebhookDBManager, *fxtest.Lifecycle) {
|
||||
func setupTestWebhookDBManager(t *testing.T) (*WebhookDBManager, *fxtest.Lifecycle) {
|
||||
t.Helper()
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
g := &globals.Globals{
|
||||
Appname: "webhooker-test",
|
||||
Version: "test",
|
||||
}
|
||||
globals.Appname = "webhooker-test"
|
||||
globals.Version = "test"
|
||||
|
||||
l, err := logger.New(
|
||||
lc,
|
||||
logger.LoggerParams{Globals: g},
|
||||
)
|
||||
g, err := globals.New(lc)
|
||||
require.NoError(t, err)
|
||||
|
||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
||||
require.NoError(t, err)
|
||||
|
||||
dataDir := filepath.Join(t.TempDir(), "events")
|
||||
@@ -41,25 +35,19 @@ func setupTestWebhookDBManager(
|
||||
DataDir: dataDir,
|
||||
}
|
||||
|
||||
mgr, err := database.NewWebhookDBManager(
|
||||
lc,
|
||||
database.WebhookDBManagerParams{
|
||||
mgr, err := NewWebhookDBManager(lc, WebhookDBManagerParams{
|
||||
Config: cfg,
|
||||
Logger: l,
|
||||
},
|
||||
)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return mgr, lc
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
@@ -80,7 +68,7 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||
require.NotNil(t, db)
|
||||
|
||||
// Verify we can write an event
|
||||
event := &database.Event{
|
||||
event := &Event{
|
||||
WebhookID: webhookID,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
@@ -92,35 +80,27 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||
assert.NotEmpty(t, event.ID)
|
||||
|
||||
// Verify we can read it back
|
||||
var readEvent database.Event
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.First(&readEvent, "id = ?", event.ID).Error,
|
||||
)
|
||||
var readEvent Event
|
||||
require.NoError(t, db.First(&readEvent, "id = ?", event.ID).Error)
|
||||
assert.Equal(t, webhookID, readEvent.WebhookID)
|
||||
assert.Equal(t, "POST", readEvent.Method)
|
||||
assert.Equal(t, `{"test": true}`, readEvent.Body)
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
|
||||
// Create the DB and write some data
|
||||
require.NoError(t, mgr.CreateDB(webhookID))
|
||||
|
||||
db, err := mgr.GetDB(webhookID)
|
||||
require.NoError(t, err)
|
||||
|
||||
event := &database.Event{
|
||||
event := &Event{
|
||||
WebhookID: webhookID,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
@@ -136,19 +116,15 @@ func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
||||
assert.False(t, mgr.DBExists(webhookID))
|
||||
|
||||
// Verify the file is actually gone from disk
|
||||
dbPath := mgr.DBPath(webhookID)
|
||||
|
||||
dbPath := mgr.dbPath(webhookID)
|
||||
_, err = os.Stat(dbPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
@@ -163,12 +139,9 @@ func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
@@ -177,23 +150,8 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||
db, err := mgr.GetDB(webhookID)
|
||||
require.NoError(t, err)
|
||||
|
||||
event, delivery := seedDeliveryWorkflow(
|
||||
t, db, webhookID, targetID,
|
||||
)
|
||||
|
||||
verifyPendingDeliveries(t, db, event)
|
||||
completeDelivery(t, db, delivery)
|
||||
verifyNoPending(t, db)
|
||||
}
|
||||
|
||||
func seedDeliveryWorkflow(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
webhookID, targetID string,
|
||||
) (*database.Event, *database.Delivery) {
|
||||
t.Helper()
|
||||
|
||||
event := &database.Event{
|
||||
// Create an event
|
||||
event := &Event{
|
||||
WebhookID: webhookID,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
@@ -203,45 +161,25 @@ func seedDeliveryWorkflow(
|
||||
}
|
||||
require.NoError(t, db.Create(event).Error)
|
||||
|
||||
delivery := &database.Delivery{
|
||||
// Create a delivery
|
||||
delivery := &Delivery{
|
||||
EventID: event.ID,
|
||||
TargetID: targetID,
|
||||
Status: database.DeliveryStatusPending,
|
||||
Status: DeliveryStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(delivery).Error)
|
||||
|
||||
return event, delivery
|
||||
}
|
||||
|
||||
func verifyPendingDeliveries(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
event *database.Event,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
var pending []database.Delivery
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.Where(
|
||||
"status = ?",
|
||||
database.DeliveryStatusPending,
|
||||
).Preload("Event").Find(&pending).Error,
|
||||
)
|
||||
// Query pending deliveries
|
||||
var pending []Delivery
|
||||
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).
|
||||
Preload("Event").
|
||||
Find(&pending).Error)
|
||||
require.Len(t, pending, 1)
|
||||
assert.Equal(t, event.ID, pending[0].EventID)
|
||||
assert.Equal(t, "POST", pending[0].Event.Method)
|
||||
}
|
||||
|
||||
func completeDelivery(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
delivery *database.Delivery,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
result := &database.DeliveryResult{
|
||||
// Create a delivery result
|
||||
result := &DeliveryResult{
|
||||
DeliveryID: delivery.ID,
|
||||
AttemptNum: 1,
|
||||
Success: true,
|
||||
@@ -250,40 +188,19 @@ func completeDelivery(
|
||||
}
|
||||
require.NoError(t, db.Create(result).Error)
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.Model(delivery).Update(
|
||||
"status",
|
||||
database.DeliveryStatusDelivered,
|
||||
).Error,
|
||||
)
|
||||
}
|
||||
// Update delivery status
|
||||
require.NoError(t, db.Model(delivery).Update("status", DeliveryStatusDelivered).Error)
|
||||
|
||||
func verifyNoPending(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
var stillPending []database.Delivery
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.Where(
|
||||
"status = ?",
|
||||
database.DeliveryStatusPending,
|
||||
).Find(&stillPending).Error,
|
||||
)
|
||||
// Verify no more pending deliveries
|
||||
var stillPending []Delivery
|
||||
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).Find(&stillPending).Error)
|
||||
assert.Empty(t, stillPending)
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhook1 := uuid.New().String()
|
||||
@@ -295,38 +212,34 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||
|
||||
db1, err := mgr.GetDB(webhook1)
|
||||
require.NoError(t, err)
|
||||
|
||||
db2, err := mgr.GetDB(webhook2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write events to each webhook's DB
|
||||
event1 := &database.Event{
|
||||
event1 := &Event{
|
||||
WebhookID: webhook1,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
Body: `{"webhook": 1}`,
|
||||
ContentType: "application/json",
|
||||
}
|
||||
event2 := &database.Event{
|
||||
event2 := &Event{
|
||||
WebhookID: webhook2,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "PUT",
|
||||
Body: `{"webhook": 2}`,
|
||||
ContentType: "application/json",
|
||||
}
|
||||
|
||||
require.NoError(t, db1.Create(event1).Error)
|
||||
require.NoError(t, db2.Create(event2).Error)
|
||||
|
||||
// Verify isolation: each DB only has its own events
|
||||
var count1 int64
|
||||
|
||||
db1.Model(&database.Event{}).Count(&count1)
|
||||
db1.Model(&Event{}).Count(&count1)
|
||||
assert.Equal(t, int64(1), count1)
|
||||
|
||||
var count2 int64
|
||||
|
||||
db2.Model(&database.Event{}).Count(&count2)
|
||||
db2.Model(&Event{}).Count(&count2)
|
||||
assert.Equal(t, int64(1), count2)
|
||||
|
||||
// Delete webhook1's DB, webhook2 should be unaffected
|
||||
@@ -335,31 +248,25 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||
assert.True(t, mgr.DBExists(webhook2))
|
||||
|
||||
// webhook2's data should still be accessible
|
||||
var events []database.Event
|
||||
|
||||
var events []Event
|
||||
require.NoError(t, db2.Find(&events).Error)
|
||||
assert.Len(t, events, 1)
|
||||
assert.Equal(t, "PUT", events[0].Method)
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_CloseAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
// Create a few DBs
|
||||
for range 3 {
|
||||
require.NoError(
|
||||
t,
|
||||
mgr.CreateDB(uuid.New().String()),
|
||||
)
|
||||
for i := 0; i < 3; i++ {
|
||||
require.NoError(t, mgr.CreateDB(uuid.New().String()))
|
||||
}
|
||||
|
||||
// CloseAll should close all connections without error
|
||||
require.NoError(t, mgr.CloseAll())
|
||||
|
||||
// Stop lifecycle (CloseAll already called)
|
||||
// Stop lifecycle (CloseAll already called, but shouldn't panic)
|
||||
require.NoError(t, lc.Stop(ctx))
|
||||
}
|
||||
|
||||
@@ -5,32 +5,41 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CircuitState represents the current state of a circuit
|
||||
// breaker.
|
||||
// CircuitState represents the current state of a circuit breaker.
|
||||
type CircuitState int
|
||||
|
||||
const (
|
||||
// CircuitClosed is the normal operating state.
|
||||
// CircuitClosed is the normal operating state. Deliveries flow through.
|
||||
CircuitClosed CircuitState = iota
|
||||
// CircuitOpen means the circuit has tripped.
|
||||
// CircuitOpen means the circuit has tripped. Deliveries are skipped
|
||||
// until the cooldown expires.
|
||||
CircuitOpen
|
||||
// CircuitHalfOpen allows a single probe delivery to
|
||||
// test whether the target has recovered.
|
||||
// CircuitHalfOpen allows a single probe delivery to test whether
|
||||
// the target has recovered.
|
||||
CircuitHalfOpen
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultFailureThreshold is the number of consecutive
|
||||
// failures before a circuit breaker trips open.
|
||||
// defaultFailureThreshold is the number of consecutive failures
|
||||
// before a circuit breaker trips open.
|
||||
defaultFailureThreshold = 5
|
||||
|
||||
// defaultCooldown is how long a circuit stays open
|
||||
// before transitioning to half-open.
|
||||
// defaultCooldown is how long a circuit stays open before
|
||||
// transitioning to half-open for a probe delivery.
|
||||
defaultCooldown = 30 * time.Second
|
||||
)
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
// for a single delivery target.
|
||||
// CircuitBreaker implements the circuit breaker pattern for a single
|
||||
// delivery target. It tracks consecutive failures and prevents
|
||||
// hammering a down target by temporarily stopping delivery attempts.
|
||||
//
|
||||
// States:
|
||||
// - Closed (normal): deliveries flow through; consecutive failures
|
||||
// are counted.
|
||||
// - Open (tripped): deliveries are skipped; a cooldown timer is
|
||||
// running. After the cooldown expires the state moves to HalfOpen.
|
||||
// - HalfOpen (probing): one probe delivery is allowed. If it
|
||||
// succeeds the circuit closes; if it fails the circuit reopens.
|
||||
type CircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state CircuitState
|
||||
@@ -40,8 +49,7 @@ type CircuitBreaker struct {
|
||||
lastFailure time.Time
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a circuit breaker with default
|
||||
// settings.
|
||||
// NewCircuitBreaker creates a circuit breaker with default settings.
|
||||
func NewCircuitBreaker() *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
@@ -50,7 +58,12 @@ func NewCircuitBreaker() *CircuitBreaker {
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks whether a delivery attempt should proceed.
|
||||
// Allow checks whether a delivery attempt should proceed. It returns
|
||||
// true if the delivery should be attempted, false if the circuit is
|
||||
// open and the delivery should be skipped.
|
||||
//
|
||||
// When the circuit is open and the cooldown has elapsed, Allow
|
||||
// transitions to half-open and permits exactly one probe delivery.
|
||||
func (cb *CircuitBreaker) Allow() bool {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -60,15 +73,17 @@ func (cb *CircuitBreaker) Allow() bool {
|
||||
return true
|
||||
|
||||
case CircuitOpen:
|
||||
// Check if cooldown has elapsed
|
||||
if time.Since(cb.lastFailure) >= cb.cooldown {
|
||||
cb.state = CircuitHalfOpen
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
case CircuitHalfOpen:
|
||||
// Only one probe at a time — reject additional attempts while
|
||||
// a probe is in flight. The probe goroutine will call
|
||||
// RecordSuccess or RecordFailure to resolve the state.
|
||||
return false
|
||||
|
||||
default:
|
||||
@@ -76,8 +91,9 @@ func (cb *CircuitBreaker) Allow() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// CooldownRemaining returns how much time is left before
|
||||
// an open circuit transitions to half-open.
|
||||
// CooldownRemaining returns how much time is left before an open circuit
|
||||
// transitions to half-open. Returns zero if the circuit is not open or
|
||||
// the cooldown has already elapsed.
|
||||
func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -90,12 +106,11 @@ func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful delivery and resets
|
||||
// the circuit breaker to closed state.
|
||||
// RecordSuccess records a successful delivery and resets the circuit
|
||||
// breaker to closed state with zero failures.
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -104,8 +119,8 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.state = CircuitClosed
|
||||
}
|
||||
|
||||
// RecordFailure records a failed delivery. If the failure
|
||||
// count reaches the threshold, the circuit trips open.
|
||||
// RecordFailure records a failed delivery. If the failure count reaches
|
||||
// the threshold, the circuit trips open.
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -119,25 +134,20 @@ func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.state = CircuitOpen
|
||||
}
|
||||
|
||||
case CircuitOpen:
|
||||
// Already open; no state change needed.
|
||||
|
||||
case CircuitHalfOpen:
|
||||
// Probe failed -- reopen immediately.
|
||||
// Probe failed — reopen immediately
|
||||
cb.state = CircuitOpen
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current circuit state.
|
||||
// State returns the current circuit state. Safe for concurrent use.
|
||||
func (cb *CircuitBreaker) State() CircuitState {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// String returns the human-readable name of a circuit
|
||||
// state.
|
||||
// String returns the human-readable name of a circuit state.
|
||||
func (s CircuitState) String() string {
|
||||
switch s {
|
||||
case CircuitClosed:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package delivery_test
|
||||
package delivery
|
||||
|
||||
import (
|
||||
"sync"
|
||||
@@ -7,304 +7,237 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/delivery"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker_ClosedState_AllowsDeliveries(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_ClosedState_AllowsDeliveries(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State())
|
||||
assert.True(t, cb.Allow(),
|
||||
"closed circuit should allow deliveries",
|
||||
)
|
||||
|
||||
for range 10 {
|
||||
assert.Equal(t, CircuitClosed, cb.State())
|
||||
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
|
||||
// Multiple calls should all succeed
|
||||
for i := 0; i < 10; i++ {
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_FailureCounting(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for i := range delivery.ExportDefaultFailureThreshold - 1 {
|
||||
// Record failures below threshold — circuit should stay closed
|
||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
||||
cb.RecordFailure()
|
||||
|
||||
assert.Equal(t,
|
||||
delivery.CircuitClosed, cb.State(),
|
||||
"circuit should remain closed after %d failures",
|
||||
i+1,
|
||||
)
|
||||
|
||||
assert.True(t, cb.Allow(),
|
||||
"should still allow after %d failures",
|
||||
i+1,
|
||||
)
|
||||
assert.Equal(t, CircuitClosed, cb.State(),
|
||||
"circuit should remain closed after %d failures", i+1)
|
||||
assert.True(t, cb.Allow(), "should still allow after %d failures", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_OpenTransition(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
// Record exactly threshold failures
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
|
||||
assert.Equal(t, delivery.CircuitOpen, cb.State(),
|
||||
"circuit should be open after threshold failures",
|
||||
)
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"open circuit should reject deliveries",
|
||||
)
|
||||
assert.Equal(t, CircuitOpen, cb.State(), "circuit should be open after threshold failures")
|
||||
assert.False(t, cb.Allow(), "open circuit should reject deliveries")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Cooldown_StaysOpen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
// Use a circuit with a known short cooldown for testing
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
require.Equal(t, delivery.CircuitOpen, cb.State())
|
||||
// Trip the circuit open
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
require.Equal(t, CircuitOpen, cb.State())
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"should be blocked during cooldown",
|
||||
)
|
||||
// During cooldown, Allow should return false
|
||||
assert.False(t, cb.Allow(), "should be blocked during cooldown")
|
||||
|
||||
// CooldownRemaining should be positive
|
||||
remaining := cb.CooldownRemaining()
|
||||
|
||||
assert.Greater(t, remaining, time.Duration(0),
|
||||
"cooldown should have remaining time",
|
||||
)
|
||||
assert.Greater(t, remaining, time.Duration(0), "cooldown should have remaining time")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpen_AfterCooldown(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_HalfOpen_AfterCooldown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
require.Equal(t, delivery.CircuitOpen, cb.State())
|
||||
// Trip the circuit open
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
require.Equal(t, CircuitOpen, cb.State())
|
||||
|
||||
// Wait for cooldown to expire
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, time.Duration(0),
|
||||
cb.CooldownRemaining(),
|
||||
)
|
||||
// CooldownRemaining should be zero after cooldown
|
||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining())
|
||||
|
||||
assert.True(t, cb.Allow(),
|
||||
"should allow one probe after cooldown",
|
||||
)
|
||||
// First Allow after cooldown should succeed (probe)
|
||||
assert.True(t, cb.Allow(), "should allow one probe after cooldown")
|
||||
assert.Equal(t, CircuitHalfOpen, cb.State(), "should be half-open after probe allowed")
|
||||
|
||||
assert.Equal(t,
|
||||
delivery.CircuitHalfOpen, cb.State(),
|
||||
"should be half-open after probe allowed",
|
||||
)
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"should reject additional probes while half-open",
|
||||
)
|
||||
// Second Allow should be rejected (only one probe at a time)
|
||||
assert.False(t, cb.Allow(), "should reject additional probes while half-open")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip open → wait for cooldown → allow probe
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
require.True(t, cb.Allow()) // probe allowed, state → half-open
|
||||
|
||||
require.True(t, cb.Allow())
|
||||
|
||||
// Probe succeeds → circuit should close
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, CircuitClosed, cb.State(), "successful probe should close circuit")
|
||||
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State(),
|
||||
"successful probe should close circuit",
|
||||
)
|
||||
|
||||
assert.True(t, cb.Allow(),
|
||||
"closed circuit should allow deliveries",
|
||||
)
|
||||
// Should allow deliveries again
|
||||
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ProbeFailure_ReopensCircuit(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_ProbeFailure_ReopensCircuit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip open → wait for cooldown → allow probe
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
require.True(t, cb.Allow()) // probe allowed, state → half-open
|
||||
|
||||
require.True(t, cb.Allow())
|
||||
|
||||
// Probe fails → circuit should reopen
|
||||
cb.RecordFailure()
|
||||
|
||||
assert.Equal(t, delivery.CircuitOpen, cb.State(),
|
||||
"failed probe should reopen circuit",
|
||||
)
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"reopened circuit should reject deliveries",
|
||||
)
|
||||
assert.Equal(t, CircuitOpen, cb.State(), "failed probe should reopen circuit")
|
||||
assert.False(t, cb.Allow(), "reopened circuit should reject deliveries")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_SuccessResetsFailures(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold - 1 {
|
||||
// Accumulate failures just below threshold
|
||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
require.Equal(t, CircuitClosed, cb.State())
|
||||
|
||||
require.Equal(t, delivery.CircuitClosed, cb.State())
|
||||
|
||||
// Success should reset the failure counter
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, CircuitClosed, cb.State())
|
||||
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State())
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold - 1 {
|
||||
// Now we should need another full threshold of failures to trip
|
||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
assert.Equal(t, CircuitClosed, cb.State(),
|
||||
"circuit should still be closed — success reset the counter")
|
||||
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State(),
|
||||
"circuit should still be closed -- "+
|
||||
"success reset the counter",
|
||||
)
|
||||
|
||||
// One more failure should trip it
|
||||
cb.RecordFailure()
|
||||
|
||||
assert.Equal(t, delivery.CircuitOpen, cb.State())
|
||||
assert.Equal(t, CircuitOpen, cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
const goroutines = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
for range goroutines {
|
||||
// Concurrent Allow calls
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
cb.Allow()
|
||||
}()
|
||||
}
|
||||
|
||||
for range goroutines {
|
||||
// Concurrent RecordFailure calls
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
cb.RecordFailure()
|
||||
}()
|
||||
}
|
||||
|
||||
for range goroutines {
|
||||
// Concurrent RecordSuccess calls
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
cb.RecordSuccess()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// No panic or data race — the test passes if -race doesn't flag anything.
|
||||
// State should be one of the valid states.
|
||||
state := cb.State()
|
||||
|
||||
assert.Contains(t,
|
||||
[]delivery.CircuitState{
|
||||
delivery.CircuitClosed,
|
||||
delivery.CircuitOpen,
|
||||
delivery.CircuitHalfOpen,
|
||||
},
|
||||
state,
|
||||
"state should be valid after concurrent access",
|
||||
)
|
||||
assert.Contains(t, []CircuitState{CircuitClosed, CircuitOpen, CircuitHalfOpen}, state,
|
||||
"state should be valid after concurrent access")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
assert.Equal(t, time.Duration(0),
|
||||
cb.CooldownRemaining(),
|
||||
"closed circuit should have zero cooldown remaining",
|
||||
)
|
||||
cb := NewCircuitBreaker()
|
||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
|
||||
"closed circuit should have zero cooldown remaining")
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip open, wait, transition to half-open
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
require.True(t, cb.Allow()) // → half-open
|
||||
|
||||
require.True(t, cb.Allow())
|
||||
|
||||
assert.Equal(t, time.Duration(0),
|
||||
cb.CooldownRemaining(),
|
||||
"half-open circuit should have zero cooldown remaining",
|
||||
)
|
||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
|
||||
"half-open circuit should have zero cooldown remaining")
|
||||
}
|
||||
|
||||
func TestCircuitState_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "closed", delivery.CircuitClosed.String())
|
||||
assert.Equal(t, "open", delivery.CircuitOpen.String())
|
||||
assert.Equal(t, "half-open", delivery.CircuitHalfOpen.String())
|
||||
assert.Equal(t, "unknown", delivery.CircuitState(99).String())
|
||||
}
|
||||
|
||||
// newShortCooldownCB creates a CircuitBreaker with a short
|
||||
// cooldown for testing. We use NewCircuitBreaker and
|
||||
// manipulate through the public API.
|
||||
func newShortCooldownCB(t *testing.T) *delivery.CircuitBreaker {
|
||||
t.Helper()
|
||||
|
||||
return delivery.NewTestCircuitBreaker(
|
||||
delivery.ExportDefaultFailureThreshold,
|
||||
50*time.Millisecond,
|
||||
)
|
||||
assert.Equal(t, "closed", CircuitClosed.String())
|
||||
assert.Equal(t, "open", CircuitOpen.String())
|
||||
assert.Equal(t, "half-open", CircuitHalfOpen.String())
|
||||
assert.Equal(t, "unknown", CircuitState(99).String())
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,240 +0,0 @@
|
||||
package delivery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
)
|
||||
|
||||
// Exported constants for test access.
|
||||
const (
|
||||
ExportDeliveryChannelSize = deliveryChannelSize
|
||||
ExportRetryChannelSize = retryChannelSize
|
||||
ExportDefaultFailureThreshold = defaultFailureThreshold
|
||||
ExportDefaultCooldown = defaultCooldown
|
||||
)
|
||||
|
||||
// ExportIsBlockedIP exposes isBlockedIP for testing.
|
||||
func ExportIsBlockedIP(ip net.IP) bool {
|
||||
return isBlockedIP(ip)
|
||||
}
|
||||
|
||||
// ExportBlockedNetworks exposes blockedNetworks.
|
||||
func ExportBlockedNetworks() []*net.IPNet {
|
||||
return blockedNetworks
|
||||
}
|
||||
|
||||
// ExportIsForwardableHeader exposes isForwardableHeader.
|
||||
func ExportIsForwardableHeader(name string) bool {
|
||||
return isForwardableHeader(name)
|
||||
}
|
||||
|
||||
// ExportTruncate exposes truncate for testing.
|
||||
func ExportTruncate(s string, maxLen int) string {
|
||||
return truncate(s, maxLen)
|
||||
}
|
||||
|
||||
// ExportDeliverHTTP exposes deliverHTTP for testing.
|
||||
func (e *Engine) ExportDeliverHTTP(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
d *database.Delivery,
|
||||
task *Task,
|
||||
) {
|
||||
e.deliverHTTP(ctx, webhookDB, d, task)
|
||||
}
|
||||
|
||||
// ExportDeliverDatabase exposes deliverDatabase.
|
||||
func (e *Engine) ExportDeliverDatabase(
|
||||
webhookDB *gorm.DB, d *database.Delivery,
|
||||
) {
|
||||
e.deliverDatabase(webhookDB, d)
|
||||
}
|
||||
|
||||
// ExportDeliverLog exposes deliverLog for testing.
|
||||
func (e *Engine) ExportDeliverLog(
|
||||
webhookDB *gorm.DB, d *database.Delivery,
|
||||
) {
|
||||
e.deliverLog(webhookDB, d)
|
||||
}
|
||||
|
||||
// ExportDeliverSlack exposes deliverSlack for testing.
|
||||
func (e *Engine) ExportDeliverSlack(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
d *database.Delivery,
|
||||
) {
|
||||
e.deliverSlack(ctx, webhookDB, d)
|
||||
}
|
||||
|
||||
// ExportProcessNewTask exposes processNewTask.
|
||||
func (e *Engine) ExportProcessNewTask(
|
||||
ctx context.Context, task *Task,
|
||||
) {
|
||||
e.processNewTask(ctx, task)
|
||||
}
|
||||
|
||||
// ExportProcessRetryTask exposes processRetryTask.
|
||||
func (e *Engine) ExportProcessRetryTask(
|
||||
ctx context.Context, task *Task,
|
||||
) {
|
||||
e.processRetryTask(ctx, task)
|
||||
}
|
||||
|
||||
// ExportProcessDelivery exposes processDelivery.
|
||||
func (e *Engine) ExportProcessDelivery(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
d *database.Delivery,
|
||||
task *Task,
|
||||
) {
|
||||
e.processDelivery(ctx, webhookDB, d, task)
|
||||
}
|
||||
|
||||
// ExportGetCircuitBreaker exposes getCircuitBreaker.
|
||||
func (e *Engine) ExportGetCircuitBreaker(
|
||||
targetID string,
|
||||
) *CircuitBreaker {
|
||||
return e.getCircuitBreaker(targetID)
|
||||
}
|
||||
|
||||
// ExportParseHTTPConfig exposes parseHTTPConfig.
|
||||
func (e *Engine) ExportParseHTTPConfig(
|
||||
configJSON string,
|
||||
) (*HTTPTargetConfig, error) {
|
||||
return e.parseHTTPConfig(configJSON)
|
||||
}
|
||||
|
||||
// ExportParseSlackConfig exposes parseSlackConfig.
|
||||
func (e *Engine) ExportParseSlackConfig(
|
||||
configJSON string,
|
||||
) (*SlackTargetConfig, error) {
|
||||
return e.parseSlackConfig(configJSON)
|
||||
}
|
||||
|
||||
// ExportDoHTTPRequest exposes doHTTPRequest.
|
||||
func (e *Engine) ExportDoHTTPRequest(
|
||||
ctx context.Context,
|
||||
cfg *HTTPTargetConfig,
|
||||
event *database.Event,
|
||||
) (int, string, int64, error) {
|
||||
return e.doHTTPRequest(ctx, cfg, event)
|
||||
}
|
||||
|
||||
// ExportScheduleRetry exposes scheduleRetry.
|
||||
func (e *Engine) ExportScheduleRetry(
|
||||
task Task, delay time.Duration,
|
||||
) {
|
||||
e.scheduleRetry(task, delay)
|
||||
}
|
||||
|
||||
// ExportRecoverPendingDeliveries exposes
|
||||
// recoverPendingDeliveries.
|
||||
func (e *Engine) ExportRecoverPendingDeliveries(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
webhookID string,
|
||||
) {
|
||||
e.recoverPendingDeliveries(
|
||||
ctx, webhookDB, webhookID,
|
||||
)
|
||||
}
|
||||
|
||||
// ExportRecoverWebhookDeliveries exposes
|
||||
// recoverWebhookDeliveries.
|
||||
func (e *Engine) ExportRecoverWebhookDeliveries(
|
||||
ctx context.Context, webhookID string,
|
||||
) {
|
||||
e.recoverWebhookDeliveries(ctx, webhookID)
|
||||
}
|
||||
|
||||
// ExportRecoverInFlight exposes recoverInFlight.
|
||||
func (e *Engine) ExportRecoverInFlight(
|
||||
ctx context.Context,
|
||||
) {
|
||||
e.recoverInFlight(ctx)
|
||||
}
|
||||
|
||||
// ExportStart exposes start for testing.
|
||||
func (e *Engine) ExportStart(ctx context.Context) {
|
||||
e.start(ctx)
|
||||
}
|
||||
|
||||
// ExportStop exposes stop for testing.
|
||||
func (e *Engine) ExportStop() {
|
||||
e.stop()
|
||||
}
|
||||
|
||||
// ExportDeliveryCh returns the delivery channel.
|
||||
func (e *Engine) ExportDeliveryCh() chan Task {
|
||||
return e.deliveryCh
|
||||
}
|
||||
|
||||
// ExportRetryCh returns the retry channel.
|
||||
func (e *Engine) ExportRetryCh() chan Task {
|
||||
return e.retryCh
|
||||
}
|
||||
|
||||
// NewTestEngine creates an Engine for unit tests without
|
||||
// database dependencies.
|
||||
func NewTestEngine(
|
||||
log *slog.Logger,
|
||||
client *http.Client,
|
||||
workers int,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
log: log,
|
||||
client: client,
|
||||
deliveryCh: make(chan Task, deliveryChannelSize),
|
||||
retryCh: make(chan Task, retryChannelSize),
|
||||
workers: workers,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestEngineSmallRetry creates an Engine with a tiny
|
||||
// retry channel buffer for overflow testing.
|
||||
func NewTestEngineSmallRetry(
|
||||
log *slog.Logger,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
log: log,
|
||||
retryCh: make(chan Task, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestEngineWithDB creates an Engine with a real
|
||||
// database and dbManager for integration tests.
|
||||
func NewTestEngineWithDB(
|
||||
db *database.Database,
|
||||
dbMgr *database.WebhookDBManager,
|
||||
log *slog.Logger,
|
||||
client *http.Client,
|
||||
workers int,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
database: db,
|
||||
dbManager: dbMgr,
|
||||
log: log,
|
||||
client: client,
|
||||
deliveryCh: make(chan Task, deliveryChannelSize),
|
||||
retryCh: make(chan Task, retryChannelSize),
|
||||
workers: workers,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestCircuitBreaker creates a CircuitBreaker with
|
||||
// custom settings for testing.
|
||||
func NewTestCircuitBreaker(
|
||||
threshold int, cooldown time.Duration,
|
||||
) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: threshold,
|
||||
cooldown: cooldown,
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package delivery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -11,27 +10,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// dnsResolutionTimeout is the maximum time to wait for
|
||||
// DNS resolution during SSRF validation.
|
||||
// dnsResolutionTimeout is the maximum time to wait for DNS resolution
|
||||
// during SSRF validation.
|
||||
dnsResolutionTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Sentinel errors for SSRF validation.
|
||||
var (
|
||||
errNoHostname = errors.New("URL has no hostname")
|
||||
errNoIPs = errors.New(
|
||||
"hostname resolved to no IP addresses",
|
||||
)
|
||||
errBlockedIP = errors.New(
|
||||
"blocked private/reserved IP range",
|
||||
)
|
||||
errInvalidScheme = errors.New(
|
||||
"only http and https are allowed",
|
||||
)
|
||||
)
|
||||
|
||||
// blockedNetworks contains all private/reserved IP ranges
|
||||
// that should be blocked to prevent SSRF attacks.
|
||||
// blockedNetworks contains all private/reserved IP ranges that should be
|
||||
// blocked to prevent SSRF attacks. This includes RFC 1918 private
|
||||
// addresses, loopback, link-local, and IPv6 equivalents.
|
||||
//
|
||||
//nolint:gochecknoglobals // package-level network list is appropriate here
|
||||
var blockedNetworks []*net.IPNet
|
||||
@@ -39,184 +25,129 @@ var blockedNetworks []*net.IPNet
|
||||
//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup
|
||||
func init() {
|
||||
cidrs := []string{
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"169.254.0.0/16",
|
||||
"0.0.0.0/8",
|
||||
"100.64.0.0/10",
|
||||
"192.0.0.0/24",
|
||||
"192.0.2.0/24",
|
||||
"198.18.0.0/15",
|
||||
"198.51.100.0/24",
|
||||
"203.0.113.0/24",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
// IPv4 private/reserved ranges
|
||||
"127.0.0.0/8", // Loopback
|
||||
"10.0.0.0/8", // RFC 1918 Class A private
|
||||
"172.16.0.0/12", // RFC 1918 Class B private
|
||||
"192.168.0.0/16", // RFC 1918 Class C private
|
||||
"169.254.0.0/16", // Link-local (cloud metadata)
|
||||
"0.0.0.0/8", // "This" network
|
||||
"100.64.0.0/10", // Shared address space (CGN)
|
||||
"192.0.0.0/24", // IETF protocol assignments
|
||||
"192.0.2.0/24", // TEST-NET-1
|
||||
"198.18.0.0/15", // Benchmarking
|
||||
"198.51.100.0/24", // TEST-NET-2
|
||||
"203.0.113.0/24", // TEST-NET-3
|
||||
"224.0.0.0/4", // Multicast
|
||||
"240.0.0.0/4", // Reserved for future use
|
||||
|
||||
// IPv6 private/reserved ranges
|
||||
"::1/128", // Loopback
|
||||
"fc00::/7", // Unique local addresses
|
||||
"fe80::/10", // Link-local
|
||||
}
|
||||
|
||||
for _, cidr := range cidrs {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf(
|
||||
"ssrf: failed to parse CIDR %q: %v",
|
||||
cidr, err,
|
||||
))
|
||||
panic(fmt.Sprintf("ssrf: failed to parse CIDR %q: %v", cidr, err))
|
||||
}
|
||||
|
||||
blockedNetworks = append(
|
||||
blockedNetworks, network,
|
||||
)
|
||||
blockedNetworks = append(blockedNetworks, network)
|
||||
}
|
||||
}
|
||||
|
||||
// isBlockedIP checks whether an IP address falls within
|
||||
// any blocked private/reserved network range.
|
||||
// isBlockedIP checks whether an IP address falls within any blocked
|
||||
// private/reserved network range.
|
||||
func isBlockedIP(ip net.IP) bool {
|
||||
for _, network := range blockedNetworks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateTargetURL checks that an HTTP delivery target
|
||||
// URL is safe from SSRF attacks.
|
||||
func ValidateTargetURL(
|
||||
ctx context.Context, targetURL string,
|
||||
) error {
|
||||
// ValidateTargetURL checks that an HTTP delivery target URL is safe
|
||||
// from SSRF attacks. It validates the URL format, resolves the hostname
|
||||
// to IP addresses, and verifies that none of the resolved IPs are in
|
||||
// blocked private/reserved ranges.
|
||||
//
|
||||
// Returns nil if the URL is safe, or an error describing the issue.
|
||||
func ValidateTargetURL(targetURL string) error {
|
||||
parsed, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
err = validateScheme(parsed.Scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
// Only allow http and https schemes
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported URL scheme %q: only http and https are allowed", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
if host == "" {
|
||||
return errNoHostname
|
||||
return fmt.Errorf("URL has no hostname")
|
||||
}
|
||||
|
||||
// Check if the host is a raw IP address first
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return checkBlockedIP(ip)
|
||||
}
|
||||
|
||||
return validateHostname(ctx, host)
|
||||
}
|
||||
|
||||
func validateScheme(scheme string) error {
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return fmt.Errorf(
|
||||
"unsupported URL scheme %q: %w",
|
||||
scheme, errInvalidScheme,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkBlockedIP(ip net.IP) error {
|
||||
if isBlockedIP(ip) {
|
||||
return fmt.Errorf(
|
||||
"target IP %s is in a blocked "+
|
||||
"private/reserved range: %w",
|
||||
ip, errBlockedIP,
|
||||
)
|
||||
return fmt.Errorf("target IP %s is in a blocked private/reserved range", ip)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHostname(
|
||||
ctx context.Context, host string,
|
||||
) error {
|
||||
dnsCtx, cancel := context.WithTimeout(
|
||||
ctx, dnsResolutionTimeout,
|
||||
)
|
||||
// Resolve hostname to IPs and check each one
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsResolutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(
|
||||
dnsCtx, host,
|
||||
)
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to resolve hostname %q: %w",
|
||||
host, err,
|
||||
)
|
||||
return fmt.Errorf("failed to resolve hostname %q: %w", host, err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return fmt.Errorf(
|
||||
"hostname %q: %w", host, errNoIPs,
|
||||
)
|
||||
return fmt.Errorf("hostname %q resolved to no IP addresses", host)
|
||||
}
|
||||
|
||||
for _, ipAddr := range ips {
|
||||
if isBlockedIP(ipAddr.IP) {
|
||||
return fmt.Errorf(
|
||||
"hostname %q resolves to blocked "+
|
||||
"IP %s: %w",
|
||||
host, ipAddr.IP, errBlockedIP,
|
||||
)
|
||||
return fmt.Errorf("hostname %q resolves to blocked IP %s (private/reserved range)", host, ipAddr.IP)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSSRFSafeTransport creates an http.Transport with a
|
||||
// custom DialContext that blocks connections to
|
||||
// private/reserved IP addresses.
|
||||
// NewSSRFSafeTransport creates an http.Transport with a custom DialContext
|
||||
// that blocks connections to private/reserved IP addresses. This provides
|
||||
// defense-in-depth SSRF protection at the network layer, catching cases
|
||||
// where DNS records change between target creation and delivery time
|
||||
// (DNS rebinding attacks).
|
||||
func NewSSRFSafeTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
DialContext: ssrfDialContext,
|
||||
}
|
||||
}
|
||||
|
||||
func ssrfDialContext(
|
||||
ctx context.Context,
|
||||
network, addr string,
|
||||
) (net.Conn, error) {
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"ssrf: invalid address %q: %w",
|
||||
addr, err,
|
||||
)
|
||||
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
|
||||
}
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(
|
||||
ctx, host,
|
||||
)
|
||||
// Resolve hostname to IPs
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"ssrf: DNS resolution failed for %q: %w",
|
||||
host, err,
|
||||
)
|
||||
return nil, fmt.Errorf("ssrf: DNS resolution failed for %q: %w", host, err)
|
||||
}
|
||||
|
||||
// Check all resolved IPs
|
||||
for _, ipAddr := range ips {
|
||||
if isBlockedIP(ipAddr.IP) {
|
||||
return nil, fmt.Errorf(
|
||||
"ssrf: connection to %s (%s) "+
|
||||
"blocked: %w",
|
||||
host, ipAddr.IP, errBlockedIP,
|
||||
)
|
||||
return nil, fmt.Errorf("ssrf: connection to %s (%s) blocked — private/reserved IP range", host, ipAddr.IP)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to the first allowed IP
|
||||
var dialer net.Dialer
|
||||
|
||||
return dialer.DialContext(
|
||||
ctx, network,
|
||||
net.JoinHostPort(ips[0].IP.String(), port),
|
||||
)
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package delivery_test
|
||||
package delivery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/delivery"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
||||
@@ -18,52 +16,56 @@ func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
||||
ip string
|
||||
blocked bool
|
||||
}{
|
||||
// Loopback
|
||||
{"loopback 127.0.0.1", "127.0.0.1", true},
|
||||
{"loopback 127.0.0.2", "127.0.0.2", true},
|
||||
{"loopback 127.255.255.255", "127.255.255.255", true},
|
||||
|
||||
// RFC 1918 - Class A
|
||||
{"10.0.0.0", "10.0.0.0", true},
|
||||
{"10.0.0.1", "10.0.0.1", true},
|
||||
{"10.255.255.255", "10.255.255.255", true},
|
||||
|
||||
// RFC 1918 - Class B
|
||||
{"172.16.0.1", "172.16.0.1", true},
|
||||
{"172.31.255.255", "172.31.255.255", true},
|
||||
{"172.15.255.255", "172.15.255.255", false},
|
||||
{"172.32.0.0", "172.32.0.0", false},
|
||||
|
||||
// RFC 1918 - Class C
|
||||
{"192.168.0.1", "192.168.0.1", true},
|
||||
{"192.168.255.255", "192.168.255.255", true},
|
||||
|
||||
// Link-local / cloud metadata
|
||||
{"169.254.0.1", "169.254.0.1", true},
|
||||
{"169.254.169.254", "169.254.169.254", true},
|
||||
|
||||
// Public IPs (should NOT be blocked)
|
||||
{"8.8.8.8", "8.8.8.8", false},
|
||||
{"1.1.1.1", "1.1.1.1", false},
|
||||
{"93.184.216.34", "93.184.216.34", false},
|
||||
|
||||
// IPv6 loopback
|
||||
{"::1", "::1", true},
|
||||
|
||||
// IPv6 unique local
|
||||
{"fd00::1", "fd00::1", true},
|
||||
{"fc00::1", "fc00::1", true},
|
||||
|
||||
// IPv6 link-local
|
||||
{"fe80::1", "fe80::1", true},
|
||||
{
|
||||
"2607:f8b0:4004:800::200e",
|
||||
"2607:f8b0:4004:800::200e",
|
||||
false,
|
||||
},
|
||||
|
||||
// IPv6 public (should NOT be blocked)
|
||||
{"2607:f8b0:4004:800::200e", "2607:f8b0:4004:800::200e", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ip := net.ParseIP(tt.ip)
|
||||
|
||||
require.NotNil(t, ip,
|
||||
"failed to parse IP %s", tt.ip,
|
||||
)
|
||||
|
||||
assert.Equal(t,
|
||||
tt.blocked,
|
||||
delivery.ExportIsBlockedIP(ip),
|
||||
"isBlockedIP(%s) = %v, want %v",
|
||||
tt.ip,
|
||||
delivery.ExportIsBlockedIP(ip),
|
||||
tt.blocked,
|
||||
)
|
||||
require.NotNil(t, ip, "failed to parse IP %s", tt.ip)
|
||||
assert.Equal(t, tt.blocked, isBlockedIP(ip),
|
||||
"isBlockedIP(%s) = %v, want %v", tt.ip, isBlockedIP(ip), tt.blocked)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -87,14 +89,8 @@ func TestValidateTargetURL_Blocked(t *testing.T) {
|
||||
for _, u := range blockedURLs {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), u,
|
||||
)
|
||||
|
||||
assert.Error(t, err,
|
||||
"URL %s should be blocked", u,
|
||||
)
|
||||
err := ValidateTargetURL(u)
|
||||
assert.Error(t, err, "URL %s should be blocked", u)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -102,6 +98,7 @@ func TestValidateTargetURL_Blocked(t *testing.T) {
|
||||
func TestValidateTargetURL_Allowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// These are public IPs and should be allowed
|
||||
allowedURLs := []string{
|
||||
"https://example.com/hook",
|
||||
"http://93.184.216.34/webhook",
|
||||
@@ -111,62 +108,35 @@ func TestValidateTargetURL_Allowed(t *testing.T) {
|
||||
for _, u := range allowedURLs {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), u,
|
||||
)
|
||||
|
||||
assert.NoError(t, err,
|
||||
"URL %s should be allowed", u,
|
||||
)
|
||||
err := ValidateTargetURL(u)
|
||||
assert.NoError(t, err, "URL %s should be allowed", u)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetURL_InvalidScheme(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), "ftp://example.com/hook",
|
||||
)
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Contains(t, err.Error(),
|
||||
"unsupported URL scheme",
|
||||
)
|
||||
err := ValidateTargetURL("ftp://example.com/hook")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported URL scheme")
|
||||
}
|
||||
|
||||
func TestValidateTargetURL_EmptyHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), "http:///path",
|
||||
)
|
||||
|
||||
err := ValidateTargetURL("http:///path")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateTargetURL_InvalidURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), "://invalid",
|
||||
)
|
||||
|
||||
err := ValidateTargetURL("://invalid")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBlockedNetworks_Initialized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
nets := delivery.ExportBlockedNetworks()
|
||||
|
||||
assert.NotEmpty(t, nets,
|
||||
"blockedNetworks should be initialized",
|
||||
)
|
||||
|
||||
assert.GreaterOrEqual(t, len(nets), 8,
|
||||
"should have at least 8 blocked network ranges",
|
||||
)
|
||||
assert.NotEmpty(t, blockedNetworks, "blockedNetworks should be initialized")
|
||||
// Should have at least the main RFC 1918 + loopback + link-local ranges
|
||||
assert.GreaterOrEqual(t, len(blockedNetworks), 8,
|
||||
"should have at least 8 blocked network ranges")
|
||||
}
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// Build-time variables populated from main() and copied into the
|
||||
// Globals object.
|
||||
// Build-time variables populated from main() and copied into the Globals object.
|
||||
//
|
||||
//nolint:gochecknoglobals // Build-time variables set by main().
|
||||
var (
|
||||
@@ -20,8 +19,7 @@ type Globals struct {
|
||||
Version string
|
||||
}
|
||||
|
||||
// New creates a Globals instance from the package-level
|
||||
// build-time variables.
|
||||
// New creates a Globals instance from the package-level build-time variables.
|
||||
//
|
||||
//nolint:revive // lc parameter is required by fx even if unused.
|
||||
func New(lc fx.Lifecycle) (*Globals, error) {
|
||||
@@ -29,6 +27,5 @@ func New(lc fx.Lifecycle) (*Globals, error) {
|
||||
Appname: Appname,
|
||||
Version: Version,
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
package globals_test
|
||||
package globals
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
func TestGlobalsFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestNew(t *testing.T) {
|
||||
// Set test values
|
||||
Appname = "test-app"
|
||||
Version = "1.0.0"
|
||||
|
||||
g := &globals.Globals{
|
||||
Appname: "test-app",
|
||||
Version: "1.0.0",
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
globals, err := New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if g.Appname != "test-app" {
|
||||
t.Errorf(
|
||||
"Appname = %v, want %v",
|
||||
g.Appname, "test-app",
|
||||
)
|
||||
if globals.Appname != "test-app" {
|
||||
t.Errorf("Appname = %v, want %v", globals.Appname, "test-app")
|
||||
}
|
||||
|
||||
if g.Version != "1.0.0" {
|
||||
t.Errorf(
|
||||
"Version = %v, want %v",
|
||||
g.Version, "1.0.0",
|
||||
)
|
||||
if globals.Version != "1.0.0" {
|
||||
t.Errorf("Version = %v, want %v", globals.Version, "1.0.0")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,12 +13,11 @@ func (h *Handlers) HandleLoginPage() http.HandlerFunc {
|
||||
sess, err := h.session.Get(r)
|
||||
if err == nil && h.session.IsAuthenticated(sess) {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Render login page
|
||||
data := map[string]any{
|
||||
data := map[string]interface{}{
|
||||
"Error": "",
|
||||
}
|
||||
|
||||
@@ -30,14 +29,12 @@ func (h *Handlers) HandleLoginPage() http.HandlerFunc {
|
||||
func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Limit request body to prevent memory exhaustion
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<maxBodyShift)
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MB
|
||||
|
||||
// Parse form data
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
h.log.Error("failed to parse form", "error", err)
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -46,147 +43,76 @@ func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
|
||||
|
||||
// Validate input
|
||||
if username == "" || password == "" {
|
||||
h.renderLoginError(
|
||||
w, r,
|
||||
"Username and password are required",
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return
|
||||
data := map[string]interface{}{
|
||||
"Error": "Username and password are required",
|
||||
}
|
||||
|
||||
user, err := h.authenticateUser(
|
||||
w, r, username, password,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.createAuthenticatedSession(w, r, user)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.log.Info(
|
||||
"user logged in",
|
||||
"username", username,
|
||||
"user_id", user.ID,
|
||||
)
|
||||
|
||||
// Redirect to home page
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
// renderLoginError renders the login page with an error message.
|
||||
func (h *Handlers) renderLoginError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
status int,
|
||||
) {
|
||||
data := map[string]any{
|
||||
"Error": msg,
|
||||
}
|
||||
|
||||
w.WriteHeader(status)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
h.renderTemplate(w, r, "login.html", data)
|
||||
}
|
||||
|
||||
// authenticateUser looks up and verifies a user's credentials.
|
||||
// On failure it writes an HTTP response and returns an error.
|
||||
func (h *Handlers) authenticateUser(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
username, password string,
|
||||
) (database.User, error) {
|
||||
var user database.User
|
||||
|
||||
err := h.db.DB().Where(
|
||||
"username = ?", username,
|
||||
).First(&user).Error
|
||||
if err != nil {
|
||||
h.log.Debug("user not found", "username", username)
|
||||
h.renderLoginError(
|
||||
w, r,
|
||||
"Invalid username or password",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
return user, err
|
||||
return
|
||||
}
|
||||
|
||||
// Find user in database
|
||||
var user database.User
|
||||
if err := h.db.DB().Where("username = ?", username).First(&user).Error; err != nil {
|
||||
h.log.Debug("user not found", "username", username)
|
||||
data := map[string]interface{}{
|
||||
"Error": "Invalid username or password",
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
h.renderTemplate(w, r, "login.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify password
|
||||
valid, err := database.VerifyPassword(password, user.Password)
|
||||
if err != nil {
|
||||
h.log.Error("failed to verify password", "error", err)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return user, err
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !valid {
|
||||
h.log.Debug("invalid password", "username", username)
|
||||
h.renderLoginError(
|
||||
w, r,
|
||||
"Invalid username or password",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
return user, errInvalidPassword
|
||||
data := map[string]interface{}{
|
||||
"Error": "Invalid username or password",
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
h.renderTemplate(w, r, "login.html", data)
|
||||
return
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// createAuthenticatedSession regenerates the session and stores
|
||||
// user info. On failure it writes an HTTP response and returns
|
||||
// an error.
|
||||
func (h *Handlers) createAuthenticatedSession(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
user database.User,
|
||||
) error {
|
||||
// Get the current session (may be pre-existing / attacker-set)
|
||||
oldSess, err := h.session.Get(r)
|
||||
if err != nil {
|
||||
h.log.Error("failed to get session", "error", err)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return err
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Regenerate the session to prevent session fixation attacks.
|
||||
// This destroys the old session ID and creates a new one.
|
||||
sess, err := h.session.Regenerate(r, w, oldSess)
|
||||
if err != nil {
|
||||
h.log.Error(
|
||||
"failed to regenerate session", "error", err,
|
||||
)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return err
|
||||
h.log.Error("failed to regenerate session", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Set user in session
|
||||
h.session.SetUser(sess, user.ID, user.Username)
|
||||
|
||||
err = h.session.Save(r, w, sess)
|
||||
if err != nil {
|
||||
// Save session
|
||||
if err := h.session.Save(r, w, sess); err != nil {
|
||||
h.log.Error("failed to save session", "error", err)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return err
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
return nil
|
||||
h.log.Info("user logged in", "username", username, "user_id", user.ID)
|
||||
|
||||
// Redirect to home page
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleLogout handles user logout
|
||||
@@ -195,10 +121,7 @@ func (h *Handlers) HandleLogout() http.HandlerFunc {
|
||||
sess, err := h.session.Get(r)
|
||||
if err != nil {
|
||||
h.log.Error("failed to get session", "error", err)
|
||||
http.Redirect(
|
||||
w, r, "/pages/login", http.StatusSeeOther,
|
||||
)
|
||||
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -206,12 +129,8 @@ func (h *Handlers) HandleLogout() http.HandlerFunc {
|
||||
h.session.Destroy(sess)
|
||||
|
||||
// Save the destroyed session
|
||||
err = h.session.Save(r, w, sess)
|
||||
if err != nil {
|
||||
h.log.Error(
|
||||
"failed to save destroyed session",
|
||||
"error", err,
|
||||
)
|
||||
if err := h.session.Save(r, w, sess); err != nil {
|
||||
h.log.Error("failed to save destroyed session", "error", err)
|
||||
}
|
||||
|
||||
// Redirect to login page
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import "net/http"
|
||||
|
||||
// RenderTemplateForTest exposes renderTemplate for use in the
|
||||
// handlers_test package.
|
||||
func (s *Handlers) RenderTemplateForTest(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
pageTemplate string,
|
||||
data any,
|
||||
) {
|
||||
s.renderTemplate(w, r, pageTemplate, data)
|
||||
}
|
||||
@@ -1,11 +1,9 @@
|
||||
// Package handlers provides HTTP request handlers for the
|
||||
// webhooker web UI and API.
|
||||
// Package handlers provides HTTP request handlers for the webhooker web UI and API.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"html/template"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -21,24 +19,9 @@ import (
|
||||
"sneak.berlin/go/webhooker/templates"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxBodyShift is the bit shift for 1 MB body limit.
|
||||
maxBodyShift = 20
|
||||
// recentEventLimit is the number of recent events to show.
|
||||
recentEventLimit = 20
|
||||
// defaultRetentionDays is the default event retention period.
|
||||
defaultRetentionDays = 30
|
||||
// paginationPerPage is the number of items per page.
|
||||
paginationPerPage = 25
|
||||
)
|
||||
|
||||
// errInvalidPassword is returned when a password does not match.
|
||||
var errInvalidPassword = errors.New("invalid password")
|
||||
|
||||
//nolint:revive // HandlersParams is a standard fx naming convention.
|
||||
type HandlersParams struct {
|
||||
fx.In
|
||||
|
||||
Logger *logger.Logger
|
||||
Globals *globals.Globals
|
||||
Database *database.Database
|
||||
@@ -48,8 +31,7 @@ type HandlersParams struct {
|
||||
Notifier delivery.Notifier
|
||||
}
|
||||
|
||||
// Handlers provides HTTP handler methods for all application
|
||||
// routes.
|
||||
// Handlers provides HTTP handler methods for all application routes.
|
||||
type Handlers struct {
|
||||
params *HandlersParams
|
||||
log *slog.Logger
|
||||
@@ -61,29 +43,20 @@ type Handlers struct {
|
||||
templates map[string]*template.Template
|
||||
}
|
||||
|
||||
// parsePageTemplate parses a page-specific template set from the
|
||||
// embedded FS. Each page template is combined with the shared
|
||||
// base, htmlheader, and navbar templates. The page file must be
|
||||
// listed first so that its root action ({{template "base" .}})
|
||||
// becomes the template set's entry point.
|
||||
// parsePageTemplate parses a page-specific template set from the embedded FS.
|
||||
// Each page template is combined with the shared base, htmlheader, and navbar templates.
|
||||
// The page file must be listed first so that its root action ({{template "base" .}})
|
||||
// becomes the template set's entry point. If a shared partial (e.g. htmlheader.html)
|
||||
// is listed first, its {{define}} block becomes the root — which is empty — and
|
||||
// Execute() produces no output.
|
||||
func parsePageTemplate(pageFile string) *template.Template {
|
||||
return template.Must(
|
||||
template.ParseFS(
|
||||
templates.Templates,
|
||||
pageFile,
|
||||
"base.html",
|
||||
"htmlheader.html",
|
||||
"navbar.html",
|
||||
),
|
||||
template.ParseFS(templates.Templates, pageFile, "base.html", "htmlheader.html", "navbar.html"),
|
||||
)
|
||||
}
|
||||
|
||||
// New creates a Handlers instance, parsing all page templates at
|
||||
// startup.
|
||||
func New(
|
||||
lc fx.Lifecycle,
|
||||
params HandlersParams,
|
||||
) (*Handlers, error) {
|
||||
// New creates a Handlers instance, parsing all page templates at startup.
|
||||
func New(lc fx.Lifecycle, params HandlersParams) (*Handlers, error) {
|
||||
s := new(Handlers)
|
||||
s.params = ¶ms
|
||||
s.log = params.Logger.Get()
|
||||
@@ -95,6 +68,7 @@ func New(
|
||||
|
||||
// Parse all page templates once at startup
|
||||
s.templates = map[string]*template.Template{
|
||||
"index.html": parsePageTemplate("index.html"),
|
||||
"login.html": parsePageTemplate("login.html"),
|
||||
"profile.html": parsePageTemplate("profile.html"),
|
||||
"sources_list.html": parsePageTemplate("sources_list.html"),
|
||||
@@ -109,19 +83,13 @@ func New(
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Handlers) respondJSON(
|
||||
w http.ResponseWriter,
|
||||
_ *http.Request,
|
||||
data any,
|
||||
status int,
|
||||
) {
|
||||
//nolint:unparam // r parameter will be used in the future for request context.
|
||||
func (s *Handlers) respondJSON(w http.ResponseWriter, _ *http.Request, data interface{}, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
if data != nil {
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
@@ -130,15 +98,9 @@ func (s *Handlers) respondJSON(
|
||||
}
|
||||
}
|
||||
|
||||
// serverError logs an error and sends a 500 response.
|
||||
func (s *Handlers) serverError(
|
||||
w http.ResponseWriter, msg string, err error,
|
||||
) {
|
||||
s.log.Error(msg, "error", err)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
//nolint:unparam,unused // will be used for handling JSON requests.
|
||||
func (s *Handlers) decodeJSON(_ http.ResponseWriter, r *http.Request, v interface{}) error {
|
||||
return json.NewDecoder(r.Body).Decode(v)
|
||||
}
|
||||
|
||||
// UserInfo represents user information for templates
|
||||
@@ -147,91 +109,58 @@ type UserInfo struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
// templateDataWrapper wraps non-map data with common fields.
|
||||
type templateDataWrapper struct {
|
||||
User *UserInfo
|
||||
CSRFToken string
|
||||
Data any
|
||||
}
|
||||
|
||||
// getUserInfo extracts user info from the session.
|
||||
func (s *Handlers) getUserInfo(
|
||||
r *http.Request,
|
||||
) *UserInfo {
|
||||
sess, err := s.session.Get(r)
|
||||
if err != nil || !s.session.IsAuthenticated(sess) {
|
||||
return nil
|
||||
}
|
||||
|
||||
username, ok := s.session.GetUsername(sess)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, ok := s.session.GetUserID(sess)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &UserInfo{ID: userID, Username: username}
|
||||
}
|
||||
|
||||
// renderTemplate renders a pre-parsed template with common
|
||||
// data
|
||||
func (s *Handlers) renderTemplate(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
pageTemplate string,
|
||||
data any,
|
||||
) {
|
||||
// renderTemplate renders a pre-parsed template with common data
|
||||
func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTemplate string, data interface{}) {
|
||||
tmpl, ok := s.templates[pageTemplate]
|
||||
if !ok {
|
||||
s.log.Error(
|
||||
"template not found",
|
||||
"template", pageTemplate,
|
||||
)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
s.log.Error("template not found", "template", pageTemplate)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := s.getUserInfo(r)
|
||||
// Get user from session if available
|
||||
var userInfo *UserInfo
|
||||
sess, err := s.session.Get(r)
|
||||
if err == nil && s.session.IsAuthenticated(sess) {
|
||||
if username, ok := s.session.GetUsername(sess); ok {
|
||||
if userID, ok := s.session.GetUserID(sess); ok {
|
||||
userInfo = &UserInfo{
|
||||
ID: userID,
|
||||
Username: username,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CSRF token from request context (set by CSRF middleware)
|
||||
csrfToken := middleware.CSRFToken(r)
|
||||
|
||||
if m, ok := data.(map[string]any); ok {
|
||||
// If data is a map, merge user info and CSRF token into it
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
m["User"] = userInfo
|
||||
m["CSRFToken"] = csrfToken
|
||||
s.executeTemplate(w, tmpl, m)
|
||||
|
||||
if err := tmpl.Execute(w, m); err != nil {
|
||||
s.log.Error("failed to execute template", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap data with base template data
|
||||
type templateDataWrapper struct {
|
||||
User *UserInfo
|
||||
CSRFToken string
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
wrapper := templateDataWrapper{
|
||||
User: userInfo,
|
||||
CSRFToken: csrfToken,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
s.executeTemplate(w, tmpl, wrapper)
|
||||
}
|
||||
|
||||
// executeTemplate runs the template and handles errors.
|
||||
func (s *Handlers) executeTemplate(
|
||||
w http.ResponseWriter,
|
||||
tmpl *template.Template,
|
||||
data any,
|
||||
) {
|
||||
err := tmpl.Execute(w, data)
|
||||
if err != nil {
|
||||
s.log.Error(
|
||||
"failed to execute template", "error", err,
|
||||
)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
if err := tmpl.Execute(w, wrapper); err != nil {
|
||||
s.log.Error("failed to execute template", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package handlers_test
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -14,23 +14,20 @@ import (
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
"sneak.berlin/go/webhooker/internal/delivery"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/handlers"
|
||||
"sneak.berlin/go/webhooker/internal/healthcheck"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
"sneak.berlin/go/webhooker/internal/session"
|
||||
)
|
||||
|
||||
// noopNotifier is a no-op delivery.Notifier for tests.
|
||||
type noopNotifier struct{}
|
||||
|
||||
func (n *noopNotifier) Notify([]delivery.Task) {}
|
||||
|
||||
func newTestApp(
|
||||
t *testing.T,
|
||||
targets ...any,
|
||||
) *fxtest.App {
|
||||
t.Helper()
|
||||
func TestHandleIndex(t *testing.T) {
|
||||
var h *Handlers
|
||||
|
||||
return fxtest.New(
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
@@ -44,99 +41,92 @@ func newTestApp(
|
||||
database.NewWebhookDBManager,
|
||||
healthcheck.New,
|
||||
session.New,
|
||||
func() delivery.Notifier {
|
||||
return &noopNotifier{}
|
||||
},
|
||||
handlers.New,
|
||||
func() delivery.Notifier { return &noopNotifier{} },
|
||||
New,
|
||||
),
|
||||
fx.Populate(targets...),
|
||||
fx.Populate(&h),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHandleIndex_Unauthenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var h *handlers.Handlers
|
||||
|
||||
app := newTestApp(t, &h)
|
||||
app.RequireStart()
|
||||
defer app.RequireStop()
|
||||
|
||||
t.Cleanup(app.RequireStop)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Since we can't test actual template rendering without templates,
|
||||
// let's test that the handler is created and doesn't panic
|
||||
handler := h.HandleIndex()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusSeeOther, w.Code)
|
||||
assert.Equal(
|
||||
t, "/pages/login", w.Header().Get("Location"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHandleIndex_Authenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var h *handlers.Handlers
|
||||
|
||||
var sess *session.Session
|
||||
|
||||
app := newTestApp(t, &h, &sess)
|
||||
app.RequireStart()
|
||||
|
||||
t.Cleanup(app.RequireStop)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s, err := sess.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
sess.SetUser(s, "test-user-id", "testuser")
|
||||
|
||||
err = sess.Save(req, w, s)
|
||||
require.NoError(t, err)
|
||||
|
||||
req2 := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
h.HandleIndex().ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusSeeOther, w2.Code)
|
||||
assert.Equal(
|
||||
t, "/sources", w2.Header().Get("Location"),
|
||||
)
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
|
||||
func TestRenderTemplate(t *testing.T) {
|
||||
t.Parallel()
|
||||
var h *Handlers
|
||||
|
||||
var h *handlers.Handlers
|
||||
|
||||
app := newTestApp(t, &h)
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
func() *config.Config {
|
||||
return &config.Config{
|
||||
DataDir: t.TempDir(),
|
||||
}
|
||||
},
|
||||
database.New,
|
||||
database.NewWebhookDBManager,
|
||||
healthcheck.New,
|
||||
session.New,
|
||||
func() delivery.Notifier { return &noopNotifier{} },
|
||||
New,
|
||||
),
|
||||
fx.Populate(&h),
|
||||
)
|
||||
app.RequireStart()
|
||||
defer app.RequireStop()
|
||||
|
||||
t.Cleanup(app.RequireStop)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
t.Run("handles missing templates gracefully", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
data := map[string]any{"Version": "1.0.0"}
|
||||
data := map[string]interface{}{
|
||||
"Version": "1.0.0",
|
||||
}
|
||||
|
||||
h.RenderTemplateForTest(
|
||||
w, req, "nonexistent.html", data,
|
||||
)
|
||||
// When a non-existent template name is requested, renderTemplate
|
||||
// should return an internal server error
|
||||
h.renderTemplate(w, req, "nonexistent.html", data)
|
||||
|
||||
assert.Equal(
|
||||
t, http.StatusInternalServerError, w.Code,
|
||||
)
|
||||
// Should return internal server error when template is not found
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatUptime(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
duration string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "minutes only",
|
||||
duration: "45m",
|
||||
expected: "45m",
|
||||
},
|
||||
{
|
||||
name: "hours and minutes",
|
||||
duration: "2h30m",
|
||||
expected: "2h 30m",
|
||||
},
|
||||
{
|
||||
name: "days, hours and minutes",
|
||||
duration: "25h45m",
|
||||
expected: "1d 1h 45m",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d, err := time.ParseDuration(tt.duration)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := formatUptime(d)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,10 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const httpStatusOK = 200
|
||||
|
||||
// HandleHealthCheck returns an HTTP handler that reports
|
||||
// application health.
|
||||
// HandleHealthCheck returns an HTTP handler that reports application health.
|
||||
func (s *Handlers) HandleHealthCheck() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
resp := s.hc.Healthcheck()
|
||||
s.respondJSON(w, req, resp, httpStatusOK)
|
||||
s.respondJSON(w, req, resp, 200)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,50 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
)
|
||||
|
||||
// HandleIndex returns a handler for the root path that redirects
|
||||
// based on authentication state: authenticated users go to /sources
|
||||
// (the dashboard), unauthenticated users go to the login page.
|
||||
// HandleIndex returns an HTTP handler that renders the application dashboard.
|
||||
func (s *Handlers) HandleIndex() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := s.session.Get(r)
|
||||
if err == nil && s.session.IsAuthenticated(sess) {
|
||||
http.Redirect(w, r, "/sources", http.StatusSeeOther)
|
||||
// Calculate server start time
|
||||
startTime := time.Now()
|
||||
|
||||
return
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
// Calculate uptime
|
||||
uptime := time.Since(startTime)
|
||||
uptimeStr := formatUptime(uptime)
|
||||
|
||||
// Get user count from database
|
||||
var userCount int64
|
||||
s.db.DB().Model(&database.User{}).Count(&userCount)
|
||||
|
||||
// Prepare template data
|
||||
data := map[string]interface{}{
|
||||
"Version": s.params.Globals.Version,
|
||||
"Uptime": uptimeStr,
|
||||
"UserCount": userCount,
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
// Render the template
|
||||
s.renderTemplate(w, req, "index.html", data)
|
||||
}
|
||||
}
|
||||
|
||||
// formatUptime formats a duration into a human-readable string
|
||||
func formatUptime(d time.Duration) string {
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
minutes := int(d.Minutes()) % 60
|
||||
|
||||
if days > 0 {
|
||||
return fmt.Sprintf("%dd %dh %dm", days, hours, minutes)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
||||
requestedUsername := chi.URLParam(r, "username")
|
||||
if requestedUsername == "" {
|
||||
http.NotFound(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -22,7 +21,6 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
||||
if err != nil || !h.session.IsAuthenticated(sess) {
|
||||
// Redirect to login if not authenticated
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -31,7 +29,6 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
||||
if !ok {
|
||||
h.log.Error("authenticated session missing username")
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,19 +36,17 @@ func (h *Handlers) HandleProfile() http.HandlerFunc {
|
||||
if !ok {
|
||||
h.log.Error("authenticated session missing user ID")
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// For now, only allow users to view their own profile
|
||||
if requestedUsername != sessionUsername {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
data := map[string]any{
|
||||
data := map[string]interface{}{
|
||||
"User": &UserInfo{
|
||||
ID: sessionUserID,
|
||||
Username: sessionUsername,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,36 +6,31 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"gorm.io/gorm"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
"sneak.berlin/go/webhooker/internal/delivery"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxWebhookBodySize is the maximum allowed webhook
|
||||
// request body (1 MB).
|
||||
maxWebhookBodySize = 1 << maxBodyShift
|
||||
// maxWebhookBodySize is the maximum allowed webhook request body (1 MB).
|
||||
maxWebhookBodySize = 1 << 20
|
||||
)
|
||||
|
||||
// HandleWebhook handles incoming webhook requests at entrypoint
|
||||
// URLs.
|
||||
// HandleWebhook handles incoming webhook requests at entrypoint URLs.
|
||||
// Only POST requests are accepted; all other methods return 405 Method Not Allowed.
|
||||
// Events and deliveries are stored in the per-webhook database. The handler
|
||||
// builds self-contained Task structs with all target and event data
|
||||
// so the delivery engine can process them without additional DB reads.
|
||||
func (h *Handlers) HandleWebhook() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
w.Header().Set("Allow", "POST")
|
||||
http.Error(
|
||||
w,
|
||||
"Method Not Allowed",
|
||||
http.StatusMethodNotAllowed,
|
||||
)
|
||||
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
entrypointUUID := chi.URLParam(r, "uuid")
|
||||
if entrypointUUID == "" {
|
||||
http.NotFound(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -45,241 +40,69 @@ func (h *Handlers) HandleWebhook() http.HandlerFunc {
|
||||
"remote_addr", r.RemoteAddr,
|
||||
)
|
||||
|
||||
entrypoint, ok := h.lookupEntrypoint(
|
||||
w, r, entrypointUUID,
|
||||
)
|
||||
if !ok {
|
||||
// Look up entrypoint by path (from main application DB)
|
||||
var entrypoint database.Entrypoint
|
||||
result := h.db.DB().Where("path = ?", entrypointUUID).First(&entrypoint)
|
||||
if result.Error != nil {
|
||||
h.log.Debug("entrypoint not found", "path", entrypointUUID)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if active
|
||||
if !entrypoint.Active {
|
||||
http.Error(w, "Gone", http.StatusGone)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.processWebhookRequest(w, r, entrypoint)
|
||||
// Read body with size limit
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, maxWebhookBodySize+1))
|
||||
if err != nil {
|
||||
h.log.Error("failed to read request body", "error", err)
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// processWebhookRequest reads the body, serializes headers,
|
||||
// loads targets, and delivers the event.
|
||||
func (h *Handlers) processWebhookRequest(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
entrypoint database.Entrypoint,
|
||||
) {
|
||||
body, ok := h.readWebhookBody(w, r)
|
||||
if !ok {
|
||||
if len(body) > maxWebhookBodySize {
|
||||
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize headers as JSON
|
||||
headersJSON, err := json.Marshal(r.Header)
|
||||
if err != nil {
|
||||
h.serverError(w, "failed to serialize headers", err)
|
||||
|
||||
h.log.Error("failed to serialize headers", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
targets, err := h.loadActiveTargets(entrypoint.WebhookID)
|
||||
if err != nil {
|
||||
h.serverError(w, "failed to query targets", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.createAndDeliverEvent(
|
||||
w, r, entrypoint, body, headersJSON, targets,
|
||||
)
|
||||
}
|
||||
|
||||
// loadActiveTargets returns all active targets for a webhook.
|
||||
func (h *Handlers) loadActiveTargets(
|
||||
webhookID string,
|
||||
) ([]database.Target, error) {
|
||||
// Find all active targets for this webhook (from main application DB)
|
||||
var targets []database.Target
|
||||
|
||||
err := h.db.DB().Where(
|
||||
"webhook_id = ? AND active = ?",
|
||||
webhookID, true,
|
||||
).Find(&targets).Error
|
||||
|
||||
return targets, err
|
||||
}
|
||||
|
||||
// lookupEntrypoint finds an entrypoint by UUID path.
|
||||
func (h *Handlers) lookupEntrypoint(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
entrypointUUID string,
|
||||
) (database.Entrypoint, bool) {
|
||||
var entrypoint database.Entrypoint
|
||||
|
||||
result := h.db.DB().Where(
|
||||
"path = ?", entrypointUUID,
|
||||
).First(&entrypoint)
|
||||
if result.Error != nil {
|
||||
h.log.Debug(
|
||||
"entrypoint not found",
|
||||
"path", entrypointUUID,
|
||||
)
|
||||
http.NotFound(w, r)
|
||||
|
||||
return entrypoint, false
|
||||
}
|
||||
|
||||
return entrypoint, true
|
||||
}
|
||||
|
||||
// readWebhookBody reads and validates the request body size.
|
||||
func (h *Handlers) readWebhookBody(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) ([]byte, bool) {
|
||||
body, err := io.ReadAll(
|
||||
io.LimitReader(r.Body, maxWebhookBodySize+1),
|
||||
)
|
||||
if err != nil {
|
||||
h.log.Error(
|
||||
"failed to read request body", "error", err,
|
||||
)
|
||||
http.Error(
|
||||
w, "Bad request", http.StatusBadRequest,
|
||||
)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(body) > maxWebhookBodySize {
|
||||
http.Error(
|
||||
w,
|
||||
"Request body too large",
|
||||
http.StatusRequestEntityTooLarge,
|
||||
)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return body, true
|
||||
}
|
||||
|
||||
// createAndDeliverEvent creates the event and delivery records
|
||||
// then notifies the delivery engine.
|
||||
func (h *Handlers) createAndDeliverEvent(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
entrypoint database.Entrypoint,
|
||||
body, headersJSON []byte,
|
||||
targets []database.Target,
|
||||
) {
|
||||
tx, err := h.beginWebhookTx(w, entrypoint.WebhookID)
|
||||
if err != nil {
|
||||
if targetErr := h.db.DB().Where("webhook_id = ? AND active = ?", entrypoint.WebhookID, true).Find(&targets).Error; targetErr != nil {
|
||||
h.log.Error("failed to query targets", "error", targetErr)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
event := h.buildEvent(r, entrypoint, headersJSON, body)
|
||||
|
||||
err = tx.Create(event).Error
|
||||
// Get the per-webhook database for event storage
|
||||
webhookDB, err := h.dbMgr.GetDB(entrypoint.WebhookID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
h.serverError(w, "failed to create event", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
bodyPtr := inlineBody(body)
|
||||
|
||||
tasks := h.buildDeliveryTasks(
|
||||
w, tx, event, entrypoint, targets, bodyPtr,
|
||||
h.log.Error("failed to get webhook database",
|
||||
"webhook_id", entrypoint.WebhookID,
|
||||
"error", err,
|
||||
)
|
||||
if tasks == nil {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
h.serverError(w, "failed to commit transaction", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.finishWebhookResponse(w, event, entrypoint, tasks)
|
||||
}
|
||||
|
||||
// beginWebhookTx opens a transaction on the per-webhook DB.
|
||||
func (h *Handlers) beginWebhookTx(
|
||||
w http.ResponseWriter,
|
||||
webhookID string,
|
||||
) (*gorm.DB, error) {
|
||||
webhookDB, err := h.dbMgr.GetDB(webhookID)
|
||||
if err != nil {
|
||||
h.serverError(
|
||||
w, "failed to get webhook database", err,
|
||||
)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the event and deliveries in a transaction on the per-webhook DB
|
||||
tx := webhookDB.Begin()
|
||||
if tx.Error != nil {
|
||||
h.serverError(
|
||||
w, "failed to begin transaction", tx.Error,
|
||||
)
|
||||
|
||||
return nil, tx.Error
|
||||
h.log.Error("failed to begin transaction", "error", tx.Error)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
// inlineBody returns a pointer to body as a string if it fits
|
||||
// within the inline size limit, or nil otherwise.
|
||||
func inlineBody(body []byte) *string {
|
||||
if len(body) < delivery.MaxInlineBodySize {
|
||||
s := string(body)
|
||||
|
||||
return &s
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// finishWebhookResponse notifies the delivery engine, logs the
|
||||
// event, and writes the HTTP response.
|
||||
func (h *Handlers) finishWebhookResponse(
|
||||
w http.ResponseWriter,
|
||||
event *database.Event,
|
||||
entrypoint database.Entrypoint,
|
||||
tasks []delivery.Task,
|
||||
) {
|
||||
if len(tasks) > 0 {
|
||||
h.notifier.Notify(tasks)
|
||||
}
|
||||
|
||||
h.log.Info("webhook event created",
|
||||
"event_id", event.ID,
|
||||
"webhook_id", entrypoint.WebhookID,
|
||||
"entrypoint_id", entrypoint.ID,
|
||||
"target_count", len(tasks),
|
||||
)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := w.Write([]byte(`{"status":"ok"}`))
|
||||
if err != nil {
|
||||
h.log.Error(
|
||||
"failed to write response", "error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEvent creates a new Event struct from request data.
|
||||
func (h *Handlers) buildEvent(
|
||||
r *http.Request,
|
||||
entrypoint database.Entrypoint,
|
||||
headersJSON, body []byte,
|
||||
) *database.Event {
|
||||
return &database.Event{
|
||||
event := &database.Event{
|
||||
WebhookID: entrypoint.WebhookID,
|
||||
EntrypointID: entrypoint.ID,
|
||||
Method: r.Method,
|
||||
@@ -287,42 +110,38 @@ func (h *Handlers) buildEvent(
|
||||
Body: string(body),
|
||||
ContentType: r.Header.Get("Content-Type"),
|
||||
}
|
||||
}
|
||||
|
||||
// buildDeliveryTasks creates delivery records in the
|
||||
// transaction and returns tasks for the delivery engine.
|
||||
// Returns nil if an error occurred.
|
||||
func (h *Handlers) buildDeliveryTasks(
|
||||
w http.ResponseWriter,
|
||||
tx *gorm.DB,
|
||||
event *database.Event,
|
||||
entrypoint database.Entrypoint,
|
||||
targets []database.Target,
|
||||
bodyPtr *string,
|
||||
) []delivery.Task {
|
||||
if err := tx.Create(event).Error; err != nil {
|
||||
tx.Rollback()
|
||||
h.log.Error("failed to create event", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare body pointer for inline transport (≤16KB bodies are
|
||||
// included in the Task so the engine needs no DB read).
|
||||
var bodyPtr *string
|
||||
if len(body) < delivery.MaxInlineBodySize {
|
||||
bodyStr := string(body)
|
||||
bodyPtr = &bodyStr
|
||||
}
|
||||
|
||||
// Create delivery records and build self-contained delivery tasks
|
||||
tasks := make([]delivery.Task, 0, len(targets))
|
||||
|
||||
for i := range targets {
|
||||
dlv := &database.Delivery{
|
||||
EventID: event.ID,
|
||||
TargetID: targets[i].ID,
|
||||
Status: database.DeliveryStatusPending,
|
||||
}
|
||||
|
||||
err := tx.Create(dlv).Error
|
||||
if err != nil {
|
||||
if err := tx.Create(dlv).Error; err != nil {
|
||||
tx.Rollback()
|
||||
h.log.Error(
|
||||
"failed to create delivery",
|
||||
h.log.Error("failed to create delivery",
|
||||
"target_id", targets[i].ID,
|
||||
"error", err,
|
||||
)
|
||||
http.Error(
|
||||
w, "Internal server error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return nil
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tasks = append(tasks, delivery.Task{
|
||||
@@ -342,5 +161,31 @@ func (h *Handlers) buildDeliveryTasks(
|
||||
})
|
||||
}
|
||||
|
||||
return tasks
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
h.log.Error("failed to commit transaction", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Notify the delivery engine with self-contained delivery tasks.
|
||||
// Each task carries all target config and event data inline so
|
||||
// the engine can deliver without touching any database (in the
|
||||
// ≤16KB happy path). The engine only writes to the DB to record
|
||||
// delivery results after each attempt.
|
||||
if len(tasks) > 0 {
|
||||
h.notifier.Notify(tasks)
|
||||
}
|
||||
|
||||
h.log.Info("webhook event created",
|
||||
"event_id", event.ID,
|
||||
"webhook_id", entrypoint.WebhookID,
|
||||
"entrypoint_id", entrypoint.ID,
|
||||
"target_count", len(targets),
|
||||
)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil {
|
||||
h.log.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
//nolint:revive // HealthcheckParams is a standard fx naming convention.
|
||||
type HealthcheckParams struct {
|
||||
fx.In
|
||||
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
@@ -30,34 +29,43 @@ type Healthcheck struct {
|
||||
params *HealthcheckParams
|
||||
}
|
||||
|
||||
// New creates a Healthcheck that records the startup time on fx
|
||||
// start.
|
||||
func New(
|
||||
lc fx.Lifecycle,
|
||||
params HealthcheckParams,
|
||||
) (*Healthcheck, error) {
|
||||
// New creates a Healthcheck that records the startup time on fx start.
|
||||
func New(lc fx.Lifecycle, params HealthcheckParams) (*Healthcheck, error) {
|
||||
s := new(Healthcheck)
|
||||
s.params = ¶ms
|
||||
s.log = params.Logger.Get()
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
//nolint:revive // ctx unused but required by fx.
|
||||
OnStart: func(_ context.Context) error {
|
||||
s.StartupTime = time.Now()
|
||||
|
||||
return nil
|
||||
},
|
||||
OnStop: func(_ context.Context) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Healthcheck returns the current health status of the
|
||||
// application.
|
||||
func (s *Healthcheck) Healthcheck() *Response {
|
||||
resp := &Response{
|
||||
//nolint:revive // HealthcheckResponse is a clear, descriptive name.
|
||||
type HealthcheckResponse struct {
|
||||
Status string `json:"status"`
|
||||
Now string `json:"now"`
|
||||
UptimeSeconds int64 `json:"uptime_seconds"`
|
||||
UptimeHuman string `json:"uptime_human"`
|
||||
Version string `json:"version"`
|
||||
Appname string `json:"appname"`
|
||||
Maintenance bool `json:"maintenance_mode"`
|
||||
}
|
||||
|
||||
func (s *Healthcheck) uptime() time.Duration {
|
||||
return time.Since(s.StartupTime)
|
||||
}
|
||||
|
||||
// Healthcheck returns the current health status of the application.
|
||||
func (s *Healthcheck) Healthcheck() *HealthcheckResponse {
|
||||
resp := &HealthcheckResponse{
|
||||
Status: "ok",
|
||||
Now: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
UptimeSeconds: int64(s.uptime().Seconds()),
|
||||
@@ -66,21 +74,5 @@ func (s *Healthcheck) Healthcheck() *Response {
|
||||
Version: s.params.Globals.Version,
|
||||
Maintenance: s.params.Config.MaintenanceMode,
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Response contains the JSON-serialised health status.
|
||||
type Response struct {
|
||||
Status string `json:"status"`
|
||||
Now string `json:"now"`
|
||||
UptimeSeconds int64 `json:"uptimeSeconds"`
|
||||
UptimeHuman string `json:"uptimeHuman"`
|
||||
Version string `json:"version"`
|
||||
Appname string `json:"appname"`
|
||||
Maintenance bool `json:"maintenanceMode"`
|
||||
}
|
||||
|
||||
func (s *Healthcheck) uptime() time.Duration {
|
||||
return time.Since(s.StartupTime)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Package logger provides structured logging with dynamic level
|
||||
// control.
|
||||
// Package logger provides structured logging with dynamic level control.
|
||||
package logger
|
||||
|
||||
import (
|
||||
@@ -15,20 +14,17 @@ import (
|
||||
//nolint:revive // LoggerParams is a standard fx naming convention.
|
||||
type LoggerParams struct {
|
||||
fx.In
|
||||
|
||||
Globals *globals.Globals
|
||||
}
|
||||
|
||||
// Logger wraps slog with dynamic level control and structured
|
||||
// output.
|
||||
// Logger wraps slog with dynamic level control and structured output.
|
||||
type Logger struct {
|
||||
logger *slog.Logger
|
||||
levelVar *slog.LevelVar
|
||||
params LoggerParams
|
||||
}
|
||||
|
||||
// New creates a Logger that outputs text (TTY) or JSON (non-TTY)
|
||||
// to stdout.
|
||||
// New creates a Logger that outputs text (TTY) or JSON (non-TTY) to stdout.
|
||||
//
|
||||
//nolint:revive // lc parameter is required by fx even if unused.
|
||||
func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
||||
@@ -52,15 +48,11 @@ func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
||||
if t, ok := a.Value.Any().(time.Time); ok {
|
||||
return slog.Time(slog.TimeKey, t.UTC())
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: l.levelVar,
|
||||
ReplaceAttr: replaceAttr,
|
||||
@@ -101,8 +93,7 @@ func (l *Logger) Identify() {
|
||||
)
|
||||
}
|
||||
|
||||
// Writer returns an io.Writer suitable for standard library
|
||||
// loggers.
|
||||
// Writer returns an io.Writer suitable for standard library loggers.
|
||||
func (l *Logger) Writer() io.Writer {
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
@@ -1,59 +1,63 @@
|
||||
package logger_test
|
||||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/fx/fxtest"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
func testGlobals() *globals.Globals {
|
||||
return &globals.Globals{
|
||||
Appname: "test-app",
|
||||
Version: "1.0.0",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Set up globals
|
||||
globals.Appname = "test-app"
|
||||
globals.Version = "1.0.0"
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
params := logger.LoggerParams{
|
||||
Globals: testGlobals(),
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("globals.New() error = %v", err)
|
||||
}
|
||||
|
||||
l, err := logger.New(lc, params)
|
||||
params := LoggerParams{
|
||||
Globals: g,
|
||||
}
|
||||
|
||||
logger, err := New(lc, params)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if l.Get() == nil {
|
||||
if logger.Get() == nil {
|
||||
t.Error("Get() returned nil logger")
|
||||
}
|
||||
|
||||
// Test that we can log without panic
|
||||
l.Get().Info("test message", "key", "value")
|
||||
logger.Get().Info("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestEnableDebugLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Set up globals
|
||||
globals.Appname = "test-app"
|
||||
globals.Version = "1.0.0"
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
params := logger.LoggerParams{
|
||||
Globals: testGlobals(),
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("globals.New() error = %v", err)
|
||||
}
|
||||
|
||||
l, err := logger.New(lc, params)
|
||||
params := LoggerParams{
|
||||
Globals: g,
|
||||
}
|
||||
|
||||
logger, err := New(lc, params)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
// Enable debug logging should not panic
|
||||
l.EnableDebugLogging()
|
||||
logger.EnableDebugLogging()
|
||||
|
||||
// Test debug logging
|
||||
l.Get().Debug("debug message", "test", true)
|
||||
logger.Get().Debug("debug message", "test", true)
|
||||
}
|
||||
|
||||
@@ -12,13 +12,6 @@ func CSRFToken(r *http.Request) string {
|
||||
return csrf.Token(r)
|
||||
}
|
||||
|
||||
// isClientTLS reports whether the client-facing connection uses TLS.
|
||||
// It checks for a direct TLS connection (r.TLS) or a TLS-terminating
|
||||
// reverse proxy that sets the standard X-Forwarded-Proto header.
|
||||
func isClientTLS(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
|
||||
// CSRF returns middleware that provides CSRF protection using the
|
||||
// gorilla/csrf library. The middleware uses the session authentication
|
||||
// key to sign a CSRF cookie and validates a masked token submitted via
|
||||
@@ -26,22 +19,17 @@ func isClientTLS(r *http.Request) bool {
|
||||
// POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing
|
||||
// token receive a 403 Forbidden response.
|
||||
//
|
||||
// The middleware detects the client-facing transport protocol per-request
|
||||
// using r.TLS and the X-Forwarded-Proto header. This allows correct
|
||||
// behavior in all deployment scenarios:
|
||||
//
|
||||
// - Direct HTTPS: strict Referer/Origin checks, Secure cookies.
|
||||
// - Behind a TLS-terminating reverse proxy: strict checks (the
|
||||
// browser is on HTTPS, so Origin/Referer headers use https://),
|
||||
// Secure cookies (the browser sees HTTPS from the proxy).
|
||||
// - Direct HTTP: relaxed Referer/Origin checks via PlaintextHTTPRequest,
|
||||
// non-Secure cookies so the browser sends them over HTTP.
|
||||
//
|
||||
// Two gorilla/csrf instances are maintained — one with Secure cookies
|
||||
// (for TLS) and one without (for plaintext HTTP) — because the
|
||||
// csrf.Secure option is set at creation time, not per-request.
|
||||
// In development mode, requests are marked as plaintext HTTP so that
|
||||
// gorilla/csrf skips the strict Referer-origin check (which is only
|
||||
// meaningful over TLS).
|
||||
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
||||
csrfErrorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
protect := csrf.Protect(
|
||||
m.session.GetKey(),
|
||||
csrf.FieldName("csrf_token"),
|
||||
csrf.Secure(!m.params.Config.IsDev()),
|
||||
csrf.SameSite(csrf.SameSiteLaxMode),
|
||||
csrf.Path("/"),
|
||||
csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
m.log.Warn("csrf: token validation failed",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
@@ -49,36 +37,20 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
||||
"reason", csrf.FailureReason(r),
|
||||
)
|
||||
http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden)
|
||||
})
|
||||
|
||||
key := m.session.GetKey()
|
||||
baseOpts := []csrf.Option{
|
||||
csrf.FieldName("csrf_token"),
|
||||
csrf.SameSite(csrf.SameSiteLaxMode),
|
||||
csrf.Path("/"),
|
||||
csrf.ErrorHandler(csrfErrorHandler),
|
||||
}
|
||||
|
||||
// Two middleware instances with different Secure flags but the
|
||||
// same signing key, so cookies are interchangeable between them.
|
||||
tlsProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(true))...)
|
||||
httpProtect := csrf.Protect(key, append(baseOpts, csrf.Secure(false))...)
|
||||
})),
|
||||
)
|
||||
|
||||
// In development (plaintext HTTP), signal gorilla/csrf to skip
|
||||
// the strict TLS Referer check by injecting the PlaintextHTTP
|
||||
// context key before the CSRF handler sees the request.
|
||||
if m.params.Config.IsDev() {
|
||||
return func(next http.Handler) http.Handler {
|
||||
tlsCSRF := tlsProtect(next)
|
||||
httpCSRF := httpProtect(next)
|
||||
|
||||
csrfHandler := protect(next)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if isClientTLS(r) {
|
||||
// Client is on TLS (directly or via reverse proxy).
|
||||
// Use Secure cookies and strict Origin/Referer checks.
|
||||
tlsCSRF.ServeHTTP(w, r)
|
||||
} else {
|
||||
// Plaintext HTTP: use non-Secure cookies and tell
|
||||
// gorilla/csrf to use "http" for scheme comparisons,
|
||||
// skipping the strict Referer check that assumes TLS.
|
||||
httpCSRF.ServeHTTP(w, csrf.PlaintextHTTPRequest(r))
|
||||
}
|
||||
csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return protect
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package middleware_test
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -12,483 +10,148 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/middleware"
|
||||
)
|
||||
|
||||
// csrfCookieName is the gorilla/csrf cookie name.
|
||||
const csrfCookieName = "_gorilla_csrf"
|
||||
|
||||
// csrfGetToken performs a GET request through the CSRF middleware
|
||||
// and returns the token and cookies.
|
||||
func csrfGetToken(
|
||||
t *testing.T,
|
||||
csrfMW func(http.Handler) http.Handler,
|
||||
getReq *http.Request,
|
||||
) (string, []*http.Cookie) {
|
||||
t.Helper()
|
||||
|
||||
var token string
|
||||
|
||||
getHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = middleware.CSRFToken(r)
|
||||
},
|
||||
))
|
||||
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies, "CSRF cookie should be set")
|
||||
require.NotEmpty(t, token, "CSRF token should be set")
|
||||
|
||||
return token, cookies
|
||||
}
|
||||
|
||||
// csrfPostWithToken performs a POST request with the given CSRF
|
||||
// token and cookies through the middleware. Returns whether the
|
||||
// handler was called and the response code.
|
||||
func csrfPostWithToken(
|
||||
t *testing.T,
|
||||
csrfMW func(http.Handler) http.Handler,
|
||||
postReq *http.Request,
|
||||
token string,
|
||||
cookies []*http.Cookie,
|
||||
) (bool, int) {
|
||||
t.Helper()
|
||||
|
||||
var called bool
|
||||
|
||||
postHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq.Body = http.NoBody
|
||||
postReq.Body = nil
|
||||
|
||||
// Rebuild the request with the form body
|
||||
rebuilt := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
postReq.Method, postReq.URL.String(),
|
||||
strings.NewReader(form.Encode()),
|
||||
)
|
||||
rebuilt.Header = postReq.Header.Clone()
|
||||
rebuilt.TLS = postReq.TLS
|
||||
rebuilt.Header.Set(
|
||||
"Content-Type", "application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
for _, c := range cookies {
|
||||
rebuilt.AddCookie(c)
|
||||
}
|
||||
|
||||
postW := httptest.NewRecorder()
|
||||
postHandler.ServeHTTP(postW, rebuilt)
|
||||
|
||||
return called, postW.Code
|
||||
}
|
||||
|
||||
func TestCSRF_GETSetsToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var gotToken string
|
||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
gotToken = CSRFToken(r)
|
||||
}))
|
||||
|
||||
handler := m.CSRF()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, r *http.Request) {
|
||||
gotToken = middleware.CSRFToken(r)
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.NotEmpty(
|
||||
t, gotToken,
|
||||
"CSRF token should be set in context on GET",
|
||||
)
|
||||
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/form", nil,
|
||||
)
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
// Capture the token from a GET request
|
||||
var token string
|
||||
csrfMiddleware := m.CSRF()
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = CSRFToken(r)
|
||||
}))
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form", nil,
|
||||
)
|
||||
called, _ := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
assert.True(
|
||||
t, called,
|
||||
"handler should be called with valid CSRF token",
|
||||
)
|
||||
}
|
||||
|
||||
// csrfPOSTWithoutTokenTest is a shared helper for testing POST
|
||||
// requests without a CSRF token in both dev and prod modes.
|
||||
func csrfPOSTWithoutTokenTest(
|
||||
t *testing.T,
|
||||
env string,
|
||||
msg string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
m, _ := testMiddleware(t, env)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
// GET to establish the CSRF cookie
|
||||
getHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {},
|
||||
))
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
// POST without CSRF token
|
||||
// POST with valid token and cookies from the GET response
|
||||
var called bool
|
||||
|
||||
postHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form", nil,
|
||||
)
|
||||
postReq.Header.Set(
|
||||
"Content-Type", "application/x-www-form-urlencoded",
|
||||
)
|
||||
}))
|
||||
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, msg)
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
assert.True(t, called, "handler should be called with valid CSRF token")
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
csrfPOSTWithoutTokenTest(
|
||||
t,
|
||||
config.EnvironmentDev,
|
||||
"handler should NOT be called without CSRF token",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
csrfMiddleware := m.CSRF()
|
||||
|
||||
// GET to establish the CSRF cookie
|
||||
getHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {},
|
||||
))
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST with wrong CSRF token
|
||||
// POST without CSRF token
|
||||
var called bool
|
||||
|
||||
postHandler := csrfMW(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
|
||||
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form",
|
||||
strings.NewReader(form.Encode()),
|
||||
)
|
||||
postReq.Header.Set(
|
||||
"Content-Type", "application/x-www-form-urlencoded",
|
||||
)
|
||||
}))
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(
|
||||
t, called,
|
||||
"handler should NOT be called with invalid CSRF token",
|
||||
)
|
||||
assert.False(t, called, "handler should NOT be called without CSRF token")
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
csrfMiddleware := m.CSRF()
|
||||
|
||||
// GET to establish the CSRF cookie
|
||||
getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST with wrong CSRF token
|
||||
var called bool
|
||||
postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
form := url.Values{"csrf_token": {"invalid-token-value"}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, "handler should NOT be called with invalid CSRF token")
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
}
|
||||
|
||||
func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.CSRF()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/form", nil)
|
||||
// GET requests should pass through without CSRF validation
|
||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(
|
||||
t, called,
|
||||
"GET requests should pass through CSRF middleware",
|
||||
)
|
||||
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
||||
}
|
||||
|
||||
func TestCSRFToken_NoMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
assert.Empty(
|
||||
t, middleware.CSRFToken(req),
|
||||
"CSRFToken should return empty string when "+
|
||||
"middleware has not run",
|
||||
)
|
||||
}
|
||||
|
||||
// --- TLS Detection Tests ---
|
||||
|
||||
func TestIsClientTLS_DirectTLS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
r.TLS = &tls.ConnectionState{}
|
||||
|
||||
assert.True(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect direct TLS connection",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIsClientTLS_XForwardedProto(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
r.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
assert.True(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect TLS via X-Forwarded-Proto",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIsClientTLS_PlaintextHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
assert.False(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect plaintext HTTP",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIsClientTLS_XForwardedProtoHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
r.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
assert.False(
|
||||
t, middleware.IsClientTLS(r),
|
||||
"should detect plaintext when X-Forwarded-Proto is http",
|
||||
)
|
||||
}
|
||||
|
||||
// --- Production Mode: POST over plaintext HTTP ---
|
||||
|
||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithValidToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/form", nil,
|
||||
)
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
// Verify cookie is NOT Secure (plaintext HTTP in prod)
|
||||
for _, c := range cookies {
|
||||
if c.Name == csrfCookieName {
|
||||
assert.False(t, c.Secure,
|
||||
"CSRF cookie should not be Secure "+
|
||||
"over plaintext HTTP")
|
||||
}
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/form", nil,
|
||||
)
|
||||
called, code := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
assert.True(t, called,
|
||||
"handler should be called -- prod mode over "+
|
||||
"plaintext HTTP must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, code,
|
||||
"should not return 403")
|
||||
}
|
||||
|
||||
// --- Production Mode: POST with X-Forwarded-Proto ---
|
||||
|
||||
func TestCSRF_ProdMode_BehindProxy_POSTWithValidToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "http://example.com/form", nil,
|
||||
)
|
||||
getReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
// Verify cookie IS Secure (X-Forwarded-Proto: https)
|
||||
for _, c := range cookies {
|
||||
if c.Name == csrfCookieName {
|
||||
assert.True(t, c.Secure,
|
||||
"CSRF cookie should be Secure behind "+
|
||||
"TLS proxy")
|
||||
}
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "http://example.com/form", nil,
|
||||
)
|
||||
postReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
postReq.Header.Set("Origin", "https://example.com")
|
||||
|
||||
called, code := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
assert.True(t, called,
|
||||
"handler should be called -- prod mode behind "+
|
||||
"TLS proxy must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, code,
|
||||
"should not return 403")
|
||||
}
|
||||
|
||||
// --- Production Mode: direct TLS ---
|
||||
|
||||
func TestCSRF_ProdMode_DirectTLS_POSTWithValidToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
csrfMW := m.CSRF()
|
||||
|
||||
getReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "https://example.com/form", nil,
|
||||
)
|
||||
getReq.TLS = &tls.ConnectionState{}
|
||||
|
||||
token, cookies := csrfGetToken(t, csrfMW, getReq)
|
||||
|
||||
// Verify cookie IS Secure (direct TLS)
|
||||
for _, c := range cookies {
|
||||
if c.Name == csrfCookieName {
|
||||
assert.True(t, c.Secure,
|
||||
"CSRF cookie should be Secure over "+
|
||||
"direct TLS")
|
||||
}
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "https://example.com/form", nil,
|
||||
)
|
||||
postReq.TLS = &tls.ConnectionState{}
|
||||
postReq.Header.Set("Origin", "https://example.com")
|
||||
|
||||
called, code := csrfPostWithToken(
|
||||
t, csrfMW, postReq, token, cookies,
|
||||
)
|
||||
|
||||
assert.True(t, called,
|
||||
"handler should be called -- direct TLS must work")
|
||||
assert.NotEqual(t, http.StatusForbidden, code,
|
||||
"should not return 403")
|
||||
}
|
||||
|
||||
// --- Production Mode: POST without token still rejects ---
|
||||
|
||||
func TestCSRF_ProdMode_PlaintextHTTP_POSTWithoutToken(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
csrfPOSTWithoutTokenTest(
|
||||
t,
|
||||
config.EnvironmentProd,
|
||||
"handler should NOT be called without CSRF token "+
|
||||
"even in prod+plaintext",
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run")
|
||||
}
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// NewLoggingResponseWriterForTest wraps newLoggingResponseWriter
|
||||
// for use in external test packages.
|
||||
func NewLoggingResponseWriterForTest(
|
||||
w http.ResponseWriter,
|
||||
) *loggingResponseWriter {
|
||||
return newLoggingResponseWriter(w)
|
||||
}
|
||||
|
||||
// LoggingResponseWriterStatusCode returns the status code
|
||||
// captured by the loggingResponseWriter.
|
||||
func LoggingResponseWriterStatusCode(
|
||||
lrw *loggingResponseWriter,
|
||||
) int {
|
||||
return lrw.statusCode
|
||||
}
|
||||
|
||||
// IPFromHostPort exposes ipFromHostPort for testing.
|
||||
func IPFromHostPort(hp string) string {
|
||||
return ipFromHostPort(hp)
|
||||
}
|
||||
|
||||
// IsClientTLS exposes isClientTLS for testing.
|
||||
func IsClientTLS(r *http.Request) bool {
|
||||
return isClientTLS(r)
|
||||
}
|
||||
|
||||
// LoginRateLimitConst exposes the loginRateLimit constant.
|
||||
const LoginRateLimitConst = loginRateLimit
|
||||
@@ -1,5 +1,4 @@
|
||||
// Package middleware provides HTTP middleware for logging, auth,
|
||||
// CORS, and metrics.
|
||||
// Package middleware provides HTTP middleware for logging, auth, CORS, and metrics.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -21,24 +20,16 @@ import (
|
||||
"sneak.berlin/go/webhooker/internal/session"
|
||||
)
|
||||
|
||||
const (
|
||||
// corsMaxAge is the maximum time (in seconds) that a
|
||||
// preflight response can be cached.
|
||||
corsMaxAge = 300
|
||||
)
|
||||
|
||||
//nolint:revive // MiddlewareParams is a standard fx naming convention.
|
||||
type MiddlewareParams struct {
|
||||
fx.In
|
||||
|
||||
Logger *logger.Logger
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
Session *session.Session
|
||||
}
|
||||
|
||||
// Middleware provides HTTP middleware for logging, CORS, auth, and
|
||||
// metrics.
|
||||
// Middleware provides HTTP middleware for logging, CORS, auth, and metrics.
|
||||
type Middleware struct {
|
||||
log *slog.Logger
|
||||
params *MiddlewareParams
|
||||
@@ -48,15 +39,11 @@ type Middleware struct {
|
||||
// New creates a Middleware from the provided fx parameters.
|
||||
//
|
||||
//nolint:revive // lc parameter is required by fx even if unused.
|
||||
func New(
|
||||
lc fx.Lifecycle,
|
||||
params MiddlewareParams,
|
||||
) (*Middleware, error) {
|
||||
func New(lc fx.Lifecycle, params MiddlewareParams) (*Middleware, error) {
|
||||
s := new(Middleware)
|
||||
s.params = ¶ms
|
||||
s.log = params.Logger.Get()
|
||||
s.session = params.Session
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -68,24 +55,19 @@ func ipFromHostPort(hp string) string {
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(h) > 0 && h[0] == '[' {
|
||||
return h[1 : len(h)-1]
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
type loggingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// newLoggingResponseWriter wraps w and records status codes.
|
||||
func newLoggingResponseWriter(
|
||||
w http.ResponseWriter,
|
||||
) *loggingResponseWriter {
|
||||
// nolint:revive // unexported type is only used internally
|
||||
func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
|
||||
return &loggingResponseWriter{w, http.StatusOK}
|
||||
}
|
||||
|
||||
@@ -94,30 +76,21 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Logging returns middleware that logs each HTTP request with
|
||||
// timing and metadata.
|
||||
// Logging returns middleware that logs each HTTP request with timing and metadata.
|
||||
func (s *Middleware) Logging() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
lrw := newLoggingResponseWriter(w)
|
||||
lrw := NewLoggingResponseWriter(w)
|
||||
ctx := r.Context()
|
||||
|
||||
defer func() {
|
||||
latency := time.Since(start)
|
||||
requestID := ""
|
||||
|
||||
if reqID := ctx.Value(
|
||||
middleware.RequestIDKey,
|
||||
); reqID != nil {
|
||||
if reqID := ctx.Value(middleware.RequestIDKey); reqID != nil {
|
||||
if id, ok := reqID.(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Info("http request",
|
||||
"request_start", start,
|
||||
"method", r.Method,
|
||||
@@ -137,29 +110,21 @@ func (s *Middleware) Logging() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// CORS returns middleware that sets CORS headers (permissive in
|
||||
// dev, no-op in prod).
|
||||
// CORS returns middleware that sets CORS headers (permissive in dev, no-op in prod).
|
||||
func (s *Middleware) CORS() func(http.Handler) http.Handler {
|
||||
if s.params.Config.IsDev() {
|
||||
// In development, allow any origin for local testing.
|
||||
return cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{
|
||||
"GET", "POST", "PUT", "DELETE", "OPTIONS",
|
||||
},
|
||||
AllowedHeaders: []string{
|
||||
"Accept", "Authorization",
|
||||
"Content-Type", "X-CSRF-Token",
|
||||
},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: false,
|
||||
MaxAge: corsMaxAge,
|
||||
MaxAge: 300,
|
||||
})
|
||||
}
|
||||
|
||||
// In production, the web UI is server-rendered so
|
||||
// cross-origin requests are not expected. Return a no-op
|
||||
// middleware.
|
||||
// In production, the web UI is server-rendered so cross-origin
|
||||
// requests are not expected. Return a no-op middleware.
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
@@ -169,33 +134,20 @@ func (s *Middleware) CORS() func(http.Handler) http.Handler {
|
||||
// Unauthenticated users are redirected to the login page.
|
||||
func (s *Middleware) RequireAuth() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := s.session.Get(r)
|
||||
if err != nil {
|
||||
s.log.Debug(
|
||||
"auth middleware: failed to get session",
|
||||
"error", err,
|
||||
)
|
||||
http.Redirect(
|
||||
w, r, "/pages/login", http.StatusSeeOther,
|
||||
)
|
||||
|
||||
s.log.Debug("auth middleware: failed to get session", "error", err)
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.session.IsAuthenticated(sess) {
|
||||
s.log.Debug(
|
||||
"auth middleware: unauthenticated request",
|
||||
s.log.Debug("auth middleware: unauthenticated request",
|
||||
"path", r.URL.Path,
|
||||
"method", r.Method,
|
||||
)
|
||||
http.Redirect(
|
||||
w, r, "/pages/login", http.StatusSeeOther,
|
||||
)
|
||||
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -209,14 +161,12 @@ func (s *Middleware) Metrics() func(http.Handler) http.Handler {
|
||||
mdlw := ghmm.New(ghmm.Config{
|
||||
Recorder: metrics.NewRecorder(metrics.Config{}),
|
||||
})
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return std.Handler("", mdlw, next)
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsAuth returns middleware that protects metrics endpoints
|
||||
// with basic auth.
|
||||
// MetricsAuth returns middleware that protects metrics endpoints with basic auth.
|
||||
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
||||
return basicauth.New(
|
||||
"metrics",
|
||||
@@ -228,63 +178,33 @@ func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
||||
)
|
||||
}
|
||||
|
||||
// SecurityHeaders returns middleware that sets production security
|
||||
// headers on every response: HSTS, X-Content-Type-Options,
|
||||
// X-Frame-Options, CSP, Referrer-Policy, and Permissions-Policy.
|
||||
// SecurityHeaders returns middleware that sets production security headers
|
||||
// on every response: HSTS, X-Content-Type-Options, X-Frame-Options, CSP,
|
||||
// Referrer-Policy, and Permissions-Policy.
|
||||
func (s *Middleware) SecurityHeaders() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
w.Header().Set(
|
||||
"Strict-Transport-Security",
|
||||
"max-age=63072000; includeSubDomains; preload",
|
||||
)
|
||||
w.Header().Set(
|
||||
"X-Content-Type-Options", "nosniff",
|
||||
)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set(
|
||||
"Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"script-src 'self' 'unsafe-inline'; "+
|
||||
"style-src 'self' 'unsafe-inline'",
|
||||
)
|
||||
w.Header().Set(
|
||||
"Referrer-Policy",
|
||||
"strict-origin-when-cross-origin",
|
||||
)
|
||||
w.Header().Set(
|
||||
"Permissions-Policy",
|
||||
"camera=(), microphone=(), geolocation=()",
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MaxBodySize returns middleware that limits the request body size
|
||||
// for POST requests. If the body exceeds the given limit in
|
||||
// bytes, the server returns 413 Request Entity Too Large. This
|
||||
// prevents clients from sending arbitrarily large form bodies.
|
||||
func (s *Middleware) MaxBodySize(
|
||||
maxBytes int64,
|
||||
) func(http.Handler) http.Handler {
|
||||
// MaxBodySize returns middleware that limits the request body size for POST
|
||||
// requests. If the body exceeds the given limit in bytes, the server returns
|
||||
// 413 Request Entity Too Large. This prevents clients from sending arbitrarily
|
||||
// large form bodies.
|
||||
func (s *Middleware) MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
if r.Method == http.MethodPost ||
|
||||
r.Method == http.MethodPut ||
|
||||
r.Method == http.MethodPatch {
|
||||
r.Body = http.MaxBytesReader(
|
||||
w, r.Body, maxBytes,
|
||||
)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package middleware_test
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -13,37 +12,25 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/middleware"
|
||||
"sneak.berlin/go/webhooker/internal/session"
|
||||
)
|
||||
|
||||
const testKeySize = 32
|
||||
|
||||
// testMiddleware creates a Middleware with minimal dependencies
|
||||
// for testing. It uses a real session.Session backed by an
|
||||
// in-memory cookie store.
|
||||
func testMiddleware(
|
||||
t *testing.T,
|
||||
env string,
|
||||
) (*middleware.Middleware, *session.Session) {
|
||||
// testMiddleware creates a Middleware with minimal dependencies for testing.
|
||||
// It uses a real session.Session backed by an in-memory cookie store.
|
||||
func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) {
|
||||
t.Helper()
|
||||
|
||||
log := slog.New(slog.NewTextHandler(
|
||||
os.Stderr,
|
||||
&slog.HandlerOptions{Level: slog.LevelDebug},
|
||||
))
|
||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
|
||||
cfg := &config.Config{
|
||||
Environment: env,
|
||||
}
|
||||
|
||||
// Create a real session manager with a known key
|
||||
key := make([]byte, testKeySize)
|
||||
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
|
||||
store := sessions.NewCookieStore(key)
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
@@ -53,33 +40,40 @@ func testMiddleware(
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
sessManager := session.NewForTest(store, cfg, log, key)
|
||||
sessManager := newTestSession(t, store, cfg, log, key)
|
||||
|
||||
m := middleware.NewForTest(log, cfg, sessManager)
|
||||
m := &Middleware{
|
||||
log: log,
|
||||
params: &MiddlewareParams{
|
||||
Config: cfg,
|
||||
},
|
||||
session: sessManager,
|
||||
}
|
||||
|
||||
return m, sessManager
|
||||
}
|
||||
|
||||
// newTestSession creates a session.Session with a pre-configured cookie store
|
||||
// for testing. This avoids needing the fx lifecycle and database.
|
||||
func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *session.Session {
|
||||
t.Helper()
|
||||
return session.NewForTest(store, cfg, log, key)
|
||||
}
|
||||
|
||||
// --- Logging Middleware Tests ---
|
||||
|
||||
func TestLogging_SetsStatusCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
handler := m.Logging()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
_, err := w.Write([]byte("created"))
|
||||
if err != nil {
|
||||
if _, err := w.Write([]byte("created")); err != nil {
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
@@ -90,20 +84,15 @@ func TestLogging_SetsStatusCode(t *testing.T) {
|
||||
|
||||
func TestLogging_DefaultStatusOK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
handler := m.Logging()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, err := w.Write([]byte("ok"))
|
||||
if err != nil {
|
||||
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
if _, err := w.Write([]byte("ok")); err != nil {
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
@@ -114,31 +103,20 @@ func TestLogging_DefaultStatusOK(t *testing.T) {
|
||||
|
||||
func TestLogging_PassesThroughToNext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.Logging()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.Logging()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/api/webhook", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/webhook", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(
|
||||
t, called,
|
||||
"logging middleware should call the next handler",
|
||||
)
|
||||
assert.True(t, called, "logging middleware should call the next handler")
|
||||
}
|
||||
|
||||
// --- LoggingResponseWriter Tests ---
|
||||
@@ -147,33 +125,24 @@ func TestLoggingResponseWriter_CapturesStatusCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
lrw := middleware.NewLoggingResponseWriterForTest(w)
|
||||
lrw := NewLoggingResponseWriter(w)
|
||||
|
||||
// Default should be 200
|
||||
assert.Equal(
|
||||
t, http.StatusOK,
|
||||
middleware.LoggingResponseWriterStatusCode(lrw),
|
||||
)
|
||||
assert.Equal(t, http.StatusOK, lrw.statusCode)
|
||||
|
||||
// WriteHeader should capture the status code
|
||||
lrw.WriteHeader(http.StatusNotFound)
|
||||
|
||||
assert.Equal(
|
||||
t, http.StatusNotFound,
|
||||
middleware.LoggingResponseWriterStatusCode(lrw),
|
||||
)
|
||||
assert.Equal(t, http.StatusNotFound, lrw.statusCode)
|
||||
|
||||
// Underlying writer should also get the status code
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestLoggingResponseWriter_WriteDelegatesToUnderlying(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestLoggingResponseWriter_WriteDelegatesToUnderlying(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
lrw := middleware.NewLoggingResponseWriterForTest(w)
|
||||
lrw := NewLoggingResponseWriter(w)
|
||||
|
||||
n, err := lrw.Write([]byte("hello world"))
|
||||
require.NoError(t, err)
|
||||
@@ -185,124 +154,79 @@ func TestLoggingResponseWriter_WriteDelegatesToUnderlying(
|
||||
|
||||
func TestCORS_DevMode_AllowsAnyOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
handler := m.CORS()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// Preflight request
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodOptions, "/api/test", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// In dev mode, CORS should allow any origin
|
||||
assert.Equal(
|
||||
t, "*",
|
||||
w.Header().Get("Access-Control-Allow-Origin"),
|
||||
)
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
func TestCORS_ProdMode_NoOp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentProd)
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.CORS()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/api/test", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.Header.Set("Origin", "http://evil.com")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(
|
||||
t, called,
|
||||
"prod CORS middleware should pass through to handler",
|
||||
)
|
||||
assert.True(t, called, "prod CORS middleware should pass through to handler")
|
||||
// In prod, no CORS headers should be set (no-op middleware)
|
||||
assert.Empty(
|
||||
t,
|
||||
w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"prod mode should not set CORS headers",
|
||||
)
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"prod mode should not set CORS headers")
|
||||
}
|
||||
|
||||
// --- RequireAuth Middleware Tests ---
|
||||
|
||||
func TestRequireAuth_NoSession_RedirectsToLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.RequireAuth()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/dashboard", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.False(
|
||||
t, called,
|
||||
"handler should not be called for "+
|
||||
"unauthenticated request",
|
||||
)
|
||||
assert.False(t, called, "handler should not be called for unauthenticated request")
|
||||
assert.Equal(t, http.StatusSeeOther, w.Code)
|
||||
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestRequireAuth_AuthenticatedSession_PassesThrough(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestRequireAuth_AuthenticatedSession_PassesThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, sessManager := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.RequireAuth()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// Create an authenticated session by making a request,
|
||||
// setting session data, and saving the session cookie
|
||||
setupReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/setup", nil,
|
||||
)
|
||||
// Create an authenticated session by making a request, setting session data,
|
||||
// and saving the session cookie
|
||||
setupReq := httptest.NewRequest(http.MethodGet, "/setup", nil)
|
||||
setupW := httptest.NewRecorder()
|
||||
|
||||
sess, err := sessManager.Get(setupReq)
|
||||
@@ -315,74 +239,47 @@ func TestRequireAuth_AuthenticatedSession_PassesThrough(
|
||||
require.NotEmpty(t, cookies, "session cookie should be set")
|
||||
|
||||
// Make the actual request with the session cookie
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/dashboard", nil,
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(
|
||||
t, called,
|
||||
"handler should be called for authenticated request",
|
||||
)
|
||||
assert.True(t, called, "handler should be called for authenticated request")
|
||||
}
|
||||
|
||||
func TestRequireAuth_UnauthenticatedSession_RedirectsToLogin(
|
||||
t *testing.T,
|
||||
) {
|
||||
func TestRequireAuth_UnauthenticatedSession_RedirectsToLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, sessManager := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.RequireAuth()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.RequireAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// Create a session but don't authenticate it
|
||||
setupReq := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/setup", nil,
|
||||
)
|
||||
setupReq := httptest.NewRequest(http.MethodGet, "/setup", nil)
|
||||
setupW := httptest.NewRecorder()
|
||||
|
||||
sess, err := sessManager.Get(setupReq)
|
||||
require.NoError(t, err)
|
||||
// Don't call SetUser -- session exists but is not
|
||||
// authenticated
|
||||
// Don't call SetUser — session exists but is not authenticated
|
||||
require.NoError(t, sessManager.Save(setupReq, setupW, sess))
|
||||
|
||||
cookies := setupW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/dashboard", nil,
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.False(
|
||||
t, called,
|
||||
"handler should not be called for "+
|
||||
"unauthenticated session",
|
||||
)
|
||||
assert.False(t, called, "handler should not be called for unauthenticated session")
|
||||
assert.Equal(t, http.StatusSeeOther, w.Code)
|
||||
assert.Equal(t, "/pages/login", w.Header().Get("Location"))
|
||||
}
|
||||
@@ -407,9 +304,7 @@ func TestIpFromHostPort(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := middleware.IPFromHostPort(tt.input)
|
||||
|
||||
result := ipFromHostPort(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -417,124 +312,122 @@ func TestIpFromHostPort(t *testing.T) {
|
||||
|
||||
// --- MetricsAuth Tests ---
|
||||
|
||||
// metricsAuthMiddleware creates a Middleware configured for
|
||||
// metrics auth testing. This helper de-duplicates the setup in
|
||||
// metrics auth test functions.
|
||||
func metricsAuthMiddleware(
|
||||
t *testing.T,
|
||||
) *middleware.Middleware {
|
||||
t.Helper()
|
||||
|
||||
log := slog.New(slog.NewTextHandler(
|
||||
os.Stderr,
|
||||
&slog.HandlerOptions{Level: slog.LevelDebug},
|
||||
))
|
||||
func TestMetricsAuth_ValidCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
cfg := &config.Config{
|
||||
Environment: config.EnvironmentDev,
|
||||
MetricsUsername: "admin",
|
||||
MetricsPassword: "secret",
|
||||
}
|
||||
|
||||
key := make([]byte, testKeySize)
|
||||
key := make([]byte, 32)
|
||||
store := sessions.NewCookieStore(key)
|
||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||
|
||||
sessManager := session.NewForTest(store, cfg, log, key)
|
||||
|
||||
return middleware.NewForTest(log, cfg, sessManager)
|
||||
}
|
||||
|
||||
func TestMetricsAuth_ValidCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := metricsAuthMiddleware(t)
|
||||
m := &Middleware{
|
||||
log: log,
|
||||
params: &MiddlewareParams{
|
||||
Config: cfg,
|
||||
},
|
||||
session: sessManager,
|
||||
}
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.MetricsAuth()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/metrics", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
req.SetBasicAuth("admin", "secret")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(
|
||||
t, called,
|
||||
"handler should be called with valid basic auth",
|
||||
)
|
||||
assert.True(t, called, "handler should be called with valid basic auth")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestMetricsAuth_InvalidCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := metricsAuthMiddleware(t)
|
||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
cfg := &config.Config{
|
||||
Environment: config.EnvironmentDev,
|
||||
MetricsUsername: "admin",
|
||||
MetricsPassword: "secret",
|
||||
}
|
||||
|
||||
key := make([]byte, 32)
|
||||
store := sessions.NewCookieStore(key)
|
||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||
|
||||
sessManager := session.NewForTest(store, cfg, log, key)
|
||||
|
||||
m := &Middleware{
|
||||
log: log,
|
||||
params: &MiddlewareParams{
|
||||
Config: cfg,
|
||||
},
|
||||
session: sessManager,
|
||||
}
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.MetricsAuth()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.MetricsAuth()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/metrics", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
req.SetBasicAuth("admin", "wrong-password")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.False(
|
||||
t, called,
|
||||
"handler should not be called with invalid basic auth",
|
||||
)
|
||||
assert.False(t, called, "handler should not be called with invalid basic auth")
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestMetricsAuth_NoCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := metricsAuthMiddleware(t)
|
||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
cfg := &config.Config{
|
||||
Environment: config.EnvironmentDev,
|
||||
MetricsUsername: "admin",
|
||||
MetricsPassword: "secret",
|
||||
}
|
||||
|
||||
key := make([]byte, 32)
|
||||
store := sessions.NewCookieStore(key)
|
||||
store.Options = &sessions.Options{Path: "/", MaxAge: 86400}
|
||||
|
||||
sessManager := session.NewForTest(store, cfg, log, key)
|
||||
|
||||
m := &Middleware{
|
||||
log: log,
|
||||
params: &MiddlewareParams{
|
||||
Config: cfg,
|
||||
},
|
||||
session: sessManager,
|
||||
}
|
||||
|
||||
var called bool
|
||||
|
||||
handler := m.MetricsAuth()(http.HandlerFunc(
|
||||
func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.MetricsAuth()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/metrics", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
// No basic auth header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.False(
|
||||
t, called,
|
||||
"handler should not be called without credentials",
|
||||
)
|
||||
assert.False(t, called, "handler should not be called without credentials")
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
@@ -542,23 +435,16 @@ func TestMetricsAuth_NoCredentials(t *testing.T) {
|
||||
|
||||
func TestCORS_DevMode_AllowsMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
handler := m.CORS()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// Preflight for POST
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodOptions, "/api/webhooks", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/webhooks", nil)
|
||||
req.Header.Set("Origin", "http://localhost:5173")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
@@ -572,17 +458,14 @@ func TestCORS_DevMode_AllowsMethods(t *testing.T) {
|
||||
func TestSessionKeyFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Verify that the session initialization correctly validates
|
||||
// key format. A proper 32-byte key encoded as base64 should
|
||||
// work.
|
||||
key := make([]byte, testKeySize)
|
||||
|
||||
// Verify that the session initialization correctly validates key format.
|
||||
// A proper 32-byte key encoded as base64 should work.
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i + 1)
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(key)
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, decoded, testKeySize)
|
||||
assert.Len(t, decoded, 32)
|
||||
}
|
||||
|
||||
@@ -8,56 +8,40 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// loginRateLimit is the maximum number of login attempts
|
||||
// per interval.
|
||||
// loginRateLimit is the maximum number of login attempts per interval.
|
||||
loginRateLimit = 5
|
||||
|
||||
// loginRateInterval is the time window for the rate limit.
|
||||
loginRateInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
// LoginRateLimit returns middleware that enforces per-IP rate
|
||||
// limiting on login attempts using go-chi/httprate. Only POST
|
||||
// requests are rate-limited; GET requests (rendering the login
|
||||
// form) pass through unaffected. When the rate limit is exceeded,
|
||||
// a 429 Too Many Requests response is returned. IP extraction
|
||||
// honours X-Forwarded-For, X-Real-IP, and True-Client-IP headers
|
||||
// for reverse-proxy setups.
|
||||
// LoginRateLimit returns middleware that enforces per-IP rate limiting
|
||||
// on login attempts using go-chi/httprate. Only POST requests are
|
||||
// rate-limited; GET requests (rendering the login form) pass through
|
||||
// unaffected. When the rate limit is exceeded, a 429 Too Many Requests
|
||||
// response is returned. IP extraction honours X-Forwarded-For,
|
||||
// X-Real-IP, and True-Client-IP headers for reverse-proxy setups.
|
||||
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||
limiter := httprate.Limit(
|
||||
loginRateLimit,
|
||||
loginRateInterval,
|
||||
httprate.WithKeyFuncs(httprate.KeyByRealIP),
|
||||
httprate.WithLimitHandler(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
m.log.Warn("login rate limit exceeded",
|
||||
"path", r.URL.Path,
|
||||
)
|
||||
http.Error(
|
||||
w,
|
||||
"Too many login attempts. "+
|
||||
"Please try again later.",
|
||||
http.StatusTooManyRequests,
|
||||
)
|
||||
},
|
||||
)),
|
||||
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
||||
})),
|
||||
)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
limited := limiter(next)
|
||||
|
||||
return http.HandlerFunc(func(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
// Only rate-limit POST requests (actual login
|
||||
// attempts)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only rate-limit POST requests (actual login attempts)
|
||||
if r.Method != http.MethodPost {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
limited.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,147 +1,90 @@
|
||||
package middleware_test
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/middleware"
|
||||
)
|
||||
|
||||
func TestLoginRateLimit_AllowsGET(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var callCount int
|
||||
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
callCount++
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// GET requests should never be rate-limited
|
||||
for i := range 20 {
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/pages/login", nil,
|
||||
)
|
||||
for i := 0; i < 20; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/pages/login", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(
|
||||
t, http.StatusOK, w.Code,
|
||||
"GET request %d should pass", i,
|
||||
)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "GET request %d should pass", i)
|
||||
}
|
||||
|
||||
assert.Equal(t, 20, callCount)
|
||||
}
|
||||
|
||||
func TestLoginRateLimit_LimitsPOST(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var callCount int
|
||||
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
callCount++
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// First loginRateLimit POST requests should succeed
|
||||
for i := range middleware.LoginRateLimitConst {
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/pages/login", nil,
|
||||
)
|
||||
for i := 0; i < loginRateLimit; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(
|
||||
t, http.StatusOK, w.Code,
|
||||
"POST request %d should pass", i,
|
||||
)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "POST request %d should pass", i)
|
||||
}
|
||||
|
||||
// Next POST should be rate-limited
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/pages/login", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(
|
||||
t, http.StatusTooManyRequests, w.Code,
|
||||
"POST after limit should be 429",
|
||||
)
|
||||
assert.Equal(t, middleware.LoginRateLimitConst, callCount)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code, "POST after limit should be 429")
|
||||
assert.Equal(t, loginRateLimit, callCount)
|
||||
}
|
||||
|
||||
func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
}))
|
||||
|
||||
// Exhaust limit for IP1
|
||||
for range middleware.LoginRateLimitConst {
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/pages/login", nil,
|
||||
)
|
||||
for i := 0; i < loginRateLimit; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "1.2.3.4:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// IP1 should be rate-limited
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/pages/login", nil,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "1.2.3.4:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
|
||||
// IP2 should still be allowed
|
||||
req2 := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodPost, "/pages/login", nil,
|
||||
)
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req2.RemoteAddr = "5.6.7.8:12345"
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(
|
||||
t, http.StatusOK, w2.Code,
|
||||
"different IP should not be affected",
|
||||
)
|
||||
assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected")
|
||||
}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/session"
|
||||
)
|
||||
|
||||
// NewForTest creates a Middleware with the minimum dependencies
|
||||
// needed for testing. This bypasses the fx lifecycle.
|
||||
func NewForTest(
|
||||
log *slog.Logger,
|
||||
cfg *config.Config,
|
||||
sess *session.Session,
|
||||
) *Middleware {
|
||||
return &Middleware{
|
||||
log: log,
|
||||
params: &MiddlewareParams{
|
||||
Config: cfg,
|
||||
},
|
||||
session: sess,
|
||||
}
|
||||
}
|
||||
@@ -1,33 +1,18 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// httpReadTimeout is the maximum duration for reading the
|
||||
// entire request, including the body.
|
||||
httpReadTimeout = 10 * time.Second
|
||||
|
||||
// httpWriteTimeout is the maximum duration before timing out
|
||||
// writes of the response.
|
||||
httpWriteTimeout = 10 * time.Second
|
||||
|
||||
// httpMaxHeaderBytes is the maximum number of bytes the
|
||||
// server will read parsing the request headers.
|
||||
httpMaxHeaderBytes = 1 << 20
|
||||
)
|
||||
|
||||
func (s *Server) serveUntilShutdown() {
|
||||
listenAddr := fmt.Sprintf(":%d", s.params.Config.Port)
|
||||
s.httpServer = &http.Server{
|
||||
Addr: listenAddr,
|
||||
ReadTimeout: httpReadTimeout,
|
||||
WriteTimeout: httpWriteTimeout,
|
||||
MaxHeaderBytes: httpMaxHeaderBytes,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
Handler: s,
|
||||
}
|
||||
|
||||
@@ -36,21 +21,14 @@ func (s *Server) serveUntilShutdown() {
|
||||
s.SetupRoutes()
|
||||
|
||||
s.log.Info("http begin listen", "listenaddr", listenAddr)
|
||||
|
||||
err := s.httpServer.ListenAndServe()
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.log.Error("listen error", "error", err)
|
||||
|
||||
if s.cancelFunc != nil {
|
||||
s.cancelFunc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP delegates to the router.
|
||||
func (s *Server) ServeHTTP(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
@@ -11,24 +11,16 @@ import (
|
||||
"sneak.berlin/go/webhooker/static"
|
||||
)
|
||||
|
||||
// maxFormBodySize is the maximum allowed request body size (in
|
||||
// bytes) for form POST endpoints. 1 MB is generous for any form
|
||||
// submission while preventing abuse from oversized payloads.
|
||||
// maxFormBodySize is the maximum allowed request body size (in bytes) for
|
||||
// form POST endpoints. 1 MB is generous for any form submission while
|
||||
// preventing abuse from oversized payloads.
|
||||
const maxFormBodySize int64 = 1 * 1024 * 1024 // 1 MB
|
||||
|
||||
// requestTimeout is the maximum time allowed for a single HTTP
|
||||
// request.
|
||||
const requestTimeout = 60 * time.Second
|
||||
|
||||
// SetupRoutes configures all HTTP routes and middleware on the
|
||||
// server's router.
|
||||
// SetupRoutes configures all HTTP routes and middleware on the server's router.
|
||||
func (s *Server) SetupRoutes() {
|
||||
s.router = chi.NewRouter()
|
||||
s.setupGlobalMiddleware()
|
||||
s.setupRoutes()
|
||||
}
|
||||
|
||||
func (s *Server) setupGlobalMiddleware() {
|
||||
// Global middleware stack — applied to every request.
|
||||
s.router.Use(middleware.Recoverer)
|
||||
s.router.Use(middleware.RequestID)
|
||||
s.router.Use(s.mw.SecurityHeaders())
|
||||
@@ -40,28 +32,24 @@ func (s *Server) setupGlobalMiddleware() {
|
||||
}
|
||||
|
||||
s.router.Use(s.mw.CORS())
|
||||
s.router.Use(middleware.Timeout(requestTimeout))
|
||||
s.router.Use(middleware.Timeout(60 * time.Second))
|
||||
|
||||
// Sentry error reporting (if SENTRY_DSN is set). Repanic is
|
||||
// true so panics still bubble up to the Recoverer middleware.
|
||||
// Sentry error reporting (if SENTRY_DSN is set). Repanic is true
|
||||
// so panics still bubble up to the Recoverer middleware above.
|
||||
if s.sentryEnabled {
|
||||
sentryHandler := sentryhttp.New(sentryhttp.Options{
|
||||
Repanic: true,
|
||||
})
|
||||
s.router.Use(sentryHandler.Handle)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
// Routes
|
||||
s.router.Get("/", s.h.HandleIndex())
|
||||
|
||||
s.router.Mount(
|
||||
"/s",
|
||||
http.StripPrefix("/s", http.FileServer(http.FS(static.Static))),
|
||||
)
|
||||
s.router.Mount("/s", http.StripPrefix("/s", http.FileServer(http.FS(static.Static))))
|
||||
|
||||
s.router.Route("/api/v1", func(_ chi.Router) {
|
||||
// API routes will be added here.
|
||||
// TODO: Add API routes here
|
||||
})
|
||||
|
||||
s.router.Get(
|
||||
@@ -73,89 +61,62 @@ func (s *Server) setupRoutes() {
|
||||
if s.params.Config.MetricsUsername != "" {
|
||||
s.router.Group(func(r chi.Router) {
|
||||
r.Use(s.mw.MetricsAuth())
|
||||
r.Get(
|
||||
"/metrics",
|
||||
http.HandlerFunc(
|
||||
promhttp.Handler().ServeHTTP,
|
||||
),
|
||||
)
|
||||
r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP))
|
||||
})
|
||||
}
|
||||
|
||||
s.setupPageRoutes()
|
||||
s.setupUserRoutes()
|
||||
s.setupSourceRoutes()
|
||||
s.setupWebhookRoutes()
|
||||
}
|
||||
|
||||
func (s *Server) setupPageRoutes() {
|
||||
// pages that are rendered server-side — CSRF-protected, body-size
|
||||
// limited, and with per-IP rate limiting on the login endpoint.
|
||||
s.router.Route("/pages", func(r chi.Router) {
|
||||
r.Use(s.mw.CSRF())
|
||||
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
||||
|
||||
// Login page — rate-limited to prevent brute-force attacks
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(s.mw.LoginRateLimit())
|
||||
r.Get("/login", s.h.HandleLoginPage())
|
||||
r.Post("/login", s.h.HandleLoginSubmit())
|
||||
})
|
||||
|
||||
// Logout (auth required)
|
||||
r.Post("/logout", s.h.HandleLogout())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) setupUserRoutes() {
|
||||
// User profile routes
|
||||
s.router.Route("/user/{username}", func(r chi.Router) {
|
||||
r.Use(s.mw.CSRF())
|
||||
r.Get("/", s.h.HandleProfile())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) setupSourceRoutes() {
|
||||
// Webhook management routes (require authentication, CSRF-protected)
|
||||
s.router.Route("/sources", func(r chi.Router) {
|
||||
r.Use(s.mw.CSRF())
|
||||
r.Use(s.mw.RequireAuth())
|
||||
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
||||
r.Get("/", s.h.HandleSourceList())
|
||||
r.Get("/new", s.h.HandleSourceCreate())
|
||||
r.Post("/new", s.h.HandleSourceCreateSubmit())
|
||||
r.Get("/", s.h.HandleSourceList()) // List all webhooks
|
||||
r.Get("/new", s.h.HandleSourceCreate()) // Show create form
|
||||
r.Post("/new", s.h.HandleSourceCreateSubmit()) // Handle create submission
|
||||
})
|
||||
|
||||
s.router.Route("/source/{sourceID}", func(r chi.Router) {
|
||||
r.Use(s.mw.CSRF())
|
||||
r.Use(s.mw.RequireAuth())
|
||||
r.Use(s.mw.MaxBodySize(maxFormBodySize))
|
||||
r.Get("/", s.h.HandleSourceDetail())
|
||||
r.Get("/edit", s.h.HandleSourceEdit())
|
||||
r.Post("/edit", s.h.HandleSourceEditSubmit())
|
||||
r.Post("/delete", s.h.HandleSourceDelete())
|
||||
r.Get("/logs", s.h.HandleSourceLogs())
|
||||
r.Post(
|
||||
"/entrypoints",
|
||||
s.h.HandleEntrypointCreate(),
|
||||
)
|
||||
r.Post(
|
||||
"/entrypoints/{entrypointID}/delete",
|
||||
s.h.HandleEntrypointDelete(),
|
||||
)
|
||||
r.Post(
|
||||
"/entrypoints/{entrypointID}/toggle",
|
||||
s.h.HandleEntrypointToggle(),
|
||||
)
|
||||
r.Post("/targets", s.h.HandleTargetCreate())
|
||||
r.Post(
|
||||
"/targets/{targetID}/delete",
|
||||
s.h.HandleTargetDelete(),
|
||||
)
|
||||
r.Post(
|
||||
"/targets/{targetID}/toggle",
|
||||
s.h.HandleTargetToggle(),
|
||||
)
|
||||
r.Get("/", s.h.HandleSourceDetail()) // View webhook details
|
||||
r.Get("/edit", s.h.HandleSourceEdit()) // Show edit form
|
||||
r.Post("/edit", s.h.HandleSourceEditSubmit()) // Handle edit submission
|
||||
r.Post("/delete", s.h.HandleSourceDelete()) // Delete webhook
|
||||
r.Get("/logs", s.h.HandleSourceLogs()) // View webhook logs
|
||||
r.Post("/entrypoints", s.h.HandleEntrypointCreate()) // Add entrypoint
|
||||
r.Post("/entrypoints/{entrypointID}/delete", s.h.HandleEntrypointDelete()) // Delete entrypoint
|
||||
r.Post("/entrypoints/{entrypointID}/toggle", s.h.HandleEntrypointToggle()) // Toggle entrypoint active
|
||||
r.Post("/targets", s.h.HandleTargetCreate()) // Add target
|
||||
r.Post("/targets/{targetID}/delete", s.h.HandleTargetDelete()) // Delete target
|
||||
r.Post("/targets/{targetID}/toggle", s.h.HandleTargetToggle()) // Toggle target active
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) setupWebhookRoutes() {
|
||||
s.router.HandleFunc(
|
||||
"/webhook/{uuid}",
|
||||
s.h.HandleWebhook(),
|
||||
)
|
||||
// Entrypoint endpoint — accepts incoming webhook POST requests only.
|
||||
// Using HandleFunc so the handler itself can return 405 for non-POST
|
||||
// methods (chi's Method routing returns 405 without Allow header).
|
||||
s.router.HandleFunc("/webhook/{uuid}", s.h.HandleWebhook())
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Package server wires up HTTP routes and manages the
|
||||
// application lifecycle.
|
||||
// Package server wires up HTTP routes and manages the application lifecycle.
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -23,20 +22,9 @@ import (
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
const (
|
||||
// shutdownTimeout is the maximum time to wait for the HTTP
|
||||
// server to finish in-flight requests during shutdown.
|
||||
shutdownTimeout = 5 * time.Second
|
||||
|
||||
// sentryFlushTimeout is the maximum time to wait for Sentry
|
||||
// to flush pending events during shutdown.
|
||||
sentryFlushTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
//nolint:revive // ServerParams is a standard fx naming convention.
|
||||
type ServerParams struct {
|
||||
fx.In
|
||||
|
||||
Logger *logger.Logger
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
@@ -44,13 +32,13 @@ type ServerParams struct {
|
||||
Handlers *handlers.Handlers
|
||||
}
|
||||
|
||||
// Server is the main HTTP server that wires up routes and manages
|
||||
// graceful shutdown.
|
||||
// Server is the main HTTP server that wires up routes and manages graceful shutdown.
|
||||
type Server struct {
|
||||
startupTime time.Time
|
||||
exitCode int
|
||||
sentryEnabled bool
|
||||
log *slog.Logger
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
httpServer *http.Server
|
||||
router *chi.Mux
|
||||
@@ -59,8 +47,7 @@ type Server struct {
|
||||
h *handlers.Handlers
|
||||
}
|
||||
|
||||
// New creates a Server that starts the HTTP listener on fx start
|
||||
// and stops it gracefully.
|
||||
// New creates a Server that starts the HTTP listener on fx start and stops it gracefully.
|
||||
func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
|
||||
s := new(Server)
|
||||
s.params = params
|
||||
@@ -72,16 +59,13 @@ func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
|
||||
OnStart: func(_ context.Context) error {
|
||||
s.startupTime = time.Now()
|
||||
go s.Run()
|
||||
|
||||
return nil
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
s.cleanShutdown(ctx)
|
||||
|
||||
OnStop: func(_ context.Context) error {
|
||||
s.cleanShutdown()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -95,12 +79,6 @@ func (s *Server) Run() {
|
||||
s.serve()
|
||||
}
|
||||
|
||||
// MaintenanceMode returns whether the server is in maintenance
|
||||
// mode.
|
||||
func (s *Server) MaintenanceMode() bool {
|
||||
return s.params.Config.MaintenanceMode
|
||||
}
|
||||
|
||||
func (s *Server) enableSentry() {
|
||||
s.sentryEnabled = false
|
||||
|
||||
@@ -110,36 +88,28 @@ func (s *Server) enableSentry() {
|
||||
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: s.params.Config.SentryDSN,
|
||||
Release: fmt.Sprintf(
|
||||
"%s-%s",
|
||||
s.params.Globals.Appname,
|
||||
s.params.Globals.Version,
|
||||
),
|
||||
Release: fmt.Sprintf("%s-%s", s.params.Globals.Appname, s.params.Globals.Version),
|
||||
})
|
||||
if err != nil {
|
||||
s.log.Error("sentry init failure", "error", err)
|
||||
// Don't use fatal since we still want the service to run
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Info("sentry error reporting activated")
|
||||
s.sentryEnabled = true
|
||||
}
|
||||
|
||||
func (s *Server) serve() int {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
s.cancelFunc = cancelFunc
|
||||
s.ctx, s.cancelFunc = context.WithCancel(context.Background())
|
||||
|
||||
// signal watcher
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
|
||||
signal.Ignore(syscall.SIGPIPE)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
// block and wait for signal
|
||||
sig := <-c
|
||||
s.log.Info("signal received", "signal", sig.String())
|
||||
|
||||
if s.cancelFunc != nil {
|
||||
// cancelling the main context will trigger a clean
|
||||
// shutdown via the fx OnStop hook.
|
||||
@@ -149,9 +119,9 @@ func (s *Server) serve() int {
|
||||
|
||||
go s.serveUntilShutdown()
|
||||
|
||||
<-ctx.Done()
|
||||
<-s.ctx.Done()
|
||||
// Shutdown is handled by the fx OnStop hook (cleanShutdown).
|
||||
// Do not call cleanShutdown() here to avoid double invocation.
|
||||
// Do not call cleanShutdown() here to avoid a double invocation.
|
||||
return s.exitCode
|
||||
}
|
||||
|
||||
@@ -159,29 +129,28 @@ func (s *Server) cleanupForExit() {
|
||||
s.log.Info("cleaning up")
|
||||
}
|
||||
|
||||
func (s *Server) cleanShutdown(ctx context.Context) {
|
||||
func (s *Server) cleanShutdown() {
|
||||
// initiate clean shutdown
|
||||
s.exitCode = 0
|
||||
|
||||
ctxShutdown, shutdownCancel := context.WithTimeout(
|
||||
ctx, shutdownTimeout,
|
||||
)
|
||||
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
err := s.httpServer.Shutdown(ctxShutdown)
|
||||
if err != nil {
|
||||
s.log.Error(
|
||||
"server clean shutdown failed", "error", err,
|
||||
)
|
||||
if err := s.httpServer.Shutdown(ctxShutdown); err != nil {
|
||||
s.log.Error("server clean shutdown failed", "error", err)
|
||||
}
|
||||
|
||||
s.cleanupForExit()
|
||||
|
||||
if s.sentryEnabled {
|
||||
sentry.Flush(sentryFlushTimeout)
|
||||
sentry.Flush(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// MaintenanceMode returns whether the server is in maintenance mode.
|
||||
func (s *Server) MaintenanceMode() bool {
|
||||
return s.params.Config.MaintenanceMode
|
||||
}
|
||||
|
||||
func (s *Server) configure() {
|
||||
// identify ourselves in the logs
|
||||
s.params.Logger.Identify()
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
// Package session manages HTTP session storage and authentication
|
||||
// state.
|
||||
// Package session manages HTTP session storage and authentication state.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
@@ -19,44 +16,28 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// SessionName is the name of the session cookie.
|
||||
// SessionName is the name of the session cookie
|
||||
SessionName = "webhooker_session"
|
||||
|
||||
// UserIDKey is the session key for user ID.
|
||||
// UserIDKey is the session key for user ID
|
||||
UserIDKey = "user_id"
|
||||
|
||||
// UsernameKey is the session key for username.
|
||||
// UsernameKey is the session key for username
|
||||
UsernameKey = "username"
|
||||
|
||||
// AuthenticatedKey is the session key for authentication
|
||||
// status.
|
||||
// AuthenticatedKey is the session key for authentication status
|
||||
AuthenticatedKey = "authenticated"
|
||||
|
||||
// sessionKeyLength is the required length in bytes for the
|
||||
// session authentication key.
|
||||
sessionKeyLength = 32
|
||||
|
||||
// sessionMaxAgeDays is the session cookie lifetime in days.
|
||||
sessionMaxAgeDays = 7
|
||||
|
||||
// secondsPerDay is the number of seconds in a day.
|
||||
secondsPerDay = 86400
|
||||
)
|
||||
|
||||
// ErrSessionKeyLength is returned when the decoded session key
|
||||
// does not have the expected length.
|
||||
var ErrSessionKeyLength = errors.New("session key length mismatch")
|
||||
|
||||
// Params holds dependencies injected by fx.
|
||||
type Params struct {
|
||||
// nolint:revive // SessionParams is a standard fx naming convention
|
||||
type SessionParams struct {
|
||||
fx.In
|
||||
|
||||
Config *config.Config
|
||||
Database *database.Database
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// Session manages encrypted session storage.
|
||||
// Session manages encrypted session storage
|
||||
type Session struct {
|
||||
store *sessions.CookieStore
|
||||
key []byte // raw 32-byte auth key, also used for CSRF cookie signing
|
||||
@@ -64,44 +45,29 @@ type Session struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// New creates a new session manager. The cookie store is
|
||||
// initialized during the fx OnStart phase after the database is
|
||||
// connected, using a session key that is auto-generated and stored
|
||||
// in the database.
|
||||
func New(
|
||||
lc fx.Lifecycle,
|
||||
params Params,
|
||||
) (*Session, error) {
|
||||
// New creates a new session manager. The cookie store is initialized
|
||||
// during the fx OnStart phase after the database is connected, using
|
||||
// a session key that is auto-generated and stored in the database.
|
||||
func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
|
||||
s := &Session{
|
||||
log: params.Logger.Get(),
|
||||
config: params.Config,
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error {
|
||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
sessionKey, err := params.Database.GetOrCreateSessionKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to get session key: %w", err,
|
||||
)
|
||||
return fmt.Errorf("failed to get session key: %w", err)
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(
|
||||
sessionKey,
|
||||
)
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid session key format: %w", err,
|
||||
)
|
||||
return fmt.Errorf("invalid session key format: %w", err)
|
||||
}
|
||||
|
||||
if len(keyBytes) != sessionKeyLength {
|
||||
return fmt.Errorf(
|
||||
"%w: want %d, got %d",
|
||||
ErrSessionKeyLength,
|
||||
sessionKeyLength,
|
||||
len(keyBytes),
|
||||
)
|
||||
if len(keyBytes) != 32 {
|
||||
return fmt.Errorf("session key must be 32 bytes (got %d)", len(keyBytes))
|
||||
}
|
||||
|
||||
store := sessions.NewCookieStore(keyBytes)
|
||||
@@ -109,16 +75,15 @@ func New(
|
||||
// Configure cookie options for security
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: secondsPerDay * sessionMaxAgeDays,
|
||||
MaxAge: 86400 * 7, // 7 days
|
||||
HttpOnly: true,
|
||||
Secure: !params.Config.IsDev(),
|
||||
Secure: !params.Config.IsDev(), // HTTPS in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
s.key = keyBytes
|
||||
s.store = store
|
||||
s.log.Info("session manager initialized")
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -126,126 +91,99 @@ func New(
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Get retrieves a session for the request.
|
||||
func (s *Session) Get(
|
||||
r *http.Request,
|
||||
) (*sessions.Session, error) {
|
||||
// Get retrieves a session for the request
|
||||
func (s *Session) Get(r *http.Request) (*sessions.Session, error) {
|
||||
return s.store.Get(r, SessionName)
|
||||
}
|
||||
|
||||
// GetKey returns the raw 32-byte authentication key used for
|
||||
// session encryption. This key is also suitable for CSRF cookie
|
||||
// signing.
|
||||
// GetKey returns the raw 32-byte authentication key used for session
|
||||
// encryption. This key is also suitable for CSRF cookie signing.
|
||||
func (s *Session) GetKey() []byte {
|
||||
return s.key
|
||||
}
|
||||
|
||||
// Save saves the session.
|
||||
func (s *Session) Save(
|
||||
r *http.Request,
|
||||
w http.ResponseWriter,
|
||||
sess *sessions.Session,
|
||||
) error {
|
||||
// Save saves the session
|
||||
func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
|
||||
return sess.Save(r, w)
|
||||
}
|
||||
|
||||
// SetUser sets the user information in the session.
|
||||
func (s *Session) SetUser(
|
||||
sess *sessions.Session,
|
||||
userID, username string,
|
||||
) {
|
||||
// SetUser sets the user information in the session
|
||||
func (s *Session) SetUser(sess *sessions.Session, userID, username string) {
|
||||
sess.Values[UserIDKey] = userID
|
||||
sess.Values[UsernameKey] = username
|
||||
sess.Values[AuthenticatedKey] = true
|
||||
}
|
||||
|
||||
// ClearUser removes user information from the session.
|
||||
// ClearUser removes user information from the session
|
||||
func (s *Session) ClearUser(sess *sessions.Session) {
|
||||
delete(sess.Values, UserIDKey)
|
||||
delete(sess.Values, UsernameKey)
|
||||
delete(sess.Values, AuthenticatedKey)
|
||||
}
|
||||
|
||||
// IsAuthenticated checks if the session has an authenticated
|
||||
// user.
|
||||
// IsAuthenticated checks if the session has an authenticated user
|
||||
func (s *Session) IsAuthenticated(sess *sessions.Session) bool {
|
||||
auth, ok := sess.Values[AuthenticatedKey].(bool)
|
||||
|
||||
return ok && auth
|
||||
}
|
||||
|
||||
// GetUserID retrieves the user ID from the session.
|
||||
func (s *Session) GetUserID(
|
||||
sess *sessions.Session,
|
||||
) (string, bool) {
|
||||
// GetUserID retrieves the user ID from the session
|
||||
func (s *Session) GetUserID(sess *sessions.Session) (string, bool) {
|
||||
userID, ok := sess.Values[UserIDKey].(string)
|
||||
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// GetUsername retrieves the username from the session.
|
||||
func (s *Session) GetUsername(
|
||||
sess *sessions.Session,
|
||||
) (string, bool) {
|
||||
// GetUsername retrieves the username from the session
|
||||
func (s *Session) GetUsername(sess *sessions.Session) (string, bool) {
|
||||
username, ok := sess.Values[UsernameKey].(string)
|
||||
|
||||
return username, ok
|
||||
}
|
||||
|
||||
// Destroy invalidates the session.
|
||||
// Destroy invalidates the session
|
||||
func (s *Session) Destroy(sess *sessions.Session) {
|
||||
sess.Options.MaxAge = -1
|
||||
s.ClearUser(sess)
|
||||
}
|
||||
|
||||
// Regenerate creates a new session with the same values but a
|
||||
// fresh ID. The old session is destroyed (MaxAge = -1) and saved,
|
||||
// then a new session is created. This prevents session fixation
|
||||
// attacks by ensuring the session ID changes after privilege
|
||||
// escalation (e.g. login).
|
||||
func (s *Session) Regenerate(
|
||||
r *http.Request,
|
||||
w http.ResponseWriter,
|
||||
oldSess *sessions.Session,
|
||||
) (*sessions.Session, error) {
|
||||
// Regenerate creates a new session with the same values but a fresh ID.
|
||||
// The old session is destroyed (MaxAge = -1) and saved, then a new session
|
||||
// is created. This prevents session fixation attacks by ensuring the
|
||||
// session ID changes after privilege escalation (e.g. login).
|
||||
func (s *Session) Regenerate(r *http.Request, w http.ResponseWriter, oldSess *sessions.Session) (*sessions.Session, error) {
|
||||
// Copy the values from the old session
|
||||
oldValues := make(map[any]any)
|
||||
maps.Copy(oldValues, oldSess.Values)
|
||||
oldValues := make(map[interface{}]interface{})
|
||||
for k, v := range oldSess.Values {
|
||||
oldValues[k] = v
|
||||
}
|
||||
|
||||
// Destroy the old session
|
||||
oldSess.Options.MaxAge = -1
|
||||
s.ClearUser(oldSess)
|
||||
|
||||
err := oldSess.Save(r, w)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to destroy old session: %w", err,
|
||||
)
|
||||
if err := oldSess.Save(r, w); err != nil {
|
||||
return nil, fmt.Errorf("failed to destroy old session: %w", err)
|
||||
}
|
||||
|
||||
// Create a new session (gorilla/sessions generates a new ID)
|
||||
newSess, err := s.store.New(r, SessionName)
|
||||
if err != nil {
|
||||
// store.New may return an error alongside a new empty
|
||||
// session if the old cookie is now invalid. That is
|
||||
// expected after we destroyed it above. Only fail on a
|
||||
// nil session.
|
||||
// store.New may return an error alongside a new empty session
|
||||
// if the old cookie is now invalid. That is expected after we
|
||||
// destroyed it above. Only fail on a nil session.
|
||||
if newSess == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to create new session: %w", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to create new session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Restore the copied values into the new session
|
||||
maps.Copy(newSess.Values, oldValues)
|
||||
for k, v := range oldValues {
|
||||
newSess.Values[k] = v
|
||||
}
|
||||
|
||||
// Apply the standard session options (the destroyed old
|
||||
// session had MaxAge = -1, which store.New might inherit
|
||||
// from the cookie).
|
||||
// Apply the standard session options (the destroyed old session had
|
||||
// MaxAge = -1, which store.New might inherit from the cookie).
|
||||
newSess.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: secondsPerDay * sessionMaxAgeDays,
|
||||
MaxAge: 86400 * 7,
|
||||
HttpOnly: true,
|
||||
Secure: !s.config.IsDev(),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package session_test
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -12,22 +11,15 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/session"
|
||||
)
|
||||
|
||||
const testKeySize = 32
|
||||
|
||||
// testSession creates a Session with a real cookie store for
|
||||
// testing.
|
||||
func testSession(t *testing.T) *session.Session {
|
||||
// testSession creates a Session with a real cookie store for testing.
|
||||
func testSession(t *testing.T) *Session {
|
||||
t.Helper()
|
||||
|
||||
key := make([]byte, testKeySize)
|
||||
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i + 42)
|
||||
}
|
||||
|
||||
store := sessions.NewCookieStore(key)
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
@@ -40,47 +32,34 @@ func testSession(t *testing.T) *session.Session {
|
||||
cfg := &config.Config{
|
||||
Environment: config.EnvironmentDev,
|
||||
}
|
||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
|
||||
log := slog.New(slog.NewTextHandler(
|
||||
os.Stderr,
|
||||
&slog.HandlerOptions{Level: slog.LevelDebug},
|
||||
))
|
||||
|
||||
return session.NewForTest(store, cfg, log, key)
|
||||
return NewForTest(store, cfg, log, key)
|
||||
}
|
||||
|
||||
// --- Get and Save Tests ---
|
||||
|
||||
func TestGet_NewSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.True(
|
||||
t, sess.IsNew,
|
||||
"session should be new when no cookie is present",
|
||||
)
|
||||
assert.True(t, sess.IsNew, "session should be new when no cookie is present")
|
||||
}
|
||||
|
||||
func TestGet_ExistingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
// Create and save a session
|
||||
req1 := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
|
||||
sess1, err := s.Get(req1)
|
||||
require.NoError(t, err)
|
||||
|
||||
sess1.Values["test_key"] = "test_value"
|
||||
require.NoError(t, s.Save(req1, w1, sess1))
|
||||
|
||||
@@ -89,34 +68,26 @@ func TestGet_ExistingSession(t *testing.T) {
|
||||
require.NotEmpty(t, cookies)
|
||||
|
||||
// Make a new request with the session cookie
|
||||
req2 := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
for _, c := range cookies {
|
||||
req2.AddCookie(c)
|
||||
}
|
||||
|
||||
sess2, err := s.Get(req2)
|
||||
require.NoError(t, err)
|
||||
assert.False(
|
||||
t, sess2.IsNew,
|
||||
"session should not be new when cookie is present",
|
||||
)
|
||||
assert.False(t, sess2.IsNew, "session should not be new when cookie is present")
|
||||
assert.Equal(t, "test_value", sess2.Values["test_key"])
|
||||
}
|
||||
|
||||
func TestSave_SetsCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
sess.Values["key"] = "value"
|
||||
|
||||
err = s.Save(req, w, sess)
|
||||
@@ -127,73 +98,48 @@ func TestSave_SetsCookie(t *testing.T) {
|
||||
|
||||
// Verify the cookie has the expected name
|
||||
var found bool
|
||||
|
||||
for _, c := range cookies {
|
||||
if c.Name == session.SessionName {
|
||||
if c.Name == SessionName {
|
||||
found = true
|
||||
|
||||
assert.True(
|
||||
t, c.HttpOnly,
|
||||
"session cookie should be HTTP-only",
|
||||
)
|
||||
|
||||
assert.True(t, c.HttpOnly, "session cookie should be HTTP-only")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(
|
||||
t, found,
|
||||
"should find a cookie named %s", session.SessionName,
|
||||
)
|
||||
assert.True(t, found, "should find a cookie named %s", SessionName)
|
||||
}
|
||||
|
||||
// --- SetUser and User Retrieval Tests ---
|
||||
|
||||
func TestSetUser_SetsAllFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.SetUser(sess, "user-abc-123", "alice")
|
||||
|
||||
assert.Equal(
|
||||
t, "user-abc-123", sess.Values[session.UserIDKey],
|
||||
)
|
||||
assert.Equal(
|
||||
t, "alice", sess.Values[session.UsernameKey],
|
||||
)
|
||||
assert.Equal(
|
||||
t, true, sess.Values[session.AuthenticatedKey],
|
||||
)
|
||||
assert.Equal(t, "user-abc-123", sess.Values[UserIDKey])
|
||||
assert.Equal(t, "alice", sess.Values[UsernameKey])
|
||||
assert.Equal(t, true, sess.Values[AuthenticatedKey])
|
||||
}
|
||||
|
||||
func TestGetUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Before setting user
|
||||
userID, ok := s.GetUserID(sess)
|
||||
assert.False(
|
||||
t, ok, "should return false when no user ID is set",
|
||||
)
|
||||
assert.False(t, ok, "should return false when no user ID is set")
|
||||
assert.Empty(t, userID)
|
||||
|
||||
// After setting user
|
||||
s.SetUser(sess, "user-xyz", "bob")
|
||||
|
||||
userID, ok = s.GetUserID(sess)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user-xyz", userID)
|
||||
@@ -201,25 +147,19 @@ func TestGetUserID(t *testing.T) {
|
||||
|
||||
func TestGetUsername(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Before setting user
|
||||
username, ok := s.GetUsername(sess)
|
||||
assert.False(
|
||||
t, ok, "should return false when no username is set",
|
||||
)
|
||||
assert.False(t, ok, "should return false when no username is set")
|
||||
assert.Empty(t, username)
|
||||
|
||||
// After setting user
|
||||
s.SetUser(sess, "user-xyz", "bob")
|
||||
|
||||
username, ok = s.GetUsername(sess)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "bob", username)
|
||||
@@ -229,29 +169,20 @@ func TestGetUsername(t *testing.T) {
|
||||
|
||||
func TestIsAuthenticated_NoSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(
|
||||
t, s.IsAuthenticated(sess),
|
||||
"new session should not be authenticated",
|
||||
)
|
||||
assert.False(t, s.IsAuthenticated(sess), "new session should not be authenticated")
|
||||
}
|
||||
|
||||
func TestIsAuthenticated_AfterSetUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -261,12 +192,9 @@ func TestIsAuthenticated_AfterSetUser(t *testing.T) {
|
||||
|
||||
func TestIsAuthenticated_AfterClearUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -274,71 +202,52 @@ func TestIsAuthenticated_AfterClearUser(t *testing.T) {
|
||||
require.True(t, s.IsAuthenticated(sess))
|
||||
|
||||
s.ClearUser(sess)
|
||||
|
||||
assert.False(
|
||||
t, s.IsAuthenticated(sess),
|
||||
"should not be authenticated after ClearUser",
|
||||
)
|
||||
assert.False(t, s.IsAuthenticated(sess), "should not be authenticated after ClearUser")
|
||||
}
|
||||
|
||||
func TestIsAuthenticated_WrongType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set authenticated to a non-bool value
|
||||
sess.Values[session.AuthenticatedKey] = "yes"
|
||||
|
||||
assert.False(
|
||||
t, s.IsAuthenticated(sess),
|
||||
"should return false for non-bool authenticated value",
|
||||
)
|
||||
sess.Values[AuthenticatedKey] = "yes"
|
||||
assert.False(t, s.IsAuthenticated(sess), "should return false for non-bool authenticated value")
|
||||
}
|
||||
|
||||
// --- ClearUser Tests ---
|
||||
|
||||
func TestClearUser_RemovesAllKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.SetUser(sess, "user-123", "alice")
|
||||
s.ClearUser(sess)
|
||||
|
||||
_, hasUserID := sess.Values[session.UserIDKey]
|
||||
_, hasUserID := sess.Values[UserIDKey]
|
||||
assert.False(t, hasUserID, "UserIDKey should be removed")
|
||||
|
||||
_, hasUsername := sess.Values[session.UsernameKey]
|
||||
_, hasUsername := sess.Values[UsernameKey]
|
||||
assert.False(t, hasUsername, "UsernameKey should be removed")
|
||||
|
||||
_, hasAuth := sess.Values[session.AuthenticatedKey]
|
||||
assert.False(
|
||||
t, hasAuth, "AuthenticatedKey should be removed",
|
||||
)
|
||||
_, hasAuth := sess.Values[AuthenticatedKey]
|
||||
assert.False(t, hasAuth, "AuthenticatedKey should be removed")
|
||||
}
|
||||
|
||||
// --- Destroy Tests ---
|
||||
|
||||
func TestDestroy_InvalidatesSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -346,18 +255,11 @@ func TestDestroy_InvalidatesSession(t *testing.T) {
|
||||
|
||||
s.Destroy(sess)
|
||||
|
||||
// After Destroy: MaxAge should be -1 (delete cookie) and
|
||||
// user data cleared
|
||||
assert.Equal(
|
||||
t, -1, sess.Options.MaxAge,
|
||||
"Destroy should set MaxAge to -1",
|
||||
)
|
||||
assert.False(
|
||||
t, s.IsAuthenticated(sess),
|
||||
"should not be authenticated after Destroy",
|
||||
)
|
||||
// After Destroy: MaxAge should be -1 (delete cookie) and user data cleared
|
||||
assert.Equal(t, -1, sess.Options.MaxAge, "Destroy should set MaxAge to -1")
|
||||
assert.False(t, s.IsAuthenticated(sess), "should not be authenticated after Destroy")
|
||||
|
||||
_, hasUserID := sess.Values[session.UserIDKey]
|
||||
_, hasUserID := sess.Values[UserIDKey]
|
||||
assert.False(t, hasUserID, "Destroy should clear user ID")
|
||||
}
|
||||
|
||||
@@ -365,12 +267,10 @@ func TestDestroy_InvalidatesSession(t *testing.T) {
|
||||
|
||||
func TestSessionPersistence_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
// Step 1: Create session, set user, save
|
||||
req1 := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
|
||||
sess1, err := s.Get(req1)
|
||||
@@ -381,13 +281,8 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
|
||||
cookies := w1.Result().Cookies()
|
||||
require.NotEmpty(t, cookies)
|
||||
|
||||
// Step 2: New request with cookies -- session data should
|
||||
// persist
|
||||
req2 := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/profile", nil,
|
||||
)
|
||||
|
||||
// Step 2: New request with cookies — session data should persist
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/profile", nil)
|
||||
for _, c := range cookies {
|
||||
req2.AddCookie(c)
|
||||
}
|
||||
@@ -395,10 +290,7 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
|
||||
sess2, err := s.Get(req2)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(
|
||||
t, s.IsAuthenticated(sess2),
|
||||
"session should be authenticated after round-trip",
|
||||
)
|
||||
assert.True(t, s.IsAuthenticated(sess2), "session should be authenticated after round-trip")
|
||||
|
||||
userID, ok := s.GetUserID(sess2)
|
||||
assert.True(t, ok)
|
||||
@@ -413,23 +305,19 @@ func TestSessionPersistence_RoundTrip(t *testing.T) {
|
||||
|
||||
func TestSessionConstants(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "webhooker_session", session.SessionName)
|
||||
assert.Equal(t, "user_id", session.UserIDKey)
|
||||
assert.Equal(t, "username", session.UsernameKey)
|
||||
assert.Equal(t, "authenticated", session.AuthenticatedKey)
|
||||
assert.Equal(t, "webhooker_session", SessionName)
|
||||
assert.Equal(t, "user_id", UserIDKey)
|
||||
assert.Equal(t, "username", UsernameKey)
|
||||
assert.Equal(t, "authenticated", AuthenticatedKey)
|
||||
}
|
||||
|
||||
// --- Edge Cases ---
|
||||
|
||||
func TestSetUser_OverwritesPreviousUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
req := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
sess, err := s.Get(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -450,12 +338,10 @@ func TestSetUser_OverwritesPreviousUser(t *testing.T) {
|
||||
|
||||
func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := testSession(t)
|
||||
|
||||
// Create a session
|
||||
req1 := httptest.NewRequestWithContext(
|
||||
context.Background(), http.MethodGet, "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
|
||||
sess, err := s.Get(req1)
|
||||
@@ -467,15 +353,10 @@ func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
||||
require.NotEmpty(t, cookies)
|
||||
|
||||
// Destroy and save
|
||||
req2 := httptest.NewRequestWithContext(
|
||||
context.Background(),
|
||||
http.MethodGet, "/logout", nil,
|
||||
)
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/logout", nil)
|
||||
for _, c := range cookies {
|
||||
req2.AddCookie(c)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
|
||||
sess2, err := s.Get(req2)
|
||||
@@ -483,25 +364,15 @@ func TestDestroy_ThenSave_DeletesCookie(t *testing.T) {
|
||||
s.Destroy(sess2)
|
||||
require.NoError(t, s.Save(req2, w2, sess2))
|
||||
|
||||
// The cookie should have MaxAge = -1 (browser should delete)
|
||||
// The cookie should have MaxAge = -1 (browser should delete it)
|
||||
responseCookies := w2.Result().Cookies()
|
||||
|
||||
var sessionCookie *http.Cookie
|
||||
|
||||
for _, c := range responseCookies {
|
||||
if c.Name == session.SessionName {
|
||||
if c.Name == SessionName {
|
||||
sessionCookie = c
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(
|
||||
t, sessionCookie,
|
||||
"should have a session cookie in response",
|
||||
)
|
||||
assert.Negative(
|
||||
t, sessionCookie.MaxAge,
|
||||
"destroyed session cookie should have negative MaxAge",
|
||||
)
|
||||
require.NotNil(t, sessionCookie, "should have a session cookie in response")
|
||||
assert.True(t, sessionCookie.MaxAge < 0, "destroyed session cookie should have negative MaxAge")
|
||||
}
|
||||
|
||||
67
templates/index.html
Normal file
67
templates/index.html
Normal file
@@ -0,0 +1,67 @@
|
||||
{{template "base" .}}
|
||||
|
||||
{{define "title"}}Home - Webhooker{{end}}
|
||||
|
||||
{{define "content"}}
|
||||
<div class="max-w-4xl mx-auto px-6 py-12">
|
||||
<div class="text-center mb-10">
|
||||
<h1 class="text-4xl font-medium text-gray-900">Welcome to Webhooker</h1>
|
||||
<p class="mt-3 text-lg text-gray-500">A reliable webhook proxy service for event delivery</p>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||
<!-- Server Status Card -->
|
||||
<div class="card-elevated p-6">
|
||||
<div class="flex items-center mb-4">
|
||||
<div class="rounded-full bg-success-50 p-3 mr-4">
|
||||
<svg class="w-6 h-6 text-success-500" fill="currentColor" viewBox="0 0 16 16">
|
||||
<path d="M1.333 2.667C1.333 1.194 4.318 0 8 0s6.667 1.194 6.667 2.667V4c0 1.473-2.985 2.667-6.667 2.667S1.333 5.473 1.333 4V2.667z"/>
|
||||
<path d="M1.333 6.334v3C1.333 10.805 4.318 12 8 12s6.667-1.194 6.667-2.667V6.334a6.51 6.51 0 0 1-1.458.79C11.81 7.684 9.967 8 8 8c-1.966 0-3.809-.317-5.208-.876a6.508 6.508 0 0 1-1.458-.79z"/>
|
||||
<path d="M14.667 11.668a6.51 6.51 0 0 1-1.458.789c-1.4.56-3.242.876-5.21.876-1.966 0-3.809-.316-5.208-.876a6.51 6.51 0 0 1-1.458-.79v1.666C1.333 14.806 4.318 16 8 16s6.667-1.194 6.667-2.667v-1.665z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<h2 class="text-lg font-medium text-gray-900">Server Status</h2>
|
||||
<span class="badge-success">Online</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="space-y-3">
|
||||
<div>
|
||||
<p class="text-sm text-gray-500">Uptime</p>
|
||||
<p class="text-2xl font-medium text-gray-900">{{.Uptime}}</p>
|
||||
</div>
|
||||
<div>
|
||||
<p class="text-sm text-gray-500">Version</p>
|
||||
<p class="font-mono text-sm text-gray-700">{{.Version}}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Users Card -->
|
||||
<div class="card-elevated p-6">
|
||||
<div class="flex items-center mb-4">
|
||||
<div class="rounded-full bg-primary-50 p-3 mr-4">
|
||||
<svg class="w-6 h-6 text-primary-500" fill="currentColor" viewBox="0 0 16 16">
|
||||
<path d="M15 14s1 0 1-1-1-4-5-4-5 3-5 4 1 1 1 1h8zm-7.978-1A.261.261 0 0 1 7 12.996c.001-.264.167-1.03.76-1.72C8.312 10.629 9.282 10 11 10c1.717 0 2.687.63 3.24 1.276.593.69.758 1.457.76 1.72l-.008.002a.274.274 0 0 1-.014.002H7.022zM11 7a2 2 0 1 0 0-4 2 2 0 0 0 0 4zm3-2a3 3 0 1 1-6 0 3 3 0 0 1 6 0zM6.936 9.28a5.88 5.88 0 0 0-1.23-.247A7.35 7.35 0 0 0 5 9c-4 0-5 3-5 4 0 .667.333 1 1 1h4.216A2.238 2.238 0 0 1 5 13c0-1.01.377-2.042 1.09-2.904.243-.294.526-.569.846-.816zM4.92 10A5.493 5.493 0 0 0 4 13H1c0-.26.164-1.03.76-1.724.545-.636 1.492-1.256 3.16-1.275zM1.5 5.5a3 3 0 1 1 6 0 3 3 0 0 1-6 0zm3-2a2 2 0 1 0 0 4 2 2 0 0 0 0-4z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<h2 class="text-lg font-medium text-gray-900">Users</h2>
|
||||
<p class="text-sm text-gray-500">Registered accounts</p>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<p class="text-4xl font-medium text-gray-900">{{.UserCount}}</p>
|
||||
<p class="text-sm text-gray-500 mt-1">Total users</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{{if not .User}}
|
||||
<div class="text-center mt-10">
|
||||
<p class="text-gray-500 mb-4">Ready to get started?</p>
|
||||
<a href="/pages/login" class="btn-primary">Login to your account</a>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
{{end}}
|
||||
Reference in New Issue
Block a user