initial
This commit is contained in:
138
internal/config/config.go
Normal file
138
internal/config/config.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
pkgconfig "git.eeqj.de/sneak/webhooker/pkg/config"
|
||||
"go.uber.org/fx"
|
||||
|
||||
// spooky action at a distance!
|
||||
// this populates the environment
|
||||
// from a ./.env file automatically
|
||||
// for development configuration.
|
||||
// .env contents should be things like
|
||||
// `DBURL=postgres://user:pass@.../`
|
||||
// (without the backticks, of course)
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvironmentDev represents development environment
|
||||
EnvironmentDev = "dev"
|
||||
// EnvironmentProd represents production environment
|
||||
EnvironmentProd = "prod"
|
||||
// DevSessionKey is an insecure default session key for development
|
||||
// This is "webhooker-dev-session-key-insecure!" base64 encoded
|
||||
DevSessionKey = "d2ViaG9va2VyLWRldi1zZXNzaW9uLWtleS1pbnNlY3VyZSE="
|
||||
)
|
||||
|
||||
// nolint:revive // ConfigParams is a standard fx naming convention
|
||||
type ConfigParams struct {
|
||||
fx.In
|
||||
Globals *globals.Globals
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DBURL string
|
||||
Debug bool
|
||||
MaintenanceMode bool
|
||||
DevelopmentMode bool
|
||||
DevAdminUsername string
|
||||
DevAdminPassword string
|
||||
Environment string
|
||||
MetricsPassword string
|
||||
MetricsUsername string
|
||||
Port int
|
||||
SentryDSN string
|
||||
SessionKey string
|
||||
params *ConfigParams
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// IsDev returns true if running in development environment
|
||||
func (c *Config) IsDev() bool {
|
||||
return c.Environment == EnvironmentDev
|
||||
}
|
||||
|
||||
// IsProd returns true if running in production environment
|
||||
func (c *Config) IsProd() bool {
|
||||
return c.Environment == EnvironmentProd
|
||||
}
|
||||
|
||||
// nolint:revive // lc parameter is required by fx even if unused
|
||||
func New(lc fx.Lifecycle, params ConfigParams) (*Config, error) {
|
||||
log := params.Logger.Get()
|
||||
|
||||
// Determine environment from WEBHOOKER_ENVIRONMENT env var, default to dev
|
||||
environment := os.Getenv("WEBHOOKER_ENVIRONMENT")
|
||||
if environment == "" {
|
||||
environment = EnvironmentDev
|
||||
}
|
||||
|
||||
// Validate environment
|
||||
if environment != EnvironmentDev && environment != EnvironmentProd {
|
||||
return nil, fmt.Errorf("WEBHOOKER_ENVIRONMENT must be either '%s' or '%s', got '%s'",
|
||||
EnvironmentDev, EnvironmentProd, environment)
|
||||
}
|
||||
|
||||
// Set the environment in the config package
|
||||
pkgconfig.SetEnvironment(environment)
|
||||
|
||||
// Load configuration values
|
||||
s := &Config{
|
||||
DBURL: pkgconfig.GetString("dburl"),
|
||||
Debug: pkgconfig.GetBool("debug"),
|
||||
MaintenanceMode: pkgconfig.GetBool("maintenanceMode"),
|
||||
DevelopmentMode: pkgconfig.GetBool("developmentMode"),
|
||||
DevAdminUsername: pkgconfig.GetString("devAdminUsername"),
|
||||
DevAdminPassword: pkgconfig.GetString("devAdminPassword"),
|
||||
Environment: pkgconfig.GetString("environment", environment),
|
||||
MetricsUsername: pkgconfig.GetString("metricsUsername"),
|
||||
MetricsPassword: pkgconfig.GetString("metricsPassword"),
|
||||
Port: pkgconfig.GetInt("port", 8080),
|
||||
SentryDSN: pkgconfig.GetSecretString("sentryDSN"),
|
||||
SessionKey: pkgconfig.GetSecretString("sessionKey"),
|
||||
log: log,
|
||||
params: ¶ms,
|
||||
}
|
||||
|
||||
// Validate database URL
|
||||
if s.DBURL == "" {
|
||||
return nil, fmt.Errorf("database URL (dburl) is required")
|
||||
}
|
||||
|
||||
// In production, require session key
|
||||
if s.IsProd() && s.SessionKey == "" {
|
||||
return nil, fmt.Errorf("SESSION_KEY is required in production environment")
|
||||
}
|
||||
|
||||
// In development mode, warn if using default session key
|
||||
if s.IsDev() && s.SessionKey == DevSessionKey {
|
||||
log.Warn("Using insecure default session key for development mode")
|
||||
}
|
||||
|
||||
if s.Debug {
|
||||
params.Logger.EnableDebugLogging()
|
||||
s.log = params.Logger.Get()
|
||||
log.Debug("Debug mode enabled")
|
||||
}
|
||||
|
||||
// Log configuration summary (without secrets)
|
||||
log.Info("Configuration loaded",
|
||||
"environment", s.Environment,
|
||||
"port", s.Port,
|
||||
"debug", s.Debug,
|
||||
"maintenanceMode", s.MaintenanceMode,
|
||||
"developmentMode", s.DevelopmentMode,
|
||||
"hasSessionKey", s.SessionKey != "",
|
||||
"hasSentryDSN", s.SentryDSN != "",
|
||||
"hasMetricsAuth", s.MetricsUsername != "" && s.MetricsPassword != "",
|
||||
)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
300
internal/config/config_test.go
Normal file
300
internal/config/config_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
pkgconfig "git.eeqj.de/sneak/webhooker/pkg/config"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/fx"
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
// createTestConfig creates a test configuration file in memory
|
||||
func createTestConfig(fs afero.Fs) error {
|
||||
configYAML := `
|
||||
environments:
|
||||
dev:
|
||||
config:
|
||||
port: 8080
|
||||
debug: true
|
||||
maintenanceMode: false
|
||||
developmentMode: true
|
||||
environment: dev
|
||||
dburl: postgres://test:test@localhost:5432/test_dev?sslmode=disable
|
||||
metricsUsername: testuser
|
||||
metricsPassword: testpass
|
||||
devAdminUsername: devadmin
|
||||
devAdminPassword: devpass
|
||||
secrets:
|
||||
sessionKey: d2ViaG9va2VyLWRldi1zZXNzaW9uLWtleS1pbnNlY3VyZSE=
|
||||
sentryDSN: ""
|
||||
|
||||
prod:
|
||||
config:
|
||||
port: $ENV:PORT
|
||||
debug: $ENV:DEBUG
|
||||
maintenanceMode: $ENV:MAINTENANCE_MODE
|
||||
developmentMode: false
|
||||
environment: prod
|
||||
dburl: $ENV:DBURL
|
||||
metricsUsername: $ENV:METRICS_USERNAME
|
||||
metricsPassword: $ENV:METRICS_PASSWORD
|
||||
devAdminUsername: ""
|
||||
devAdminPassword: ""
|
||||
secrets:
|
||||
sessionKey: $ENV:SESSION_KEY
|
||||
sentryDSN: $ENV:SENTRY_DSN
|
||||
|
||||
configDefaults:
|
||||
port: 8080
|
||||
debug: false
|
||||
maintenanceMode: false
|
||||
developmentMode: false
|
||||
environment: dev
|
||||
metricsUsername: ""
|
||||
metricsPassword: ""
|
||||
devAdminUsername: ""
|
||||
devAdminPassword: ""
|
||||
`
|
||||
return afero.WriteFile(fs, "config.yaml", []byte(configYAML), 0644)
|
||||
}
|
||||
|
||||
func TestEnvironmentConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
envVars map[string]string
|
||||
expectError bool
|
||||
isDev bool
|
||||
isProd bool
|
||||
}{
|
||||
{
|
||||
name: "default is dev",
|
||||
envValue: "",
|
||||
expectError: false,
|
||||
isDev: true,
|
||||
isProd: false,
|
||||
},
|
||||
{
|
||||
name: "explicit dev",
|
||||
envValue: "dev",
|
||||
expectError: false,
|
||||
isDev: true,
|
||||
isProd: false,
|
||||
},
|
||||
{
|
||||
name: "explicit prod with session key",
|
||||
envValue: "prod",
|
||||
envVars: map[string]string{
|
||||
"SESSION_KEY": "cHJvZC1zZXNzaW9uLWtleS0zMi1ieXRlcy1sb25nISE=",
|
||||
"DBURL": "postgres://prod:prod@localhost:5432/prod?sslmode=require",
|
||||
},
|
||||
expectError: false,
|
||||
isDev: false,
|
||||
isProd: true,
|
||||
},
|
||||
{
|
||||
name: "invalid environment",
|
||||
envValue: "staging",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create in-memory filesystem with test config
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(t, createTestConfig(fs))
|
||||
pkgconfig.SetFs(fs)
|
||||
|
||||
// Set environment variable if specified
|
||||
if tt.envValue != "" {
|
||||
os.Setenv("WEBHOOKER_ENVIRONMENT", tt.envValue)
|
||||
defer os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
||||
}
|
||||
|
||||
// Set additional environment variables
|
||||
for k, v := range tt.envVars {
|
||||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
// Use regular fx.New for error cases since fxtest doesn't expose errors the same way
|
||||
var cfg *Config
|
||||
app := fx.New(
|
||||
fx.NopLogger, // Suppress fx logs in tests
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
assert.Error(t, app.Err())
|
||||
} else {
|
||||
// Use fxtest for success cases
|
||||
var cfg *Config
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
require.NoError(t, app.Err())
|
||||
app.RequireStart()
|
||||
defer app.RequireStop()
|
||||
|
||||
assert.Equal(t, tt.isDev, cfg.IsDev())
|
||||
assert.Equal(t, tt.isProd, cfg.IsProd())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionKeyDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
environment string
|
||||
sessionKey string
|
||||
dburl string
|
||||
expectError bool
|
||||
expectedKey string
|
||||
}{
|
||||
{
|
||||
name: "dev mode with default session key",
|
||||
environment: "dev",
|
||||
sessionKey: "",
|
||||
expectError: false,
|
||||
expectedKey: DevSessionKey,
|
||||
},
|
||||
{
|
||||
name: "dev mode with custom session key",
|
||||
environment: "dev",
|
||||
sessionKey: "Y3VzdG9tLXNlc3Npb24ta2V5LTMyLWJ5dGVzLWxvbmchIQ==",
|
||||
expectError: false,
|
||||
expectedKey: "Y3VzdG9tLXNlc3Npb24ta2V5LTMyLWJ5dGVzLWxvbmchIQ==",
|
||||
},
|
||||
{
|
||||
name: "prod mode with no session key fails",
|
||||
environment: "prod",
|
||||
sessionKey: "",
|
||||
dburl: "postgres://prod:prod@localhost:5432/prod",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "prod mode with session key succeeds",
|
||||
environment: "prod",
|
||||
sessionKey: "cHJvZC1zZXNzaW9uLWtleS0zMi1ieXRlcy1sb25nISE=",
|
||||
dburl: "postgres://prod:prod@localhost:5432/prod",
|
||||
expectError: false,
|
||||
expectedKey: "cHJvZC1zZXNzaW9uLWtleS0zMi1ieXRlcy1sb25nISE=",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create in-memory filesystem with test config
|
||||
fs := afero.NewMemMapFs()
|
||||
|
||||
// Create custom config for session key tests
|
||||
configYAML := `
|
||||
environments:
|
||||
dev:
|
||||
config:
|
||||
environment: dev
|
||||
developmentMode: true
|
||||
dburl: postgres://test:test@localhost:5432/test_dev
|
||||
secrets:`
|
||||
|
||||
// Only add sessionKey line if it's not empty
|
||||
if tt.sessionKey != "" {
|
||||
configYAML += `
|
||||
sessionKey: ` + tt.sessionKey
|
||||
} else if tt.environment == "dev" {
|
||||
// For dev mode with no session key, use the default
|
||||
configYAML += `
|
||||
sessionKey: d2ViaG9va2VyLWRldi1zZXNzaW9uLWtleS1pbnNlY3VyZSE=`
|
||||
}
|
||||
|
||||
// Add prod config if testing prod
|
||||
if tt.environment == "prod" {
|
||||
configYAML += `
|
||||
prod:
|
||||
config:
|
||||
environment: prod
|
||||
developmentMode: false
|
||||
dburl: $ENV:DBURL
|
||||
secrets:
|
||||
sessionKey: $ENV:SESSION_KEY`
|
||||
}
|
||||
|
||||
require.NoError(t, afero.WriteFile(fs, "config.yaml", []byte(configYAML), 0644))
|
||||
pkgconfig.SetFs(fs)
|
||||
|
||||
// Clean up any existing env vars
|
||||
os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
||||
os.Unsetenv("SESSION_KEY")
|
||||
os.Unsetenv("DBURL")
|
||||
|
||||
// Set environment variables
|
||||
os.Setenv("WEBHOOKER_ENVIRONMENT", tt.environment)
|
||||
defer os.Unsetenv("WEBHOOKER_ENVIRONMENT")
|
||||
|
||||
if tt.sessionKey != "" && tt.environment == "prod" {
|
||||
os.Setenv("SESSION_KEY", tt.sessionKey)
|
||||
defer os.Unsetenv("SESSION_KEY")
|
||||
}
|
||||
|
||||
if tt.dburl != "" {
|
||||
os.Setenv("DBURL", tt.dburl)
|
||||
defer os.Unsetenv("DBURL")
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
// Use regular fx.New for error cases
|
||||
var cfg *Config
|
||||
app := fx.New(
|
||||
fx.NopLogger, // Suppress fx logs in tests
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
assert.Error(t, app.Err())
|
||||
} else {
|
||||
// Use fxtest for success cases
|
||||
var cfg *Config
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&cfg),
|
||||
)
|
||||
require.NoError(t, app.Err())
|
||||
app.RequireStart()
|
||||
defer app.RequireStop()
|
||||
|
||||
if tt.environment == "dev" && tt.sessionKey == "" {
|
||||
// Dev mode with no session key uses default
|
||||
assert.Equal(t, DevSessionKey, cfg.SessionKey)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedKey, cfg.SessionKey)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
25
internal/database/base_model.go
Normal file
25
internal/database/base_model.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseModel contains common fields for all models
|
||||
// This replaces gorm.Model but uses UUID instead of uint for ID
|
||||
type BaseModel struct {
|
||||
ID string `gorm:"type:uuid;primary_key" json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to set UUID before creating a record
|
||||
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
|
||||
if b.ID == "" {
|
||||
b.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
144
internal/database/database.go
Normal file
144
internal/database/database.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"go.uber.org/fx"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
_ "modernc.org/sqlite" // Pure Go SQLite driver
|
||||
)
|
||||
|
||||
// nolint:revive // DatabaseParams is a standard fx naming convention
|
||||
type DatabaseParams struct {
|
||||
fx.In
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
db *gorm.DB
|
||||
log *slog.Logger
|
||||
params *DatabaseParams
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
||||
d := &Database{
|
||||
params: ¶ms,
|
||||
log: params.Logger.Get(),
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
return d.connect()
|
||||
},
|
||||
OnStop: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
return d.close()
|
||||
},
|
||||
})
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Database) connect() error {
|
||||
dbURL := d.params.Config.DBURL
|
||||
if dbURL == "" {
|
||||
// Default to SQLite for development
|
||||
dbURL = "file:webhooker.db?cache=shared&mode=rwc"
|
||||
}
|
||||
|
||||
// First, open the database with the pure Go driver
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
d.log.Error("failed to open database", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Then use it with GORM
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
d.log.Error("failed to connect to database", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
d.db = db
|
||||
d.log.Info("connected to database", "database", dbURL)
|
||||
|
||||
// Run migrations
|
||||
return d.migrate()
|
||||
}
|
||||
|
||||
func (d *Database) migrate() error {
|
||||
// Run GORM auto-migrations
|
||||
if err := d.Migrate(); err != nil {
|
||||
d.log.Error("failed to run database migrations", "error", err)
|
||||
return err
|
||||
}
|
||||
d.log.Info("database migrations completed")
|
||||
|
||||
// Check if admin user exists
|
||||
var userCount int64
|
||||
if err := d.db.Model(&User{}).Count(&userCount).Error; err != nil {
|
||||
d.log.Error("failed to count users", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if userCount == 0 {
|
||||
// Create admin user
|
||||
d.log.Info("no users found, creating admin user")
|
||||
|
||||
// Generate random password
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
d.log.Error("failed to generate random password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
hashedPassword, err := HashPassword(password)
|
||||
if err != nil {
|
||||
d.log.Error("failed to hash password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
adminUser := &User{
|
||||
Username: "admin",
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
if err := d.db.Create(adminUser).Error; err != nil {
|
||||
d.log.Error("failed to create admin user", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Log the password - this will only happen once on first startup
|
||||
d.log.Info("admin user created",
|
||||
"username", "admin",
|
||||
"password", password,
|
||||
"message", "SAVE THIS PASSWORD - it will not be shown again!")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) close() error {
|
||||
if d.db != nil {
|
||||
sqlDB, err := d.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) DB() *gorm.DB {
|
||||
return d.db
|
||||
}
|
||||
81
internal/database/database_test.go
Normal file
81
internal/database/database_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
func TestDatabaseConnection(t *testing.T) {
|
||||
// Set up test dependencies
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
// Create globals
|
||||
globals.Appname = "webhooker-test"
|
||||
globals.Version = "test"
|
||||
globals.Buildarch = "test"
|
||||
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create globals: %v", err)
|
||||
}
|
||||
|
||||
// Create logger
|
||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create logger: %v", err)
|
||||
}
|
||||
|
||||
// Create config
|
||||
c, err := config.New(lc, config.ConfigParams{
|
||||
Globals: g,
|
||||
Logger: l,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config: %v", err)
|
||||
}
|
||||
|
||||
// Set test database URL
|
||||
c.DBURL = "file:test.db?cache=shared&mode=rwc"
|
||||
|
||||
// Create database
|
||||
db, err := New(lc, DatabaseParams{
|
||||
Config: c,
|
||||
Logger: l,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
// Start lifecycle (this will trigger the connection)
|
||||
ctx := context.Background()
|
||||
err = lc.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if stopErr := lc.Stop(ctx); stopErr != nil {
|
||||
t.Errorf("Failed to stop lifecycle: %v", stopErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Verify we can get the DB instance
|
||||
if db.DB() == nil {
|
||||
t.Error("Expected non-nil database connection")
|
||||
}
|
||||
|
||||
// Test that we can perform a simple query
|
||||
var result int
|
||||
err = db.DB().Raw("SELECT 1").Scan(&result).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute test query: %v", err)
|
||||
}
|
||||
|
||||
if result != 1 {
|
||||
t.Errorf("Expected query result to be 1, got %d", result)
|
||||
}
|
||||
}
|
||||
16
internal/database/model_apikey.go
Normal file
16
internal/database/model_apikey.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package database
|
||||
|
||||
import "time"
|
||||
|
||||
// APIKey represents an API key for a user
|
||||
type APIKey struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Key string `gorm:"uniqueIndex;not null" json:"key"`
|
||||
Description string `json:"description"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitempty"`
|
||||
}
|
||||
25
internal/database/model_delivery.go
Normal file
25
internal/database/model_delivery.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package database
|
||||
|
||||
// DeliveryStatus represents the status of a delivery
|
||||
type DeliveryStatus string
|
||||
|
||||
const (
|
||||
DeliveryStatusPending DeliveryStatus = "pending"
|
||||
DeliveryStatusDelivered DeliveryStatus = "delivered"
|
||||
DeliveryStatusFailed DeliveryStatus = "failed"
|
||||
DeliveryStatusRetrying DeliveryStatus = "retrying"
|
||||
)
|
||||
|
||||
// Delivery represents a delivery attempt for an event to a target
|
||||
type Delivery struct {
|
||||
BaseModel
|
||||
|
||||
EventID string `gorm:"type:uuid;not null" json:"event_id"`
|
||||
TargetID string `gorm:"type:uuid;not null" json:"target_id"`
|
||||
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
|
||||
|
||||
// Relations
|
||||
Event Event `json:"event,omitempty"`
|
||||
Target Target `json:"target,omitempty"`
|
||||
DeliveryResults []DeliveryResult `json:"delivery_results,omitempty"`
|
||||
}
|
||||
17
internal/database/model_delivery_result.go
Normal file
17
internal/database/model_delivery_result.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package database
|
||||
|
||||
// DeliveryResult represents the result of a delivery attempt
|
||||
type DeliveryResult struct {
|
||||
BaseModel
|
||||
|
||||
DeliveryID string `gorm:"type:uuid;not null" json:"delivery_id"`
|
||||
AttemptNum int `gorm:"not null" json:"attempt_num"`
|
||||
Success bool `json:"success"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ResponseBody string `gorm:"type:text" json:"response_body,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Duration int64 `json:"duration_ms"` // Duration in milliseconds
|
||||
|
||||
// Relations
|
||||
Delivery Delivery `json:"delivery,omitempty"`
|
||||
}
|
||||
20
internal/database/model_event.go
Normal file
20
internal/database/model_event.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package database
|
||||
|
||||
// Event represents a webhook event
|
||||
type Event struct {
|
||||
BaseModel
|
||||
|
||||
ProcessorID string `gorm:"type:uuid;not null" json:"processor_id"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
|
||||
// Request data
|
||||
Method string `gorm:"not null" json:"method"`
|
||||
Headers string `gorm:"type:text" json:"headers"` // JSON
|
||||
Body string `gorm:"type:text" json:"body"`
|
||||
ContentType string `json:"content_type"`
|
||||
|
||||
// Relations
|
||||
Processor Processor `json:"processor,omitempty"`
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
16
internal/database/model_processor.go
Normal file
16
internal/database/model_processor.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package database
|
||||
|
||||
// Processor represents an event processor
|
||||
type Processor struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Description string `json:"description"`
|
||||
RetentionDays int `gorm:"default:30" json:"retention_days"` // Days to retain events
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitempty"`
|
||||
Webhooks []Webhook `json:"webhooks,omitempty"`
|
||||
Targets []Target `json:"targets,omitempty"`
|
||||
}
|
||||
32
internal/database/model_target.go
Normal file
32
internal/database/model_target.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package database
|
||||
|
||||
// TargetType represents the type of delivery target
|
||||
type TargetType string
|
||||
|
||||
const (
|
||||
TargetTypeHTTP TargetType = "http"
|
||||
TargetTypeRetry TargetType = "retry"
|
||||
TargetTypeDatabase TargetType = "database"
|
||||
TargetTypeLog TargetType = "log"
|
||||
)
|
||||
|
||||
// Target represents a delivery target for a processor
|
||||
type Target struct {
|
||||
BaseModel
|
||||
|
||||
ProcessorID string `gorm:"type:uuid;not null" json:"processor_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type TargetType `gorm:"not null" json:"type"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Configuration fields (JSON stored based on type)
|
||||
Config string `gorm:"type:text" json:"config"` // JSON configuration
|
||||
|
||||
// For retry targets
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
MaxQueueSize int `json:"max_queue_size,omitempty"`
|
||||
|
||||
// Relations
|
||||
Processor Processor `json:"processor,omitempty"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
13
internal/database/model_user.go
Normal file
13
internal/database/model_user.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package database
|
||||
|
||||
// User represents a user of the webhooker service
|
||||
type User struct {
|
||||
BaseModel
|
||||
|
||||
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
||||
Password string `gorm:"not null" json:"-"` // Argon2 hashed
|
||||
|
||||
// Relations
|
||||
Processors []Processor `json:"processors,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
}
|
||||
14
internal/database/model_webhook.go
Normal file
14
internal/database/model_webhook.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package database
|
||||
|
||||
// Webhook represents a webhook endpoint that feeds into a processor
|
||||
type Webhook struct {
|
||||
BaseModel
|
||||
|
||||
ProcessorID string `gorm:"type:uuid;not null" json:"processor_id"`
|
||||
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this webhook
|
||||
Description string `json:"description"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Relations
|
||||
Processor Processor `json:"processor,omitempty"`
|
||||
}
|
||||
15
internal/database/models.go
Normal file
15
internal/database/models.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package database
|
||||
|
||||
// Migrate runs database migrations for all models
|
||||
func (d *Database) Migrate() error {
|
||||
return d.db.AutoMigrate(
|
||||
&User{},
|
||||
&APIKey{},
|
||||
&Processor{},
|
||||
&Webhook{},
|
||||
&Target{},
|
||||
&Event{},
|
||||
&Delivery{},
|
||||
&DeliveryResult{},
|
||||
)
|
||||
}
|
||||
187
internal/database/password.go
Normal file
187
internal/database/password.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// Argon2 parameters - these are up-to-date secure defaults
|
||||
const (
|
||||
argon2Time = 1
|
||||
argon2Memory = 64 * 1024 // 64 MB
|
||||
argon2Threads = 4
|
||||
argon2KeyLen = 32
|
||||
argon2SaltLen = 16
|
||||
)
|
||||
|
||||
// PasswordConfig holds Argon2 configuration
|
||||
type PasswordConfig struct {
|
||||
Time uint32
|
||||
Memory uint32
|
||||
Threads uint8
|
||||
KeyLen uint32
|
||||
SaltLen uint32
|
||||
}
|
||||
|
||||
// DefaultPasswordConfig returns secure default Argon2 parameters
|
||||
func DefaultPasswordConfig() *PasswordConfig {
|
||||
return &PasswordConfig{
|
||||
Time: argon2Time,
|
||||
Memory: argon2Memory,
|
||||
Threads: argon2Threads,
|
||||
KeyLen: argon2KeyLen,
|
||||
SaltLen: argon2SaltLen,
|
||||
}
|
||||
}
|
||||
|
||||
// HashPassword generates an Argon2id hash of the password
|
||||
func HashPassword(password string) (string, error) {
|
||||
config := DefaultPasswordConfig()
|
||||
|
||||
// Generate a salt
|
||||
salt := make([]byte, config.SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate the hash
|
||||
hash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
|
||||
// Encode the hash and parameters
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
|
||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if the provided password matches the hash
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
// Extract parameters and hash from encoded string
|
||||
config, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Generate hash of the provided password
|
||||
otherHash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
|
||||
// Compare hashes using constant time comparison
|
||||
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
|
||||
}
|
||||
|
||||
// decodeHash extracts parameters, salt, and hash from an encoded hash string
|
||||
func decodeHash(encodedHash string) (*PasswordConfig, []byte, []byte, error) {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return nil, nil, nil, fmt.Errorf("invalid hash format")
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return nil, nil, nil, fmt.Errorf("invalid algorithm")
|
||||
}
|
||||
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, fmt.Errorf("incompatible argon2 version")
|
||||
}
|
||||
|
||||
config := &PasswordConfig{}
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
saltLen := len(salt)
|
||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("salt length out of range")
|
||||
}
|
||||
config.SaltLen = uint32(saltLen) // nolint:gosec // checked above
|
||||
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
hashLen := len(hash)
|
||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("hash length out of range")
|
||||
}
|
||||
config.KeyLen = uint32(hashLen) // nolint:gosec // checked above
|
||||
|
||||
return config, salt, hash, nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword generates a cryptographically secure random password
|
||||
func GenerateRandomPassword(length int) (string, error) {
|
||||
const (
|
||||
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
lowercase = "abcdefghijklmnopqrstuvwxyz"
|
||||
digits = "0123456789"
|
||||
special = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
)
|
||||
|
||||
// Combine all character sets
|
||||
allChars := uppercase + lowercase + digits + special
|
||||
|
||||
// Create password slice
|
||||
password := make([]byte, length)
|
||||
|
||||
// Ensure at least one character from each set for password complexity
|
||||
if length >= 4 {
|
||||
// Get one character from each set
|
||||
password[0] = uppercase[cryptoRandInt(len(uppercase))]
|
||||
password[1] = lowercase[cryptoRandInt(len(lowercase))]
|
||||
password[2] = digits[cryptoRandInt(len(digits))]
|
||||
password[3] = special[cryptoRandInt(len(special))]
|
||||
|
||||
// Fill the rest randomly from all characters
|
||||
for i := 4; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
|
||||
// Shuffle the password to avoid predictable pattern
|
||||
for i := len(password) - 1; i > 0; i-- {
|
||||
j := cryptoRandInt(i + 1)
|
||||
password[i], password[j] = password[j], password[i]
|
||||
}
|
||||
} else {
|
||||
// For very short passwords, just use all characters
|
||||
for i := 0; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
}
|
||||
|
||||
return string(password), nil
|
||||
}
|
||||
|
||||
// cryptoRandInt generates a cryptographically secure random integer in [0, max)
|
||||
func cryptoRandInt(max int) int {
|
||||
if max <= 0 {
|
||||
panic("max must be positive")
|
||||
}
|
||||
|
||||
// Calculate the maximum valid value to avoid modulo bias
|
||||
// For example, if max=200 and we have 256 possible values,
|
||||
// we only accept values 0-199 (reject 200-255)
|
||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand error: %v", err))
|
||||
}
|
||||
|
||||
return int(nBig.Int64())
|
||||
}
|
||||
126
internal/database/password_test.go
Normal file
126
internal/database/password_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateRandomPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
}{
|
||||
{"Short password", 8},
|
||||
{"Medium password", 16},
|
||||
{"Long password", 32},
|
||||
{"Very short password", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
password, err := GenerateRandomPassword(tt.length)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if len(password) != tt.length {
|
||||
t.Errorf("Password length = %v, want %v", len(password), tt.length)
|
||||
}
|
||||
|
||||
// For passwords >= 4 chars, check complexity
|
||||
if tt.length >= 4 {
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
hasSpecial := false
|
||||
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
||||
t.Errorf("Password lacks required complexity: upper=%v, lower=%v, digit=%v, special=%v",
|
||||
hasUpper, hasLower, hasDigit, hasSpecial)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomPasswordUniqueness(t *testing.T) {
|
||||
// Generate multiple passwords and ensure they're different
|
||||
passwords := make(map[string]bool)
|
||||
const numPasswords = 100
|
||||
|
||||
for i := 0; i < numPasswords; i++ {
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if passwords[password] {
|
||||
t.Errorf("Duplicate password generated: %s", password)
|
||||
}
|
||||
passwords[password] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
password := "testPassword123!"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
// Check that hash has correct format
|
||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||
t.Errorf("Hash doesn't have correct prefix: %s", hash)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
valid, err := VerifyPassword(password, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
if !valid {
|
||||
t.Error("VerifyPassword() returned false for correct password")
|
||||
}
|
||||
|
||||
// Verify wrong password fails
|
||||
valid, err = VerifyPassword("wrongPassword", hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
if valid {
|
||||
t.Error("VerifyPassword() returned true for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPasswordUniqueness(t *testing.T) {
|
||||
password := "testPassword123!"
|
||||
|
||||
// Same password should produce different hashes due to salt
|
||||
hash1, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
hash2, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Same password produced identical hashes (salt not working)")
|
||||
}
|
||||
}
|
||||
28
internal/globals/globals.go
Normal file
28
internal/globals/globals.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// these get populated from main() and copied into the Globals object.
|
||||
var (
|
||||
Appname string
|
||||
Version string
|
||||
Buildarch string
|
||||
)
|
||||
|
||||
type Globals struct {
|
||||
Appname string
|
||||
Version string
|
||||
Buildarch string
|
||||
}
|
||||
|
||||
// nolint:revive // lc parameter is required by fx even if unused
|
||||
func New(lc fx.Lifecycle) (*Globals, error) {
|
||||
n := &Globals{
|
||||
Appname: Appname,
|
||||
Buildarch: Buildarch,
|
||||
Version: Version,
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
30
internal/globals/globals_test.go
Normal file
30
internal/globals/globals_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// Set test values
|
||||
Appname = "test-app"
|
||||
Version = "1.0.0"
|
||||
Buildarch = "test-arch"
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
globals, err := New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if globals.Appname != "test-app" {
|
||||
t.Errorf("Appname = %v, want %v", globals.Appname, "test-app")
|
||||
}
|
||||
if globals.Version != "1.0.0" {
|
||||
t.Errorf("Version = %v, want %v", globals.Version, "1.0.0")
|
||||
}
|
||||
if globals.Buildarch != "test-arch" {
|
||||
t.Errorf("Buildarch = %v, want %v", globals.Buildarch, "test-arch")
|
||||
}
|
||||
}
|
||||
127
internal/handlers/auth.go
Normal file
127
internal/handlers/auth.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/database"
|
||||
)
|
||||
|
||||
// HandleLoginPage returns a handler for the login page (GET)
|
||||
func (h *Handlers) HandleLoginPage() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if already logged in
|
||||
sess, err := h.session.Get(r)
|
||||
if err == nil && h.session.IsAuthenticated(sess) {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Render login page
|
||||
data := map[string]interface{}{
|
||||
"Error": "",
|
||||
}
|
||||
|
||||
h.renderTemplate(w, r, []string{"templates/base.html", "templates/login.html"}, data)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleLoginSubmit handles the login form submission (POST)
|
||||
func (h *Handlers) HandleLoginSubmit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse form data
|
||||
if err := r.ParseForm(); err != nil {
|
||||
h.log.Error("failed to parse form", "error", err)
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
// Validate input
|
||||
if username == "" || password == "" {
|
||||
data := map[string]interface{}{
|
||||
"Error": "Username and password are required",
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
h.renderTemplate(w, r, []string{"templates/base.html", "templates/login.html"}, data)
|
||||
return
|
||||
}
|
||||
|
||||
// Find user in database
|
||||
var user database.User
|
||||
if err := h.db.DB().Where("username = ?", username).First(&user).Error; err != nil {
|
||||
h.log.Debug("user not found", "username", username)
|
||||
data := map[string]interface{}{
|
||||
"Error": "Invalid username or password",
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
h.renderTemplate(w, r, []string{"templates/base.html", "templates/login.html"}, data)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify password
|
||||
valid, err := database.VerifyPassword(password, user.Password)
|
||||
if err != nil {
|
||||
h.log.Error("failed to verify password", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !valid {
|
||||
h.log.Debug("invalid password", "username", username)
|
||||
data := map[string]interface{}{
|
||||
"Error": "Invalid username or password",
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
h.renderTemplate(w, r, []string{"templates/base.html", "templates/login.html"}, data)
|
||||
return
|
||||
}
|
||||
|
||||
// Create session
|
||||
sess, err := h.session.Get(r)
|
||||
if err != nil {
|
||||
h.log.Error("failed to get session", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Set user in session
|
||||
h.session.SetUser(sess, user.ID, user.Username)
|
||||
|
||||
// Save session
|
||||
if err := h.session.Save(r, w, sess); err != nil {
|
||||
h.log.Error("failed to save session", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.log.Info("user logged in", "username", username, "user_id", user.ID)
|
||||
|
||||
// Redirect to home page
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleLogout handles user logout
|
||||
func (h *Handlers) HandleLogout() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := h.session.Get(r)
|
||||
if err != nil {
|
||||
h.log.Error("failed to get session", "error", err)
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy session
|
||||
h.session.Destroy(sess)
|
||||
|
||||
// Save the destroyed session
|
||||
if err := h.session.Save(r, w, sess); err != nil {
|
||||
h.log.Error("failed to save destroyed session", "error", err)
|
||||
}
|
||||
|
||||
// Redirect to login page
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
137
internal/handlers/handlers.go
Normal file
137
internal/handlers/handlers.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/database"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/healthcheck"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"git.eeqj.de/sneak/webhooker/internal/session"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// nolint:revive // HandlersParams is a standard fx naming convention
|
||||
type HandlersParams struct {
|
||||
fx.In
|
||||
Logger *logger.Logger
|
||||
Globals *globals.Globals
|
||||
Database *database.Database
|
||||
Healthcheck *healthcheck.Healthcheck
|
||||
Session *session.Session
|
||||
}
|
||||
|
||||
type Handlers struct {
|
||||
params *HandlersParams
|
||||
log *slog.Logger
|
||||
hc *healthcheck.Healthcheck
|
||||
db *database.Database
|
||||
session *session.Session
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params HandlersParams) (*Handlers, error) {
|
||||
s := new(Handlers)
|
||||
s.params = ¶ms
|
||||
s.log = params.Logger.Get()
|
||||
s.hc = params.Healthcheck
|
||||
s.db = params.Database
|
||||
s.session = params.Session
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
// FIXME compile some templates here or something
|
||||
return nil
|
||||
},
|
||||
})
|
||||
return s, nil
|
||||
}
|
||||
|
||||
//nolint:unparam // r parameter will be used in the future for request context
|
||||
func (s *Handlers) respondJSON(w http.ResponseWriter, r *http.Request, data interface{}, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if data != nil {
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
s.log.Error("json encode error", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unparam,unused // will be used for handling JSON requests
|
||||
func (s *Handlers) decodeJSON(w http.ResponseWriter, r *http.Request, v interface{}) error {
|
||||
return json.NewDecoder(r.Body).Decode(v)
|
||||
}
|
||||
|
||||
// TemplateData represents the common data passed to templates
|
||||
type TemplateData struct {
|
||||
User *UserInfo
|
||||
Version string
|
||||
UserCount int64
|
||||
Uptime string
|
||||
}
|
||||
|
||||
// UserInfo represents user information for templates
|
||||
type UserInfo struct {
|
||||
ID string
|
||||
Username string
|
||||
}
|
||||
|
||||
// renderTemplate renders a template with common data
|
||||
func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, templateFiles []string, data interface{}) {
|
||||
// Always include the common templates
|
||||
allTemplates := []string{"templates/htmlheader.html", "templates/navbar.html"}
|
||||
allTemplates = append(allTemplates, templateFiles...)
|
||||
|
||||
// Parse templates
|
||||
tmpl, err := template.ParseFiles(allTemplates...)
|
||||
if err != nil {
|
||||
s.log.Error("failed to parse template", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user from session if available
|
||||
var userInfo *UserInfo
|
||||
sess, err := s.session.Get(r)
|
||||
if err == nil && s.session.IsAuthenticated(sess) {
|
||||
if username, ok := s.session.GetUsername(sess); ok {
|
||||
if userID, ok := s.session.GetUserID(sess); ok {
|
||||
userInfo = &UserInfo{
|
||||
ID: userID,
|
||||
Username: username,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap data with base template data
|
||||
type templateDataWrapper struct {
|
||||
User *UserInfo
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
wrapper := templateDataWrapper{
|
||||
User: userInfo,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// If data is a map, merge user info into it
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
m["User"] = userInfo
|
||||
if err := tmpl.Execute(w, m); err != nil {
|
||||
s.log.Error("failed to execute template", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise use wrapper
|
||||
if err := tmpl.Execute(w, wrapper); err != nil {
|
||||
s.log.Error("failed to execute template", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
130
internal/handlers/handlers_test.go
Normal file
130
internal/handlers/handlers_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/database"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/healthcheck"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"git.eeqj.de/sneak/webhooker/internal/session"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/fx"
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
func TestHandleIndex(t *testing.T) {
|
||||
var h *Handlers
|
||||
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
func() *config.Config {
|
||||
return &config.Config{
|
||||
// This is a base64 encoded 32-byte key: "test-session-key-32-bytes-long!!"
|
||||
SessionKey: "dGVzdC1zZXNzaW9uLWtleS0zMi1ieXRlcy1sb25nISE=",
|
||||
}
|
||||
},
|
||||
func() *database.Database {
|
||||
// Mock database with a mock DB method
|
||||
db := &database.Database{}
|
||||
return db
|
||||
},
|
||||
healthcheck.New,
|
||||
session.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&h),
|
||||
)
|
||||
app.RequireStart()
|
||||
defer app.RequireStop()
|
||||
|
||||
// Since we can't test actual template rendering without templates,
|
||||
// let's test that the handler is created and doesn't panic
|
||||
handler := h.HandleIndex()
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
|
||||
func TestRenderTemplate(t *testing.T) {
|
||||
var h *Handlers
|
||||
|
||||
app := fxtest.New(
|
||||
t,
|
||||
fx.Provide(
|
||||
globals.New,
|
||||
logger.New,
|
||||
func() *config.Config {
|
||||
return &config.Config{
|
||||
// This is a base64 encoded 32-byte key: "test-session-key-32-bytes-long!!"
|
||||
SessionKey: "dGVzdC1zZXNzaW9uLWtleS0zMi1ieXRlcy1sb25nISE=",
|
||||
}
|
||||
},
|
||||
func() *database.Database {
|
||||
// Mock database
|
||||
return &database.Database{}
|
||||
},
|
||||
healthcheck.New,
|
||||
session.New,
|
||||
New,
|
||||
),
|
||||
fx.Populate(&h),
|
||||
)
|
||||
app.RequireStart()
|
||||
defer app.RequireStop()
|
||||
|
||||
t.Run("handles missing templates gracefully", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Version": "1.0.0",
|
||||
}
|
||||
|
||||
// When templates don't exist, renderTemplate should return an error
|
||||
h.renderTemplate(w, req, []string{"nonexistent.html"}, data)
|
||||
|
||||
// Should return internal server error when template parsing fails
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatUptime(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
duration string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "minutes only",
|
||||
duration: "45m",
|
||||
expected: "45m",
|
||||
},
|
||||
{
|
||||
name: "hours and minutes",
|
||||
duration: "2h30m",
|
||||
expected: "2h 30m",
|
||||
},
|
||||
{
|
||||
name: "days, hours and minutes",
|
||||
duration: "25h45m",
|
||||
expected: "1d 1h 45m",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d, err := time.ParseDuration(tt.duration)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := formatUptime(d)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
12
internal/handlers/healthcheck.go
Normal file
12
internal/handlers/healthcheck.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (s *Handlers) HandleHealthCheck() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
resp := s.hc.Healthcheck()
|
||||
s.respondJSON(w, req, resp, 200)
|
||||
}
|
||||
}
|
||||
54
internal/handlers/index.go
Normal file
54
internal/handlers/index.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/database"
|
||||
)
|
||||
|
||||
type IndexResponse struct {
|
||||
Message string `json:"message"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (s *Handlers) HandleIndex() http.HandlerFunc {
|
||||
// Calculate server start time
|
||||
startTime := time.Now()
|
||||
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
// Calculate uptime
|
||||
uptime := time.Since(startTime)
|
||||
uptimeStr := formatUptime(uptime)
|
||||
|
||||
// Get user count from database
|
||||
var userCount int64
|
||||
s.db.DB().Model(&database.User{}).Count(&userCount)
|
||||
|
||||
// Prepare template data
|
||||
data := map[string]interface{}{
|
||||
"Version": s.params.Globals.Version,
|
||||
"Uptime": uptimeStr,
|
||||
"UserCount": userCount,
|
||||
}
|
||||
|
||||
// Render the template
|
||||
s.renderTemplate(w, req, []string{"templates/base.html", "templates/index.html"}, data)
|
||||
}
|
||||
}
|
||||
|
||||
// formatUptime formats a duration into a human-readable string
|
||||
func formatUptime(d time.Duration) string {
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
minutes := int(d.Minutes()) % 60
|
||||
|
||||
if days > 0 {
|
||||
return fmt.Sprintf("%dd %dh %dm", days, hours, minutes)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
59
internal/handlers/profile.go
Normal file
59
internal/handlers/profile.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
// HandleProfile returns a handler for the user profile page
|
||||
func (h *Handlers) HandleProfile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get username from URL
|
||||
requestedUsername := chi.URLParam(r, "username")
|
||||
if requestedUsername == "" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get session
|
||||
sess, err := h.session.Get(r)
|
||||
if err != nil || !h.session.IsAuthenticated(sess) {
|
||||
// Redirect to login if not authenticated
|
||||
http.Redirect(w, r, "/pages/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user info from session
|
||||
sessionUsername, ok := h.session.GetUsername(sess)
|
||||
if !ok {
|
||||
h.log.Error("authenticated session missing username")
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionUserID, ok := h.session.GetUserID(sess)
|
||||
if !ok {
|
||||
h.log.Error("authenticated session missing user ID")
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// For now, only allow users to view their own profile
|
||||
if requestedUsername != sessionUsername {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
data := map[string]interface{}{
|
||||
"User": &UserInfo{
|
||||
ID: sessionUserID,
|
||||
Username: sessionUsername,
|
||||
},
|
||||
}
|
||||
|
||||
// Render the profile page
|
||||
h.renderTemplate(w, r, []string{"templates/base.html", "templates/profile.html"}, data)
|
||||
}
|
||||
}
|
||||
69
internal/handlers/source_management.go
Normal file
69
internal/handlers/source_management.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HandleSourceList shows a list of user's webhook sources
|
||||
func (h *Handlers) HandleSourceList() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source list page
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceCreate shows the form to create a new webhook source
|
||||
func (h *Handlers) HandleSourceCreate() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source creation form
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceCreateSubmit handles the source creation form submission
|
||||
func (h *Handlers) HandleSourceCreateSubmit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source creation logic
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceDetail shows details for a specific webhook source
|
||||
func (h *Handlers) HandleSourceDetail() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source detail page
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceEdit shows the form to edit a webhook source
|
||||
func (h *Handlers) HandleSourceEdit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source edit form
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceEditSubmit handles the source edit form submission
|
||||
func (h *Handlers) HandleSourceEditSubmit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source update logic
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceDelete handles webhook source deletion
|
||||
func (h *Handlers) HandleSourceDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source deletion logic
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSourceLogs shows the request/response logs for a webhook source
|
||||
func (h *Handlers) HandleSourceLogs() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement source logs page
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
42
internal/handlers/webhook.go
Normal file
42
internal/handlers/webhook.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
// HandleWebhook handles incoming webhook requests
|
||||
func (h *Handlers) HandleWebhook() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get webhook UUID from URL
|
||||
webhookUUID := chi.URLParam(r, "uuid")
|
||||
if webhookUUID == "" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Log the incoming webhook request
|
||||
h.log.Info("webhook request received",
|
||||
"uuid", webhookUUID,
|
||||
"method", r.Method,
|
||||
"remote_addr", r.RemoteAddr,
|
||||
"user_agent", r.UserAgent(),
|
||||
)
|
||||
|
||||
// Only POST methods are allowed for webhooks
|
||||
if r.Method != http.MethodPost {
|
||||
w.Header().Set("Allow", "POST")
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Implement webhook handling logic
|
||||
// For now, return "unimplemented" for all webhook POST requests
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, err := w.Write([]byte("unimplemented"))
|
||||
if err != nil {
|
||||
h.log.Error("failed to write response", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
73
internal/healthcheck/healthcheck.go
Normal file
73
internal/healthcheck/healthcheck.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/database"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// nolint:revive // HealthcheckParams is a standard fx naming convention
|
||||
type HealthcheckParams struct {
|
||||
fx.In
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
Database *database.Database
|
||||
}
|
||||
|
||||
type Healthcheck struct {
|
||||
StartupTime time.Time
|
||||
log *slog.Logger
|
||||
params *HealthcheckParams
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params HealthcheckParams) (*Healthcheck, error) {
|
||||
s := new(Healthcheck)
|
||||
s.params = ¶ms
|
||||
s.log = params.Logger.Get()
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
s.StartupTime = time.Now()
|
||||
return nil
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// nolint:revive // HealthcheckResponse is a clear, descriptive name
|
||||
type HealthcheckResponse struct {
|
||||
Status string `json:"status"`
|
||||
Now string `json:"now"`
|
||||
UptimeSeconds int64 `json:"uptime_seconds"`
|
||||
UptimeHuman string `json:"uptime_human"`
|
||||
Version string `json:"version"`
|
||||
Appname string `json:"appname"`
|
||||
Maintenance bool `json:"maintenance_mode"`
|
||||
}
|
||||
|
||||
func (s *Healthcheck) uptime() time.Duration {
|
||||
return time.Since(s.StartupTime)
|
||||
}
|
||||
|
||||
func (s *Healthcheck) Healthcheck() *HealthcheckResponse {
|
||||
resp := &HealthcheckResponse{
|
||||
Status: "ok",
|
||||
Now: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
UptimeSeconds: int64(s.uptime().Seconds()),
|
||||
UptimeHuman: s.uptime().String(),
|
||||
Appname: s.params.Globals.Appname,
|
||||
Version: s.params.Globals.Version,
|
||||
Maintenance: s.params.Config.MaintenanceMode,
|
||||
}
|
||||
return resp
|
||||
}
|
||||
112
internal/logger/logger.go
Normal file
112
internal/logger/logger.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// nolint:revive // LoggerParams is a standard fx naming convention
|
||||
type LoggerParams struct {
|
||||
fx.In
|
||||
Globals *globals.Globals
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
logger *slog.Logger
|
||||
params LoggerParams
|
||||
}
|
||||
|
||||
// nolint:revive // lc parameter is required by fx even if unused
|
||||
func New(lc fx.Lifecycle, params LoggerParams) (*Logger, error) {
|
||||
l := new(Logger)
|
||||
l.params = params
|
||||
|
||||
// Determine if we're running in a terminal
|
||||
tty := false
|
||||
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) != 0 {
|
||||
tty = true
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { // nolint:revive // groups unused
|
||||
// Always use UTC for timestamps
|
||||
if a.Key == slog.TimeKey {
|
||||
if t, ok := a.Value.Any().(time.Time); ok {
|
||||
return slog.Time(slog.TimeKey, t.UTC())
|
||||
}
|
||||
}
|
||||
return a
|
||||
},
|
||||
}
|
||||
|
||||
if tty {
|
||||
// Use text handler for terminal output (human-readable)
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
} else {
|
||||
// Use JSON handler for production (machine-readable)
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
}
|
||||
|
||||
l.logger = slog.New(handler)
|
||||
|
||||
// Set as default logger
|
||||
slog.SetDefault(l.logger)
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (l *Logger) EnableDebugLogging() {
|
||||
// Recreate logger with debug level
|
||||
tty := false
|
||||
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) != 0 {
|
||||
tty = true
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { // nolint:revive // groups unused
|
||||
// Always use UTC for timestamps
|
||||
if a.Key == slog.TimeKey {
|
||||
if t, ok := a.Value.Any().(time.Time); ok {
|
||||
return slog.Time(slog.TimeKey, t.UTC())
|
||||
}
|
||||
}
|
||||
return a
|
||||
},
|
||||
}
|
||||
|
||||
if tty {
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
} else {
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
}
|
||||
|
||||
l.logger = slog.New(handler)
|
||||
slog.SetDefault(l.logger)
|
||||
l.logger.Debug("debug logging enabled", "debug", true)
|
||||
}
|
||||
|
||||
func (l *Logger) Get() *slog.Logger {
|
||||
return l.logger
|
||||
}
|
||||
|
||||
func (l *Logger) Identify() {
|
||||
l.logger.Info("starting",
|
||||
"appname", l.params.Globals.Appname,
|
||||
"version", l.params.Globals.Version,
|
||||
"buildarch", l.params.Globals.Buildarch,
|
||||
)
|
||||
}
|
||||
|
||||
// Helper methods to maintain compatibility with existing code
|
||||
func (l *Logger) Writer() io.Writer {
|
||||
return os.Stdout
|
||||
}
|
||||
65
internal/logger/logger_test.go
Normal file
65
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// Set up globals
|
||||
globals.Appname = "test-app"
|
||||
globals.Version = "1.0.0"
|
||||
globals.Buildarch = "test-arch"
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("globals.New() error = %v", err)
|
||||
}
|
||||
|
||||
params := LoggerParams{
|
||||
Globals: g,
|
||||
}
|
||||
|
||||
logger, err := New(lc, params)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if logger.Get() == nil {
|
||||
t.Error("Get() returned nil logger")
|
||||
}
|
||||
|
||||
// Test that we can log without panic
|
||||
logger.Get().Info("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestEnableDebugLogging(t *testing.T) {
|
||||
// Set up globals
|
||||
globals.Appname = "test-app"
|
||||
globals.Version = "1.0.0"
|
||||
globals.Buildarch = "test-arch"
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("globals.New() error = %v", err)
|
||||
}
|
||||
|
||||
params := LoggerParams{
|
||||
Globals: g,
|
||||
}
|
||||
|
||||
logger, err := New(lc, params)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
// Enable debug logging should not panic
|
||||
logger.EnableDebugLogging()
|
||||
|
||||
// Test debug logging
|
||||
logger.Get().Debug("debug message", "test", true)
|
||||
}
|
||||
149
internal/middleware/middleware.go
Normal file
149
internal/middleware/middleware.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
basicauth "github.com/99designs/basicauth-go"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
metrics "github.com/slok/go-http-metrics/metrics/prometheus"
|
||||
ghmm "github.com/slok/go-http-metrics/middleware"
|
||||
"github.com/slok/go-http-metrics/middleware/std"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// nolint:revive // MiddlewareParams is a standard fx naming convention
|
||||
type MiddlewareParams struct {
|
||||
fx.In
|
||||
Logger *logger.Logger
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
log *slog.Logger
|
||||
params *MiddlewareParams
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params MiddlewareParams) (*Middleware, error) {
|
||||
s := new(Middleware)
|
||||
s.params = ¶ms
|
||||
s.log = params.Logger.Get()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// the following is from
|
||||
// https://learning-cloud-native-go.github.io/docs/a6.adding_zerolog_logger/
|
||||
|
||||
func ipFromHostPort(hp string) string {
|
||||
h, _, err := net.SplitHostPort(hp)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if len(h) > 0 && h[0] == '[' {
|
||||
return h[1 : len(h)-1]
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
type loggingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// nolint:revive // unexported type is only used internally
|
||||
func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
|
||||
return &loggingResponseWriter{w, http.StatusOK}
|
||||
}
|
||||
|
||||
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.statusCode = code
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// type Middleware func(http.Handler) http.Handler
|
||||
// this returns a Middleware that is designed to do every request through the
|
||||
// mux, note the signature:
|
||||
func (s *Middleware) Logging() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
lrw := NewLoggingResponseWriter(w)
|
||||
ctx := r.Context()
|
||||
defer func() {
|
||||
latency := time.Since(start)
|
||||
requestID := ""
|
||||
if reqID := ctx.Value(middleware.RequestIDKey); reqID != nil {
|
||||
if id, ok := reqID.(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
s.log.Info("http request",
|
||||
"request_start", start,
|
||||
"method", r.Method,
|
||||
"url", r.URL.String(),
|
||||
"useragent", r.UserAgent(),
|
||||
"request_id", requestID,
|
||||
"referer", r.Referer(),
|
||||
"proto", r.Proto,
|
||||
"remoteIP", ipFromHostPort(r.RemoteAddr),
|
||||
"status", lrw.statusCode,
|
||||
"latency_ms", latency.Milliseconds(),
|
||||
)
|
||||
}()
|
||||
|
||||
next.ServeHTTP(lrw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Middleware) CORS() func(http.Handler) http.Handler {
|
||||
return cors.Handler(cors.Options{
|
||||
// CHANGEME! these are defaults, change them to suit your needs or
|
||||
// read from environment/viper.
|
||||
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
|
||||
AllowedOrigins: []string{"*"},
|
||||
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: false,
|
||||
MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Middleware) Auth() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: implement proper authentication
|
||||
s.log.Debug("AUTH: before request")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Middleware) Metrics() func(http.Handler) http.Handler {
|
||||
mdlw := ghmm.New(ghmm.Config{
|
||||
Recorder: metrics.NewRecorder(metrics.Config{}),
|
||||
})
|
||||
return func(next http.Handler) http.Handler {
|
||||
return std.Handler("", mdlw, next)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
||||
return basicauth.New(
|
||||
"metrics",
|
||||
map[string][]string{
|
||||
s.params.Config.MetricsUsername: {
|
||||
s.params.Config.MetricsPassword,
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
34
internal/server/http.go
Normal file
34
internal/server/http.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) serveUntilShutdown() {
|
||||
listenAddr := fmt.Sprintf(":%d", s.params.Config.Port)
|
||||
s.httpServer = &http.Server{
|
||||
Addr: listenAddr,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
Handler: s,
|
||||
}
|
||||
|
||||
// add routes
|
||||
// this does any necessary setup in each handler
|
||||
s.SetupRoutes()
|
||||
|
||||
s.log.Info("http begin listen", "listenaddr", listenAddr)
|
||||
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.log.Error("listen error", "error", err)
|
||||
if s.cancelFunc != nil {
|
||||
s.cancelFunc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
112
internal/server/routes.go
Normal file
112
internal/server/routes.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/static"
|
||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
func (s *Server) SetupRoutes() {
|
||||
s.router = chi.NewRouter()
|
||||
|
||||
// the mux .Use() takes a http.Handler wrapper func, like most
|
||||
// things that deal with "middlewares" like alice et c, and will
|
||||
// call ServeHTTP on it. These middlewares applied by the mux (you
|
||||
// can .Use() more than one) will be applied to every request into
|
||||
// the service.
|
||||
|
||||
s.router.Use(middleware.Recoverer)
|
||||
s.router.Use(middleware.RequestID)
|
||||
s.router.Use(s.mw.Logging())
|
||||
|
||||
// add metrics middleware only if we can serve them behind auth
|
||||
if s.params.Config.MetricsUsername != "" {
|
||||
s.router.Use(s.mw.Metrics())
|
||||
}
|
||||
|
||||
// set up CORS headers
|
||||
s.router.Use(s.mw.CORS())
|
||||
|
||||
// timeout for request context; your handlers must finish within
|
||||
// this window:
|
||||
s.router.Use(middleware.Timeout(60 * time.Second))
|
||||
|
||||
// this adds a sentry reporting middleware if and only if sentry is
|
||||
// enabled via setting of SENTRY_DSN in env.
|
||||
if s.sentryEnabled {
|
||||
// Options docs at
|
||||
// https://docs.sentry.io/platforms/go/guides/http/
|
||||
// we set sentry to repanic so that all panics bubble up to the
|
||||
// Recoverer chi middleware above.
|
||||
sentryHandler := sentryhttp.New(sentryhttp.Options{
|
||||
Repanic: true,
|
||||
})
|
||||
s.router.Use(sentryHandler.Handle)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// ROUTES
|
||||
// complete docs: https://github.com/go-chi/chi
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
s.router.Get("/", s.h.HandleIndex())
|
||||
|
||||
s.router.Mount("/s", http.StripPrefix("/s", http.FileServer(http.FS(static.Static))))
|
||||
|
||||
s.router.Route("/api/v1", func(_ chi.Router) {
|
||||
// TODO: Add API routes here
|
||||
})
|
||||
|
||||
s.router.Get(
|
||||
"/.well-known/healthcheck.json",
|
||||
s.h.HandleHealthCheck(),
|
||||
)
|
||||
|
||||
// set up authenticated /metrics route:
|
||||
if s.params.Config.MetricsUsername != "" {
|
||||
s.router.Group(func(r chi.Router) {
|
||||
r.Use(s.mw.MetricsAuth())
|
||||
r.Get("/metrics", http.HandlerFunc(promhttp.Handler().ServeHTTP))
|
||||
})
|
||||
}
|
||||
|
||||
// pages that are rendered server-side
|
||||
s.router.Route("/pages", func(r chi.Router) {
|
||||
// Login page (no auth required)
|
||||
r.Get("/login", s.h.HandleLoginPage())
|
||||
r.Post("/login", s.h.HandleLoginSubmit())
|
||||
|
||||
// Logout (auth required)
|
||||
r.Post("/logout", s.h.HandleLogout())
|
||||
})
|
||||
|
||||
// User profile routes
|
||||
s.router.Route("/user/{username}", func(r chi.Router) {
|
||||
r.Get("/", s.h.HandleProfile())
|
||||
})
|
||||
|
||||
// Webhook source management routes (require authentication)
|
||||
s.router.Route("/sources", func(r chi.Router) {
|
||||
// TODO: Add authentication middleware here
|
||||
r.Get("/", s.h.HandleSourceList()) // List all sources
|
||||
r.Get("/new", s.h.HandleSourceCreate()) // Show create form
|
||||
r.Post("/new", s.h.HandleSourceCreateSubmit()) // Handle create submission
|
||||
})
|
||||
|
||||
s.router.Route("/source/{sourceID}", func(r chi.Router) {
|
||||
// TODO: Add authentication middleware here
|
||||
r.Get("/", s.h.HandleSourceDetail()) // View source details
|
||||
r.Get("/edit", s.h.HandleSourceEdit()) // Show edit form
|
||||
r.Post("/edit", s.h.HandleSourceEditSubmit()) // Handle edit submission
|
||||
r.Post("/delete", s.h.HandleSourceDelete()) // Delete source
|
||||
r.Get("/logs", s.h.HandleSourceLogs()) // View source logs
|
||||
})
|
||||
|
||||
// Webhook endpoint - accepts all HTTP methods
|
||||
s.router.HandleFunc("/webhook/{uuid}", s.h.HandleWebhook())
|
||||
}
|
||||
162
internal/server/server.go
Normal file
162
internal/server/server.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/handlers"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"git.eeqj.de/sneak/webhooker/internal/middleware"
|
||||
"go.uber.org/fx"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
// spooky action at a distance!
|
||||
// this populates the environment
|
||||
// from a ./.env file automatically
|
||||
// for development configuration.
|
||||
// .env contents should be things like
|
||||
// `DBURL=postgres://user:pass@.../`
|
||||
// (without the backticks, of course)
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
)
|
||||
|
||||
// ServerParams is a standard fx naming convention for dependency injection
|
||||
// nolint:golint
|
||||
type ServerParams struct {
|
||||
fx.In
|
||||
Logger *logger.Logger
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
Middleware *middleware.Middleware
|
||||
Handlers *handlers.Handlers
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
startupTime time.Time
|
||||
exitCode int
|
||||
sentryEnabled bool
|
||||
log *slog.Logger
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
httpServer *http.Server
|
||||
router *chi.Mux
|
||||
params ServerParams
|
||||
mw *middleware.Middleware
|
||||
h *handlers.Handlers
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params ServerParams) (*Server, error) {
|
||||
s := new(Server)
|
||||
s.params = params
|
||||
s.mw = params.Middleware
|
||||
s.h = params.Handlers
|
||||
s.log = params.Logger.Get()
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
s.startupTime = time.Now()
|
||||
go s.Run()
|
||||
return nil
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
s.cleanShutdown()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) Run() {
|
||||
s.configure()
|
||||
|
||||
// logging before sentry, because sentry logs
|
||||
s.enableSentry()
|
||||
|
||||
s.serve()
|
||||
}
|
||||
|
||||
func (s *Server) enableSentry() {
|
||||
s.sentryEnabled = false
|
||||
|
||||
if s.params.Config.SentryDSN == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: s.params.Config.SentryDSN,
|
||||
Release: fmt.Sprintf("%s-%s", s.params.Globals.Appname, s.params.Globals.Version),
|
||||
})
|
||||
if err != nil {
|
||||
s.log.Error("sentry init failure", "error", err)
|
||||
// Don't use fatal since we still want the service to run
|
||||
return
|
||||
}
|
||||
s.log.Info("sentry error reporting activated")
|
||||
s.sentryEnabled = true
|
||||
}
|
||||
|
||||
func (s *Server) serve() int {
|
||||
s.ctx, s.cancelFunc = context.WithCancel(context.Background())
|
||||
|
||||
// signal watcher
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Ignore(syscall.SIGPIPE)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
// block and wait for signal
|
||||
sig := <-c
|
||||
s.log.Info("signal received", "signal", sig.String())
|
||||
if s.cancelFunc != nil {
|
||||
// cancelling the main context will trigger a clean
|
||||
// shutdown.
|
||||
s.cancelFunc()
|
||||
}
|
||||
}()
|
||||
|
||||
go s.serveUntilShutdown()
|
||||
|
||||
<-s.ctx.Done()
|
||||
s.cleanShutdown()
|
||||
return s.exitCode
|
||||
}
|
||||
|
||||
func (s *Server) cleanupForExit() {
|
||||
s.log.Info("cleaning up")
|
||||
// TODO: close database connections, flush buffers, etc.
|
||||
}
|
||||
|
||||
func (s *Server) cleanShutdown() {
|
||||
// initiate clean shutdown
|
||||
s.exitCode = 0
|
||||
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := s.httpServer.Shutdown(ctxShutdown); err != nil {
|
||||
s.log.Error("server clean shutdown failed", "error", err)
|
||||
}
|
||||
|
||||
s.cleanupForExit()
|
||||
|
||||
if s.sentryEnabled {
|
||||
sentry.Flush(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) MaintenanceMode() bool {
|
||||
return s.params.Config.MaintenanceMode
|
||||
}
|
||||
|
||||
func (s *Server) configure() {
|
||||
// identify ourselves in the logs
|
||||
s.params.Logger.Identify()
|
||||
}
|
||||
125
internal/session/session.go
Normal file
125
internal/session/session.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"github.com/gorilla/sessions"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
const (
|
||||
// SessionName is the name of the session cookie
|
||||
SessionName = "webhooker_session"
|
||||
|
||||
// UserIDKey is the session key for user ID
|
||||
UserIDKey = "user_id"
|
||||
|
||||
// UsernameKey is the session key for username
|
||||
UsernameKey = "username"
|
||||
|
||||
// AuthenticatedKey is the session key for authentication status
|
||||
AuthenticatedKey = "authenticated"
|
||||
)
|
||||
|
||||
// nolint:revive // SessionParams is a standard fx naming convention
|
||||
type SessionParams struct {
|
||||
fx.In
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// Session manages encrypted session storage
|
||||
type Session struct {
|
||||
store *sessions.CookieStore
|
||||
log *slog.Logger
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// New creates a new session manager
|
||||
func New(lc fx.Lifecycle, params SessionParams) (*Session, error) {
|
||||
if params.Config.SessionKey == "" {
|
||||
return nil, fmt.Errorf("SESSION_KEY environment variable is required")
|
||||
}
|
||||
|
||||
// Decode the base64 session key
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(params.Config.SessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid SESSION_KEY format: %w", err)
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("SESSION_KEY must be 32 bytes (got %d)", len(keyBytes))
|
||||
}
|
||||
|
||||
store := sessions.NewCookieStore(keyBytes)
|
||||
|
||||
// Configure cookie options for security
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: 86400 * 7, // 7 days
|
||||
HttpOnly: true,
|
||||
Secure: !params.Config.IsDev(), // HTTPS in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
s := &Session{
|
||||
store: store,
|
||||
log: params.Logger.Get(),
|
||||
config: params.Config,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Get retrieves a session for the request
|
||||
func (s *Session) Get(r *http.Request) (*sessions.Session, error) {
|
||||
return s.store.Get(r, SessionName)
|
||||
}
|
||||
|
||||
// Save saves the session
|
||||
func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error {
|
||||
return sess.Save(r, w)
|
||||
}
|
||||
|
||||
// SetUser sets the user information in the session
|
||||
func (s *Session) SetUser(sess *sessions.Session, userID, username string) {
|
||||
sess.Values[UserIDKey] = userID
|
||||
sess.Values[UsernameKey] = username
|
||||
sess.Values[AuthenticatedKey] = true
|
||||
}
|
||||
|
||||
// ClearUser removes user information from the session
|
||||
func (s *Session) ClearUser(sess *sessions.Session) {
|
||||
delete(sess.Values, UserIDKey)
|
||||
delete(sess.Values, UsernameKey)
|
||||
delete(sess.Values, AuthenticatedKey)
|
||||
}
|
||||
|
||||
// IsAuthenticated checks if the session has an authenticated user
|
||||
func (s *Session) IsAuthenticated(sess *sessions.Session) bool {
|
||||
auth, ok := sess.Values[AuthenticatedKey].(bool)
|
||||
return ok && auth
|
||||
}
|
||||
|
||||
// GetUserID retrieves the user ID from the session
|
||||
func (s *Session) GetUserID(sess *sessions.Session) (string, bool) {
|
||||
userID, ok := sess.Values[UserIDKey].(string)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// GetUsername retrieves the username from the session
|
||||
func (s *Session) GetUsername(sess *sessions.Session) (string, bool) {
|
||||
username, ok := sess.Values[UsernameKey].(string)
|
||||
return username, ok
|
||||
}
|
||||
|
||||
// Destroy invalidates the session
|
||||
func (s *Session) Destroy(sess *sessions.Session) {
|
||||
sess.Options.MaxAge = -1
|
||||
s.ClearUser(sess)
|
||||
}
|
||||
Reference in New Issue
Block a user